44import os
55import copy
66import json
7+ import importlib
78from datetime import datetime
89from pathlib import Path
910import folder_paths
@@ -136,15 +137,36 @@ def mgpu_mm_log_method(self, msg):
136137 )
137138logger .mgpu_mm_log = mgpu_mm_log_method .__get__ (logger , type (logger ))
138139
140+ def _normalize_module_name (module_name ):
141+ """Normalize a custom node directory name for tolerant matching."""
142+ return "" .join (char for char in os .path .basename (module_name ).lower () if char .isalnum ())
143+
139144def check_module_exists (module_path ):
140145 """Check if a custom node module exists in ComfyUI custom_nodes directory."""
141- full_path = os .path .join (folder_paths .get_folder_paths ("custom_nodes" )[0 ], module_path )
142- logger .debug (f"[MultiGPU] Checking for module at { full_path } " )
143- if not os .path .exists (full_path ):
144- logger .debug (f"[MultiGPU] Module { module_path } not found - skipping" )
145- return False
146- logger .debug (f"[MultiGPU] Found { module_path } , creating compatible MultiGPU nodes" )
147- return True
146+ custom_nodes_paths = folder_paths .get_folder_paths ("custom_nodes" )
147+ normalized_module_path = _normalize_module_name (module_path )
148+
149+ for custom_nodes_path in custom_nodes_paths :
150+ full_path = os .path .join (custom_nodes_path , module_path )
151+ logger .debug (f"[MultiGPU] Checking for module at { full_path } " )
152+ if os .path .isdir (full_path ):
153+ logger .debug (f"[MultiGPU] Found exact module match for { module_path } at { full_path } " )
154+ return True
155+
156+ for custom_nodes_path in custom_nodes_paths :
157+ try :
158+ with os .scandir (custom_nodes_path ) as entries :
159+ for entry in entries :
160+ if not entry .is_dir ():
161+ continue
162+ if _normalize_module_name (entry .name ) == normalized_module_path :
163+ logger .debug (f"[MultiGPU] Found normalized module match for { module_path } at { entry .path } " )
164+ return True
165+ except OSError :
166+ continue
167+
168+ logger .debug (f"[MultiGPU] Module { module_path } not found - skipping" )
169+ return False
148170
149171current_device = mm .get_torch_device ()
150172current_text_encoder_device = mm .text_encoder_device ()
@@ -216,6 +238,44 @@ def unet_offload_device_patched():
216238 logger .debug (f"[MultiGPU Core Patching] unet_offload_device_patched returning device: { device } (current_unet_offload_device={ current_unet_offload_device } )" )
217239 return device
218240
241+ def _patch_comfy_kitchen_dlpack_device_guard ():
242+ """Guard comfy_kitchen DLPack export by switching to the tensor's CUDA device."""
243+ try :
244+ comfy_kitchen_cuda = importlib .import_module ("comfy_kitchen.backends.cuda" )
245+ except ImportError :
246+ logger .debug ("[MultiGPU] comfy_kitchen not found - skipping CUDA DLPack compat patch" )
247+ return False
248+
249+ wrap_for_dlpack = getattr (comfy_kitchen_cuda , "_wrap_for_dlpack" , None )
250+ if wrap_for_dlpack is None :
251+ logger .debug ("[MultiGPU] comfy_kitchen.backends.cuda._wrap_for_dlpack not found - skipping compat patch" )
252+ return False
253+
254+ if getattr (wrap_for_dlpack , "_multigpu_cuda_device_guard" , False ):
255+ return True
256+
257+ def wrap_for_dlpack_with_device_guard (* args , ** kwargs ):
258+ tensor = args [0 ] if args else kwargs .get ("tensor" )
259+ previous_device_index = None
260+ switched_device = False
261+
262+ if isinstance (tensor , torch .Tensor ) and tensor .is_cuda and tensor .device .index is not None :
263+ previous_device_index = torch .cuda .current_device ()
264+ if previous_device_index != tensor .device .index :
265+ torch .cuda .set_device (tensor .device .index )
266+ switched_device = True
267+
268+ try :
269+ return wrap_for_dlpack (* args , ** kwargs )
270+ finally :
271+ if switched_device and previous_device_index is not None :
272+ torch .cuda .set_device (previous_device_index )
273+
274+ wrap_for_dlpack_with_device_guard ._multigpu_cuda_device_guard = True
275+ comfy_kitchen_cuda ._wrap_for_dlpack = wrap_for_dlpack_with_device_guard
276+ logger .info ("[MultiGPU] Applied comfy_kitchen CUDA DLPack device guard patch" )
277+ return True
278+
219279logger .info (f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device" )
220280logger .info (f"[MultiGPU DEBUG] Initial current_device: { current_device } " )
221281logger .info (f"[MultiGPU DEBUG] Initial current_text_encoder_device: { current_text_encoder_device } " )
@@ -224,8 +284,10 @@ def unet_offload_device_patched():
224284mm .get_torch_device = get_torch_device_patched
225285mm .text_encoder_device = text_encoder_device_patched
226286mm .unet_offload_device = unet_offload_device_patched
287+ _patch_comfy_kitchen_dlpack_device_guard ()
227288
228289from .nodes import (
290+ DeviceSelectorMultiGPU ,
229291 UnetLoaderGGUF ,
230292 UnetLoaderGGUFAdvanced ,
231293 CLIPLoaderGGUF ,
@@ -246,29 +308,6 @@ def unet_offload_device_patched():
246308 UNetLoaderLP ,
247309)
248310
249- from .wanvideo import (
250- LoadWanVideoT5TextEncoder ,
251- WanVideoTextEncode ,
252- WanVideoTextEncodeCached ,
253- WanVideoTextEncodeSingle ,
254- WanVideoVAELoader ,
255- WanVideoTinyVAELoader ,
256- WanVideoBlockSwap ,
257- WanVideoImageToVideoEncode ,
258- WanVideoDecode ,
259- WanVideoModelLoader ,
260- WanVideoSampler ,
261- WanVideoVACEEncode ,
262- WanVideoEncode ,
263- LoadWanVideoClipTextEncoder ,
264- WanVideoClipVisionEncode ,
265- WanVideoControlnetLoader ,
266- FantasyTalkingModelLoader ,
267- Wav2VecModelLoader ,
268- WanVideoUni3C_ControlnetLoader ,
269- DownloadAndLoadWav2VecModel ,
270- )
271-
272311from .wrappers import (
273312 override_class ,
274313 override_class_offload ,
@@ -294,9 +333,57 @@ def unet_offload_device_patched():
294333 CheckpointLoaderAdvancedDisTorch2MultiGPU
295334)
296335
336+ def _load_wanvideo_nodes ():
337+ from .wanvideo import (
338+ LoadWanVideoT5TextEncoder ,
339+ WanVideoTextEncode ,
340+ WanVideoTextEncodeCached ,
341+ WanVideoTextEncodeSingle ,
342+ WanVideoVAELoader ,
343+ WanVideoTinyVAELoader ,
344+ WanVideoBlockSwap ,
345+ WanVideoImageToVideoEncode ,
346+ WanVideoDecode ,
347+ WanVideoModelLoader ,
348+ WanVideoSampler ,
349+ WanVideoVACEEncode ,
350+ WanVideoEncode ,
351+ LoadWanVideoClipTextEncoder ,
352+ WanVideoClipVisionEncode ,
353+ WanVideoControlnetLoader ,
354+ FantasyTalkingModelLoader ,
355+ Wav2VecModelLoader ,
356+ WanVideoUni3C_ControlnetLoader ,
357+ DownloadAndLoadWav2VecModel ,
358+ )
359+
360+ return {
361+ "LoadWanVideoT5TextEncoderMultiGPU" : LoadWanVideoT5TextEncoder ,
362+ "WanVideoTextEncodeMultiGPU" : WanVideoTextEncode ,
363+ "WanVideoTextEncodeCachedMultiGPU" : WanVideoTextEncodeCached ,
364+ "WanVideoTextEncodeSingleMultiGPU" : WanVideoTextEncodeSingle ,
365+ "WanVideoVAELoaderMultiGPU" : WanVideoVAELoader ,
366+ "WanVideoTinyVAELoaderMultiGPU" : WanVideoTinyVAELoader ,
367+ "WanVideoBlockSwapMultiGPU" : WanVideoBlockSwap ,
368+ "WanVideoImageToVideoEncodeMultiGPU" : WanVideoImageToVideoEncode ,
369+ "WanVideoDecodeMultiGPU" : WanVideoDecode ,
370+ "WanVideoModelLoaderMultiGPU" : WanVideoModelLoader ,
371+ "WanVideoSamplerMultiGPU" : WanVideoSampler ,
372+ "WanVideoVACEEncodeMultiGPU" : WanVideoVACEEncode ,
373+ "WanVideoEncodeMultiGPU" : WanVideoEncode ,
374+ "LoadWanVideoClipTextEncoderMultiGPU" : LoadWanVideoClipTextEncoder ,
375+ "WanVideoClipVisionEncodeMultiGPU" : WanVideoClipVisionEncode ,
376+ "WanVideoControlnetLoaderMultiGPU" : WanVideoControlnetLoader ,
377+ "FantasyTalkingModelLoaderMultiGPU" : FantasyTalkingModelLoader ,
378+ "Wav2VecModelLoaderMultiGPU" : Wav2VecModelLoader ,
379+ "WanVideoUni3C_ControlnetLoaderMultiGPU" : WanVideoUni3C_ControlnetLoader ,
380+ "DownloadAndLoadWav2VecModelMultiGPU" : DownloadAndLoadWav2VecModel ,
381+ }
382+
297383NODE_CLASS_MAPPINGS = {
298384 "CheckpointLoaderAdvancedMultiGPU" : CheckpointLoaderAdvancedMultiGPU ,
299385 "CheckpointLoaderAdvancedDisTorch2MultiGPU" : CheckpointLoaderAdvancedDisTorch2MultiGPU ,
386+ "DeviceSelectorMultiGPU" : DeviceSelectorMultiGPU ,
300387 "UNetLoaderLP" : UNetLoaderLP ,
301388}
302389
@@ -342,8 +429,14 @@ def register_and_count(module_names, node_map):
342429
343430 count = 0
344431 if found :
432+ try :
433+ resolved_node_map = node_map () if callable (node_map ) else node_map
434+ except Exception as exc :
435+ logger .warning (f"[MultiGPU] Failed to register nodes for { module_names [0 ]} : { exc } " )
436+ resolved_node_map = {}
437+
345438 initial_len = len (NODE_CLASS_MAPPINGS )
346- for key , value in node_map .items ():
439+ for key , value in resolved_node_map .items ():
347440 NODE_CLASS_MAPPINGS [key ] = value
348441 count = len (NODE_CLASS_MAPPINGS ) - initial_len
349442
@@ -401,29 +494,7 @@ def register_and_count(module_names, node_map):
401494}
402495register_and_count (["PuLID_ComfyUI" , "pulid_comfyui" ], pulid_nodes )
403496
404- wanvideo_nodes = {
405- "LoadWanVideoT5TextEncoderMultiGPU" : LoadWanVideoT5TextEncoder ,
406- "WanVideoTextEncodeMultiGPU" : WanVideoTextEncode ,
407- "WanVideoTextEncodeCachedMultiGPU" : WanVideoTextEncodeCached ,
408- "WanVideoTextEncodeSingleMultiGPU" : WanVideoTextEncodeSingle ,
409- "WanVideoVAELoaderMultiGPU" : WanVideoVAELoader ,
410- "WanVideoTinyVAELoaderMultiGPU" : WanVideoTinyVAELoader ,
411- "WanVideoBlockSwapMultiGPU" : WanVideoBlockSwap ,
412- "WanVideoImageToVideoEncodeMultiGPU" : WanVideoImageToVideoEncode ,
413- "WanVideoDecodeMultiGPU" : WanVideoDecode ,
414- "WanVideoModelLoaderMultiGPU" : WanVideoModelLoader ,
415- "WanVideoSamplerMultiGPU" : WanVideoSampler ,
416- "WanVideoVACEEncodeMultiGPU" : WanVideoVACEEncode ,
417- "WanVideoEncodeMultiGPU" : WanVideoEncode ,
418- "LoadWanVideoClipTextEncoderMultiGPU" : LoadWanVideoClipTextEncoder ,
419- "WanVideoClipVisionEncodeMultiGPU" : WanVideoClipVisionEncode ,
420- "WanVideoControlnetLoaderMultiGPU" : WanVideoControlnetLoader ,
421- "FantasyTalkingModelLoaderMultiGPU" : FantasyTalkingModelLoader ,
422- "Wav2VecModelLoaderMultiGPU" : Wav2VecModelLoader ,
423- "WanVideoUni3C_ControlnetLoaderMultiGPU" : WanVideoUni3C_ControlnetLoader ,
424- "DownloadAndLoadWav2VecModelMultiGPU" : DownloadAndLoadWav2VecModel ,
425- }
426- register_and_count (["ComfyUI-WanVideoWrapper" , "comfyui-wanvideowrapper" ], wanvideo_nodes )
497+ register_and_count (["ComfyUI-WanVideoWrapper" , "comfyui-wanvideowrapper" ], _load_wanvideo_nodes )
427498
428499for item in registration_data :
429500 logger .info (fmt_reg .format (item ['name' ], item ['found' ], str (item ['count' ])))
0 commit comments