Skip to content

Commit 375cbd0

Browse files
committed
Add LoadWanVideoClipTextEncoder and WanVideoClipVisionEncode for multi-GPU support
1 parent 2f605e0 commit 375cbd0

2 files changed

Lines changed: 84 additions & 0 deletions

File tree

__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def text_encoder_device_patched():
224224
WanVideoSampler,
225225
WanVideoVACEEncode,
226226
WanVideoEncode,
227+
LoadWanVideoClipTextEncoder,
228+
WanVideoClipVisionEncode,
227229
)
228230

229231
from .wrappers import (
@@ -373,6 +375,8 @@ def register_and_count(module_names, node_map):
373375
"WanVideoSamplerMultiGPU": WanVideoSampler,
374376
"WanVideoVACEEncodeMultiGPU": WanVideoVACEEncode,
375377
"WanVideoEncodeMultiGPU": WanVideoEncode,
378+
"LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder,
379+
"WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode,
376380
}
377381
register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], wanvideo_nodes)
378382

wanvideo.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,44 @@ def loadmodel(self, model_name, precision, device=None, quantization="disabled")
261261

262262
return text_encoder, device
263263

264+
class LoadWanVideoClipTextEncoder:
265+
@classmethod
266+
def INPUT_TYPES(s):
267+
devices = get_device_list()
268+
default_device = devices[1] if len(devices) > 1 else devices[0]
269+
return {
270+
"required": {
271+
"model_name": (folder_paths.get_filename_list("clip_vision") + folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/clip_vision'"}),
272+
"precision": (["fp16", "fp32", "bf16"],
273+
{"default": "fp16"}
274+
),
275+
},
276+
"optional": {
277+
"device": (devices, {"default": default_device}),
278+
}
279+
}
280+
281+
RETURN_TYPES = ("CLIP_VISION", "MULTIGPUDEVICE")
282+
RETURN_NAMES = ("wan_clip_vision", "load_device")
283+
FUNCTION = "loadmodel"
284+
CATEGORY = "multigpu/WanVideoWrapper"
285+
DESCRIPTION = "Loads Wan clip_vision model from 'ComfyUI/models/clip_vision'"
286+
287+
def loadmodel(self, model_name, precision, device=None):
288+
from . import set_current_device
289+
290+
set_current_device(device)
291+
292+
if device == "cpu":
293+
load_device = "offload_device"
294+
else:
295+
load_device = "main_device"
296+
297+
original_loader = NODE_CLASS_MAPPINGS["LoadWanVideoClipTextEncoder"]()
298+
clip_model = original_loader.loadmodel(model_name, precision, load_device)
299+
300+
return clip_model, device
301+
264302
class WanVideoTextEncodeCached:
265303
@classmethod
266304
def INPUT_TYPES(s):
@@ -632,3 +670,45 @@ def encode(self, vae, load_device, image, enable_vae_tiling, tile_x, tile_y, til
632670
return original_encode.encode(vae[0], image, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, noise_aug_strength, latent_strength, mask)
633671
finally:
634672
encode_module.device = original_module_device
673+
674+
class WanVideoClipVisionEncode:
675+
@classmethod
676+
def INPUT_TYPES(s):
677+
return {"required": {
678+
"clip_vision": ("CLIP_VISION",),
679+
"load_device": ("MULTIGPUDEVICE",),
680+
"image_1": ("IMAGE", {"tooltip": "Image to encode"}),
681+
"strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}),
682+
"strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}),
683+
"crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}),
684+
"combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}),
685+
"force_offload": ("BOOLEAN", {"default": True}),
686+
},
687+
"optional": {
688+
"image_2": ("IMAGE", ),
689+
"negative_image": ("IMAGE", {"tooltip": "image to use for uncond"}),
690+
"tiles": ("INT", {"default": 0, "min": 0, "max": 16, "step": 2, "tooltip": "Use matteo's tiled image encoding for improved accuracy"}),
691+
"ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ratio of the tile average"}),
692+
}
693+
}
694+
695+
RETURN_TYPES = ("WANVIDIMAGE_CLIPEMBEDS",)
696+
RETURN_NAMES = ("image_embeds",)
697+
FUNCTION = "process"
698+
CATEGORY = "multigpu/WanVideoWrapper"
699+
700+
def process(self, clip_vision, load_device, image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2=None, negative_image=None, tiles=0, ratio=1.0):
701+
from . import set_current_device
702+
703+
original_encode = NODE_CLASS_MAPPINGS["WanVideoClipVisionEncode"]()
704+
encode_module = inspect.getmodule(original_encode)
705+
original_module_device = encode_module.device
706+
707+
set_current_device(load_device)
708+
compute_device_to_be_patched = mm.get_torch_device()
709+
encode_module.device = compute_device_to_be_patched
710+
711+
try:
712+
return original_encode.process(clip_vision[0], image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2, negative_image, tiles, ratio)
713+
finally:
714+
encode_module.device = original_module_device

0 commit comments

Comments
 (0)