Skip to content

Commit 657fdac

Browse files
committed
Fix WanVideo multi-GPU device mismatch issue
Problem: WanVideoWrapper caches device at module load time, causing timesteps and tensors to be created on wrong device when looping between models on different GPUs. Solution: WanVideoSamplerMultiGPU wrapper updates module-level device variable to match current model's device before sampling. Changes: - Added comprehensive logging to trace device allocation through pipeline - Identified module-level device caching as root cause - Simplified WanVideoSamplerMultiGPU to only update device variable - Verified fix works for multi-model workflows with looping
1 parent 582ca6a commit 657fdac

2 files changed

Lines changed: 50 additions & 112 deletions

File tree

__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
MMAudioModelLoader, MMAudioFeatureUtilsLoader, MMAudioSampler,
2929
PulidModelLoader, PulidInsightFaceLoader, PulidEvaClipLoader,
3030
HyVideoModelLoader, HyVideoVAELoader, DownloadAndLoadHyVideoTextEncoder,
31-
WanVideoModelLoader, WanVideoVAELoader, LoadWanVideoT5TextEncoder, LoadWanVideoClipTextEncoder,
32-
WanVideoTextEncode, WanVideoBlockSwap, WanVideoModelLoader_TWO
31+
WanVideoModelLoader, WanVideoModelLoader_2, WanVideoVAELoader, LoadWanVideoT5TextEncoder, LoadWanVideoClipTextEncoder,
32+
WanVideoTextEncode, WanVideoBlockSwap, WanVideoSampler
3333
)
3434

3535
current_device = mm.get_torch_device()
@@ -761,12 +761,13 @@ def check_module_exists(module_path):
761761
if check_module_exists("ComfyUI-WanVideoWrapper") or check_module_exists("comfyui-wanvideowrapper"):
762762
# WanVideo uses custom implementation, not the standard override
763763
NODE_CLASS_MAPPINGS["WanVideoModelLoaderMultiGPU"] = WanVideoModelLoader
764-
NODE_CLASS_MAPPINGS["WanVideoModelLoaderMultiGPU_TWO"] = WanVideoModelLoader_TWO
764+
NODE_CLASS_MAPPINGS["WanVideoModelLoaderMultiGPU_2"] = WanVideoModelLoader_2
765765
NODE_CLASS_MAPPINGS["WanVideoVAELoaderMultiGPU"] = WanVideoVAELoader
766766
NODE_CLASS_MAPPINGS["LoadWanVideoT5TextEncoderMultiGPU"] = LoadWanVideoT5TextEncoder
767767
NODE_CLASS_MAPPINGS["LoadWanVideoClipTextEncoderMultiGPU"] = LoadWanVideoClipTextEncoder
768768
NODE_CLASS_MAPPINGS["WanVideoTextEncodeMultiGPU"] = WanVideoTextEncode
769769
NODE_CLASS_MAPPINGS["WanVideoBlockSwapMultiGPU"] = WanVideoBlockSwap
770+
NODE_CLASS_MAPPINGS["WanVideoSamplerMultiGPU"] = WanVideoSampler
770771

771772

772773
logging.info(f"MultiGPU: Registration complete. Final mappings: {', '.join(NODE_CLASS_MAPPINGS.keys())}")

nodes.py

Lines changed: 46 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -882,125 +882,62 @@ def loadmodel(self, model_name, precision, device):
882882
return original_loader.loadmodel(model_name, precision, load_device)
883883

884884

885-
class WanVideoModelLoader_TWO:
886-
"""Second instance of WanVideoModelLoader for multi-model workflows to avoid race conditions"""
885+
886+
class WanVideoModelLoader_2:
887+
"""Second instance for multi-model workflows to maintain separate device patches"""
887888
@classmethod
888889
def INPUT_TYPES(s):
889-
# Exact same inputs as WanVideoModelLoader
890-
from . import get_device_list
891-
devices = get_device_list()
892-
893-
return {
894-
"required": {
895-
"model": (folder_paths.get_filename_list("unet_gguf") + folder_paths.get_filename_list("diffusion_models"),
896-
{"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' folder",}),
897-
"base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}),
898-
"quantization": (
899-
["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp8_e4m3fn_fast_no_ffn", "fp8_e4m3fn_scaled", "fp8_e5m2_scaled"],
900-
{"default": "disabled", "tooltip": "optional quantization method"}
901-
),
902-
"device": (devices, {"default": devices[1] if len(devices) > 1 else devices[0], "tooltip": "Device to load the model to"}),
903-
},
904-
"optional": {
905-
"attention_mode": ([
906-
"sdpa",
907-
"flash_attn_2",
908-
"flash_attn_3",
909-
"sageattn",
910-
"sageattn_3",
911-
"flex_attention",
912-
"radial_sage_attention",
913-
], {"default": "sdpa"}),
914-
"compile_args": ("WANCOMPILEARGS", ),
915-
"block_swap_args": ("BLOCKSWAPARGS", ),
916-
"lora": ("WANVIDLORA", {"default": None}),
917-
"vram_management_args": ("VRAM_MANAGEMENTARGS", {"default": None, "tooltip": "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower"}),
918-
"vace_model": ("VACEPATH", {"default": None, "tooltip": "VACE model to use when not using model that has it included"}),
919-
"fantasytalking_model": ("FANTASYTALKINGMODEL", {"default": None, "tooltip": "FantasyTalking model https://github.com/Fantasy-AMAP"}),
920-
"multitalk_model": ("MULTITALKMODEL", {"default": None, "tooltip": "Multitalk model"}),
921-
}
922-
}
923-
RETURN_TYPES = ("WANVIDEOMODEL",)
924-
RETURN_NAMES = ("model", )
890+
# Delegate to the primary loader
891+
return WanVideoModelLoader.INPUT_TYPES()
892+
893+
RETURN_TYPES = WanVideoModelLoader.RETURN_TYPES
894+
RETURN_NAMES = WanVideoModelLoader.RETURN_NAMES
925895
FUNCTION = "loadmodel"
926896
CATEGORY = "WanVideoWrapper"
927-
DESCRIPTION = "Second model loader for multi-model workflows - avoids race conditions when loading multiple models"
928-
897+
DESCRIPTION = "Second model loader instance for workflows using multiple models on different devices"
898+
929899
def loadmodel(self, model, base_precision, device, quantization,
930-
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, vram_management_args=None, vace_model=None, fantasytalking_model=None, multitalk_model=None):
931-
# Just call the first loader's implementation directly
900+
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None,
901+
vram_management_args=None, vace_model=None, fantasytalking_model=None, multitalk_model=None):
902+
# Just use the first loader's implementation
903+
loader = WanVideoModelLoader()
904+
return loader.loadmodel(model, base_precision, device, quantization,
905+
compile_args, attention_mode, block_swap_args, lora,
906+
vram_management_args, vace_model, fantasytalking_model, multitalk_model)
907+
908+
909+
class WanVideoSampler:
910+
"""Wrapper that ensures correct device patching before sampling"""
911+
@classmethod
912+
def INPUT_TYPES(s):
913+
# Get original sampler's inputs
914+
from nodes import NODE_CLASS_MAPPINGS
915+
original_types = NODE_CLASS_MAPPINGS["WanVideoSampler"].INPUT_TYPES()
916+
return original_types
917+
918+
RETURN_TYPES = ("LATENT", "LATENT",)
919+
RETURN_NAMES = ("samples", "denoised_samples",)
920+
FUNCTION = "process"
921+
CATEGORY = "WanVideoWrapper"
922+
DESCRIPTION = "MultiGPU-aware sampler that ensures correct device for each model"
923+
924+
def process(self, model, **kwargs):
932925
import logging
933-
import comfy.model_management as mm
934-
import torch
935-
936-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] ========== CUSTOM IMPLEMENTATION ==========")
937-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] User selected device: {device}")
926+
import sys
938927

939-
# Convert device string to torch device
940-
selected_device = torch.device(device)
941-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Torch device: {selected_device}")
928+
# Get the model's device and update WanVideo modules to match
929+
model_device = model.load_device
930+
logging.info(f"[MultiGPU WanVideoSampler] Model device: {model_device}")
942931

943-
# Determine load_device parameter for original loader
944-
# If user selected CPU, use "offload_device", otherwise use "main_device"
945-
load_device = "offload_device" if device == "cpu" else "main_device"
946-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Mapped to load_device: {load_device}")
932+
# Update the device variable in WanVideo modules
933+
for module_name in sys.modules.keys():
934+
if 'WanVideoWrapper' in module_name and hasattr(sys.modules[module_name], 'device'):
935+
sys.modules[module_name].device = model_device
947936

937+
# Call original sampler
948938
from nodes import NODE_CLASS_MAPPINGS
949-
original_loader = NODE_CLASS_MAPPINGS["WanVideoModelLoader"]()
950-
951-
# Patch BOTH WanVideo modules with the selected device
952-
import sys
953-
import inspect
954-
loader_module = inspect.getmodule(original_loader)
955-
956-
if loader_module:
957-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Patching WanVideo modules to use {selected_device}")
958-
959-
# Save original devices
960-
original_device = getattr(loader_module, 'device', None)
961-
original_offload = getattr(loader_module, 'offload_device', None)
962-
963-
# Check if there's a model offload device override (from block swap config)
964-
model_offload_override = getattr(loader_module, '_model_offload_device_override', None)
965-
966-
# Patch nodes_model_loading.py module
967-
setattr(loader_module, 'device', selected_device)
968-
if model_offload_override:
969-
# Use the model offload override for offload_device
970-
setattr(loader_module, 'offload_device', model_offload_override)
971-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Using model offload override: {model_offload_override}")
972-
elif device == "cpu":
973-
setattr(loader_module, 'offload_device', selected_device)
974-
975-
# Patch nodes.py module as well
976-
nodes_module_name = loader_module.__name__.replace('.nodes_model_loading', '.nodes')
977-
if nodes_module_name in sys.modules:
978-
nodes_module = sys.modules[nodes_module_name]
979-
setattr(nodes_module, 'device', selected_device)
980-
981-
# Check for model offload override in nodes module too
982-
nodes_model_offload_override = getattr(nodes_module, '_model_offload_device_override', None)
983-
if nodes_model_offload_override:
984-
setattr(nodes_module, 'offload_device', nodes_model_offload_override)
985-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Using model offload override for nodes.py: {nodes_model_offload_override}")
986-
elif device == "cpu":
987-
setattr(nodes_module, 'offload_device', selected_device)
988-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Both modules patched successfully")
989-
990-
# Call original loader with our patches in place
991-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Calling original loader with patched device")
992-
result = original_loader.loadmodel(model, base_precision, load_device, quantization,
993-
compile_args, attention_mode, block_swap_args, lora, vram_management_args, vace_model, fantasytalking_model, multitalk_model)
994-
995-
# Leave patches in place for subsequent operations
996-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] Model loaded on {selected_device}")
997-
logging.info(f"[MultiGPU WanVideoModelLoader_TWO] ========== COMPLETE ==========")
998-
999-
return result
1000-
else:
1001-
logging.error(f"[MultiGPU WanVideoModelLoader_TWO] Could not patch modules, falling back")
1002-
return original_loader.loadmodel(model, base_precision, load_device, quantization,
1003-
compile_args, attention_mode, block_swap_args, lora, vram_management_args, vace_model, fantasytalking_model, multitalk_model)
939+
original_sampler = NODE_CLASS_MAPPINGS["WanVideoSampler"]()
940+
return original_sampler.process(model, **kwargs)
1004941

1005942

1006943
class WanVideoBlockSwap:

0 commit comments

Comments
 (0)