@@ -882,125 +882,62 @@ def loadmodel(self, model_name, precision, device):
882882 return original_loader .loadmodel (model_name , precision , load_device )
883883
884884
885- class WanVideoModelLoader_TWO :
886- """Second instance of WanVideoModelLoader for multi-model workflows to avoid race conditions"""
885+
886+ class WanVideoModelLoader_2 :
887+ """Second instance for multi-model workflows to maintain separate device patches"""
887888 @classmethod
888889 def INPUT_TYPES (s ):
889- # Exact same inputs as WanVideoModelLoader
890- from . import get_device_list
891- devices = get_device_list ()
892-
893- return {
894- "required" : {
895- "model" : (folder_paths .get_filename_list ("unet_gguf" ) + folder_paths .get_filename_list ("diffusion_models" ),
896- {"tooltip" : "These models are loaded from the 'ComfyUI/models/diffusion_models' folder" ,}),
897- "base_precision" : (["fp32" , "bf16" , "fp16" , "fp16_fast" ], {"default" : "bf16" }),
898- "quantization" : (
899- ["disabled" , "fp8_e4m3fn" , "fp8_e4m3fn_fast" , "fp8_e5m2" , "fp8_e4m3fn_fast_no_ffn" , "fp8_e4m3fn_scaled" , "fp8_e5m2_scaled" ],
900- {"default" : "disabled" , "tooltip" : "optional quantization method" }
901- ),
902- "device" : (devices , {"default" : devices [1 ] if len (devices ) > 1 else devices [0 ], "tooltip" : "Device to load the model to" }),
903- },
904- "optional" : {
905- "attention_mode" : ([
906- "sdpa" ,
907- "flash_attn_2" ,
908- "flash_attn_3" ,
909- "sageattn" ,
910- "sageattn_3" ,
911- "flex_attention" ,
912- "radial_sage_attention" ,
913- ], {"default" : "sdpa" }),
914- "compile_args" : ("WANCOMPILEARGS" , ),
915- "block_swap_args" : ("BLOCKSWAPARGS" , ),
916- "lora" : ("WANVIDLORA" , {"default" : None }),
917- "vram_management_args" : ("VRAM_MANAGEMENTARGS" , {"default" : None , "tooltip" : "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower" }),
918- "vace_model" : ("VACEPATH" , {"default" : None , "tooltip" : "VACE model to use when not using model that has it included" }),
919- "fantasytalking_model" : ("FANTASYTALKINGMODEL" , {"default" : None , "tooltip" : "FantasyTalking model https://github.com/Fantasy-AMAP" }),
920- "multitalk_model" : ("MULTITALKMODEL" , {"default" : None , "tooltip" : "Multitalk model" }),
921- }
922- }
923- RETURN_TYPES = ("WANVIDEOMODEL" ,)
924- RETURN_NAMES = ("model" , )
890+ # Delegate to the primary loader
891+ return WanVideoModelLoader .INPUT_TYPES ()
892+
893+ RETURN_TYPES = WanVideoModelLoader .RETURN_TYPES
894+ RETURN_NAMES = WanVideoModelLoader .RETURN_NAMES
925895 FUNCTION = "loadmodel"
926896 CATEGORY = "WanVideoWrapper"
927- DESCRIPTION = "Second model loader for multi-model workflows - avoids race conditions when loading multiple models "
928-
897+ DESCRIPTION = "Second model loader instance for workflows using multiple models on different devices "
898+
929899 def loadmodel (self , model , base_precision , device , quantization ,
930- compile_args = None , attention_mode = "sdpa" , block_swap_args = None , lora = None , vram_management_args = None , vace_model = None , fantasytalking_model = None , multitalk_model = None ):
931- # Just call the first loader's implementation directly
900+ compile_args = None , attention_mode = "sdpa" , block_swap_args = None , lora = None ,
901+ vram_management_args = None , vace_model = None , fantasytalking_model = None , multitalk_model = None ):
902+ # Just use the first loader's implementation
903+ loader = WanVideoModelLoader ()
904+ return loader .loadmodel (model , base_precision , device , quantization ,
905+ compile_args , attention_mode , block_swap_args , lora ,
906+ vram_management_args , vace_model , fantasytalking_model , multitalk_model )
907+
908+
909+ class WanVideoSampler :
910+ """Wrapper that ensures correct device patching before sampling"""
911+ @classmethod
912+ def INPUT_TYPES (s ):
913+ # Get original sampler's inputs
914+ from nodes import NODE_CLASS_MAPPINGS
915+ original_types = NODE_CLASS_MAPPINGS ["WanVideoSampler" ].INPUT_TYPES ()
916+ return original_types
917+
918+ RETURN_TYPES = ("LATENT" , "LATENT" ,)
919+ RETURN_NAMES = ("samples" , "denoised_samples" ,)
920+ FUNCTION = "process"
921+ CATEGORY = "WanVideoWrapper"
922+ DESCRIPTION = "MultiGPU-aware sampler that ensures correct device for each model"
923+
924+ def process (self , model , ** kwargs ):
932925 import logging
933- import comfy .model_management as mm
934- import torch
935-
936- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] ========== CUSTOM IMPLEMENTATION ==========" )
937- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] User selected device: { device } " )
926+ import sys
938927
939- # Convert device string to torch device
940- selected_device = torch . device ( device )
941- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] Torch device: { selected_device } " )
928+ # Get the model's device and update WanVideo modules to match
929+ model_device = model . load_device
930+ logging .info (f"[MultiGPU WanVideoSampler] Model device: { model_device } " )
942931
943- # Determine load_device parameter for original loader
944- # If user selected CPU, use "offload_device", otherwise use "main_device"
945- load_device = "offload_device" if device == "cpu" else "main_device"
946- logging . info ( f"[MultiGPU WanVideoModelLoader_TWO] Mapped to load_device: { load_device } " )
932+ # Update the device variable in WanVideo modules
933+ for module_name in sys . modules . keys ():
934+ if 'WanVideoWrapper' in module_name and hasattr ( sys . modules [ module_name ], 'device' ):
935+ sys . modules [ module_name ]. device = model_device
947936
937+ # Call original sampler
948938 from nodes import NODE_CLASS_MAPPINGS
949- original_loader = NODE_CLASS_MAPPINGS ["WanVideoModelLoader" ]()
950-
951- # Patch BOTH WanVideo modules with the selected device
952- import sys
953- import inspect
954- loader_module = inspect .getmodule (original_loader )
955-
956- if loader_module :
957- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] Patching WanVideo modules to use { selected_device } " )
958-
959- # Save original devices
960- original_device = getattr (loader_module , 'device' , None )
961- original_offload = getattr (loader_module , 'offload_device' , None )
962-
963- # Check if there's a model offload device override (from block swap config)
964- model_offload_override = getattr (loader_module , '_model_offload_device_override' , None )
965-
966- # Patch nodes_model_loading.py module
967- setattr (loader_module , 'device' , selected_device )
968- if model_offload_override :
969- # Use the model offload override for offload_device
970- setattr (loader_module , 'offload_device' , model_offload_override )
971- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] Using model offload override: { model_offload_override } " )
972- elif device == "cpu" :
973- setattr (loader_module , 'offload_device' , selected_device )
974-
975- # Patch nodes.py module as well
976- nodes_module_name = loader_module .__name__ .replace ('.nodes_model_loading' , '.nodes' )
977- if nodes_module_name in sys .modules :
978- nodes_module = sys .modules [nodes_module_name ]
979- setattr (nodes_module , 'device' , selected_device )
980-
981- # Check for model offload override in nodes module too
982- nodes_model_offload_override = getattr (nodes_module , '_model_offload_device_override' , None )
983- if nodes_model_offload_override :
984- setattr (nodes_module , 'offload_device' , nodes_model_offload_override )
985- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] Using model offload override for nodes.py: { nodes_model_offload_override } " )
986- elif device == "cpu" :
987- setattr (nodes_module , 'offload_device' , selected_device )
988- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] Both modules patched successfully" )
989-
990- # Call original loader with our patches in place
991- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] Calling original loader with patched device" )
992- result = original_loader .loadmodel (model , base_precision , load_device , quantization ,
993- compile_args , attention_mode , block_swap_args , lora , vram_management_args , vace_model , fantasytalking_model , multitalk_model )
994-
995- # Leave patches in place for subsequent operations
996- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] Model loaded on { selected_device } " )
997- logging .info (f"[MultiGPU WanVideoModelLoader_TWO] ========== COMPLETE ==========" )
998-
999- return result
1000- else :
1001- logging .error (f"[MultiGPU WanVideoModelLoader_TWO] Could not patch modules, falling back" )
1002- return original_loader .loadmodel (model , base_precision , load_device , quantization ,
1003- compile_args , attention_mode , block_swap_args , lora , vram_management_args , vace_model , fantasytalking_model , multitalk_model )
939+ original_sampler = NODE_CLASS_MAPPINGS ["WanVideoSampler" ]()
940+ return original_sampler .process (model , ** kwargs )
1004941
1005942
1006943class WanVideoBlockSwap :
0 commit comments