Skip to content

Commit 582ca6a

Browse files
committed
WanVideoWrapper MultiGPU integration - custom wrapper nodes
- Created custom implementations for all WanVideo nodes with explicit device selection - Added WanVideoBlockSwap with dual device control (swap_device and model_offload_device) - Created WanVideoModelLoader_TWO for multi-model workflows to avoid race conditions - Discovered core ComfyUI bug: safetensors loader ignores device index (uses device.type instead of str(device)) - All wrapper nodes use runtime module patching to override WanVideoWrapper's cached device variables - Extensive logging added for debugging device assignments
1 parent a05823f commit 582ca6a

4 files changed

Lines changed: 4361 additions & 26 deletions

File tree

__init__.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
MMAudioModelLoader, MMAudioFeatureUtilsLoader, MMAudioSampler,
2929
PulidModelLoader, PulidInsightFaceLoader, PulidEvaClipLoader,
3030
HyVideoModelLoader, HyVideoVAELoader, DownloadAndLoadHyVideoTextEncoder,
31-
WanVideoModelLoader, WanVideoVAELoader, LoadWanVideoT5TextEncoder
31+
WanVideoModelLoader, WanVideoVAELoader, LoadWanVideoT5TextEncoder, LoadWanVideoClipTextEncoder,
32+
WanVideoTextEncode, WanVideoBlockSwap, WanVideoModelLoader_TWO
3233
)
3334

3435
current_device = mm.get_torch_device()
@@ -41,6 +42,7 @@ def get_torch_device_patched():
4142
device = torch.device("cpu")
4243
else:
4344
device = torch.device(current_device)
45+
logging.info(f"[MultiGPU get_torch_device_patched] Returning device: {device} (current_device={current_device})")
4446
return device
4547

4648
def text_encoder_device_patched():
@@ -49,10 +51,15 @@ def text_encoder_device_patched():
4951
device = torch.device("cpu")
5052
else:
5153
device = torch.device(current_text_encoder_device)
54+
logging.info(f"[MultiGPU text_encoder_device_patched] Returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
5255
return device
5356

57+
logging.info(f"[MultiGPU] Patching mm.get_torch_device and mm.text_encoder_device")
58+
logging.info(f"[MultiGPU] Initial current_device: {current_device}")
59+
logging.info(f"[MultiGPU] Initial current_text_encoder_device: {current_text_encoder_device}")
5460
mm.get_torch_device = get_torch_device_patched
5561
mm.text_encoder_device = text_encoder_device_patched
62+
logging.info(f"[MultiGPU] Patches applied successfully")
5663

5764

5865
def create_model_hash(model, caller):
@@ -528,10 +535,18 @@ def INPUT_TYPES(s):
528535

529536
def override(self, *args, device=None, **kwargs):
530537
global current_device
538+
539+
logging.info(f"[MultiGPU override_class] Called with device={device}, current_device={current_device}")
540+
531541
if device is not None:
532542
current_device = device
543+
logging.info(f"[MultiGPU override_class] Setting current_device to {device}")
544+
533545
fn = getattr(super(), cls.FUNCTION)
546+
logging.info(f"[MultiGPU override_class] Calling wrapped function: {cls.__name__}.{cls.FUNCTION}")
534547
out = fn(*args, **kwargs)
548+
logging.info(f"[MultiGPU override_class] Wrapped function completed successfully")
549+
535550
return out
536551

537552
return NodeOverride
@@ -552,10 +567,13 @@ def INPUT_TYPES(s):
552567

553568
def override(self, *args, device=None, **kwargs):
554569
global current_text_encoder_device
570+
555571
if device is not None:
556572
current_text_encoder_device = device
573+
557574
fn = getattr(super(), cls.FUNCTION)
558575
out = fn(*args, **kwargs)
576+
559577
return out
560578

561579
return NodeOverride
@@ -741,9 +759,14 @@ def check_module_exists(module_path):
741759
NODE_CLASS_MAPPINGS["DownloadAndLoadHyVideoTextEncoderMultiGPU"] = override_class(DownloadAndLoadHyVideoTextEncoder)
742760

743761
if check_module_exists("ComfyUI-WanVideoWrapper") or check_module_exists("comfyui-wanvideowrapper"):
744-
NODE_CLASS_MAPPINGS["WanVideoModelLoaderMultiGPU"] = override_class(WanVideoModelLoader)
745-
NODE_CLASS_MAPPINGS["WanVideoVAELoaderMultiGPU"] = override_class(WanVideoVAELoader)
746-
NODE_CLASS_MAPPINGS["LoadWanVideoT5TextEncoderMultiGPU"] = override_class(LoadWanVideoT5TextEncoder)
762+
# WanVideo uses custom implementation, not the standard override
763+
NODE_CLASS_MAPPINGS["WanVideoModelLoaderMultiGPU"] = WanVideoModelLoader
764+
NODE_CLASS_MAPPINGS["WanVideoModelLoaderMultiGPU_TWO"] = WanVideoModelLoader_TWO
765+
NODE_CLASS_MAPPINGS["WanVideoVAELoaderMultiGPU"] = WanVideoVAELoader
766+
NODE_CLASS_MAPPINGS["LoadWanVideoT5TextEncoderMultiGPU"] = LoadWanVideoT5TextEncoder
767+
NODE_CLASS_MAPPINGS["LoadWanVideoClipTextEncoderMultiGPU"] = LoadWanVideoClipTextEncoder
768+
NODE_CLASS_MAPPINGS["WanVideoTextEncodeMultiGPU"] = WanVideoTextEncode
769+
NODE_CLASS_MAPPINGS["WanVideoBlockSwapMultiGPU"] = WanVideoBlockSwap
747770

748771

749772
logging.info(f"MultiGPU: Registration complete. Final mappings: {', '.join(NODE_CLASS_MAPPINGS.keys())}")

0 commit comments

Comments
 (0)