Skip to content

Commit 7af256a

Browse files
author
Your Name
committed
fix: harden startup compat and restore low-risk MultiGPU helpers
1 parent ac3df4e commit 7af256a

4 files changed

Lines changed: 148 additions & 57 deletions

File tree

__init__.py

Lines changed: 125 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import copy
66
import json
7+
import importlib
78
from datetime import datetime
89
from pathlib import Path
910
import folder_paths
@@ -136,15 +137,36 @@ def mgpu_mm_log_method(self, msg):
136137
)
137138
logger.mgpu_mm_log = mgpu_mm_log_method.__get__(logger, type(logger))
138139

140+
def _normalize_module_name(module_name):
141+
"""Normalize a custom node directory name for tolerant matching."""
142+
return "".join(char for char in os.path.basename(module_name).lower() if char.isalnum())
143+
139144
def check_module_exists(module_path):
140145
"""Check if a custom node module exists in ComfyUI custom_nodes directory."""
141-
full_path = os.path.join(folder_paths.get_folder_paths("custom_nodes")[0], module_path)
142-
logger.debug(f"[MultiGPU] Checking for module at {full_path}")
143-
if not os.path.exists(full_path):
144-
logger.debug(f"[MultiGPU] Module {module_path} not found - skipping")
145-
return False
146-
logger.debug(f"[MultiGPU] Found {module_path}, creating compatible MultiGPU nodes")
147-
return True
146+
custom_nodes_paths = folder_paths.get_folder_paths("custom_nodes")
147+
normalized_module_path = _normalize_module_name(module_path)
148+
149+
for custom_nodes_path in custom_nodes_paths:
150+
full_path = os.path.join(custom_nodes_path, module_path)
151+
logger.debug(f"[MultiGPU] Checking for module at {full_path}")
152+
if os.path.isdir(full_path):
153+
logger.debug(f"[MultiGPU] Found exact module match for {module_path} at {full_path}")
154+
return True
155+
156+
for custom_nodes_path in custom_nodes_paths:
157+
try:
158+
with os.scandir(custom_nodes_path) as entries:
159+
for entry in entries:
160+
if not entry.is_dir():
161+
continue
162+
if _normalize_module_name(entry.name) == normalized_module_path:
163+
logger.debug(f"[MultiGPU] Found normalized module match for {module_path} at {entry.path}")
164+
return True
165+
except OSError:
166+
continue
167+
168+
logger.debug(f"[MultiGPU] Module {module_path} not found - skipping")
169+
return False
148170

149171
current_device = mm.get_torch_device()
150172
current_text_encoder_device = mm.text_encoder_device()
@@ -216,6 +238,44 @@ def unet_offload_device_patched():
216238
logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})")
217239
return device
218240

241+
def _patch_comfy_kitchen_dlpack_device_guard():
242+
"""Guard comfy_kitchen DLPack export by switching to the tensor's CUDA device."""
243+
try:
244+
comfy_kitchen_cuda = importlib.import_module("comfy_kitchen.backends.cuda")
245+
except ImportError:
246+
logger.debug("[MultiGPU] comfy_kitchen not found - skipping CUDA DLPack compat patch")
247+
return False
248+
249+
wrap_for_dlpack = getattr(comfy_kitchen_cuda, "_wrap_for_dlpack", None)
250+
if wrap_for_dlpack is None:
251+
logger.debug("[MultiGPU] comfy_kitchen.backends.cuda._wrap_for_dlpack not found - skipping compat patch")
252+
return False
253+
254+
if getattr(wrap_for_dlpack, "_multigpu_cuda_device_guard", False):
255+
return True
256+
257+
def wrap_for_dlpack_with_device_guard(*args, **kwargs):
258+
tensor = args[0] if args else kwargs.get("tensor")
259+
previous_device_index = None
260+
switched_device = False
261+
262+
if isinstance(tensor, torch.Tensor) and tensor.is_cuda and tensor.device.index is not None:
263+
previous_device_index = torch.cuda.current_device()
264+
if previous_device_index != tensor.device.index:
265+
torch.cuda.set_device(tensor.device.index)
266+
switched_device = True
267+
268+
try:
269+
return wrap_for_dlpack(*args, **kwargs)
270+
finally:
271+
if switched_device and previous_device_index is not None:
272+
torch.cuda.set_device(previous_device_index)
273+
274+
wrap_for_dlpack_with_device_guard._multigpu_cuda_device_guard = True
275+
comfy_kitchen_cuda._wrap_for_dlpack = wrap_for_dlpack_with_device_guard
276+
logger.info("[MultiGPU] Applied comfy_kitchen CUDA DLPack device guard patch")
277+
return True
278+
219279
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
220280
logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
221281
logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
@@ -224,8 +284,10 @@ def unet_offload_device_patched():
224284
mm.get_torch_device = get_torch_device_patched
225285
mm.text_encoder_device = text_encoder_device_patched
226286
mm.unet_offload_device = unet_offload_device_patched
287+
_patch_comfy_kitchen_dlpack_device_guard()
227288

228289
from .nodes import (
290+
DeviceSelectorMultiGPU,
229291
UnetLoaderGGUF,
230292
UnetLoaderGGUFAdvanced,
231293
CLIPLoaderGGUF,
@@ -246,29 +308,6 @@ def unet_offload_device_patched():
246308
UNetLoaderLP,
247309
)
248310

249-
from .wanvideo import (
250-
LoadWanVideoT5TextEncoder,
251-
WanVideoTextEncode,
252-
WanVideoTextEncodeCached,
253-
WanVideoTextEncodeSingle,
254-
WanVideoVAELoader,
255-
WanVideoTinyVAELoader,
256-
WanVideoBlockSwap,
257-
WanVideoImageToVideoEncode,
258-
WanVideoDecode,
259-
WanVideoModelLoader,
260-
WanVideoSampler,
261-
WanVideoVACEEncode,
262-
WanVideoEncode,
263-
LoadWanVideoClipTextEncoder,
264-
WanVideoClipVisionEncode,
265-
WanVideoControlnetLoader,
266-
FantasyTalkingModelLoader,
267-
Wav2VecModelLoader,
268-
WanVideoUni3C_ControlnetLoader,
269-
DownloadAndLoadWav2VecModel,
270-
)
271-
272311
from .wrappers import (
273312
override_class,
274313
override_class_offload,
@@ -294,9 +333,57 @@ def unet_offload_device_patched():
294333
CheckpointLoaderAdvancedDisTorch2MultiGPU
295334
)
296335

336+
def _load_wanvideo_nodes():
337+
from .wanvideo import (
338+
LoadWanVideoT5TextEncoder,
339+
WanVideoTextEncode,
340+
WanVideoTextEncodeCached,
341+
WanVideoTextEncodeSingle,
342+
WanVideoVAELoader,
343+
WanVideoTinyVAELoader,
344+
WanVideoBlockSwap,
345+
WanVideoImageToVideoEncode,
346+
WanVideoDecode,
347+
WanVideoModelLoader,
348+
WanVideoSampler,
349+
WanVideoVACEEncode,
350+
WanVideoEncode,
351+
LoadWanVideoClipTextEncoder,
352+
WanVideoClipVisionEncode,
353+
WanVideoControlnetLoader,
354+
FantasyTalkingModelLoader,
355+
Wav2VecModelLoader,
356+
WanVideoUni3C_ControlnetLoader,
357+
DownloadAndLoadWav2VecModel,
358+
)
359+
360+
return {
361+
"LoadWanVideoT5TextEncoderMultiGPU": LoadWanVideoT5TextEncoder,
362+
"WanVideoTextEncodeMultiGPU": WanVideoTextEncode,
363+
"WanVideoTextEncodeCachedMultiGPU": WanVideoTextEncodeCached,
364+
"WanVideoTextEncodeSingleMultiGPU": WanVideoTextEncodeSingle,
365+
"WanVideoVAELoaderMultiGPU": WanVideoVAELoader,
366+
"WanVideoTinyVAELoaderMultiGPU": WanVideoTinyVAELoader,
367+
"WanVideoBlockSwapMultiGPU": WanVideoBlockSwap,
368+
"WanVideoImageToVideoEncodeMultiGPU": WanVideoImageToVideoEncode,
369+
"WanVideoDecodeMultiGPU": WanVideoDecode,
370+
"WanVideoModelLoaderMultiGPU": WanVideoModelLoader,
371+
"WanVideoSamplerMultiGPU": WanVideoSampler,
372+
"WanVideoVACEEncodeMultiGPU": WanVideoVACEEncode,
373+
"WanVideoEncodeMultiGPU": WanVideoEncode,
374+
"LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder,
375+
"WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode,
376+
"WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoader,
377+
"FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoader,
378+
"Wav2VecModelLoaderMultiGPU": Wav2VecModelLoader,
379+
"WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoader,
380+
"DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModel,
381+
}
382+
297383
NODE_CLASS_MAPPINGS = {
298384
"CheckpointLoaderAdvancedMultiGPU": CheckpointLoaderAdvancedMultiGPU,
299385
"CheckpointLoaderAdvancedDisTorch2MultiGPU": CheckpointLoaderAdvancedDisTorch2MultiGPU,
386+
"DeviceSelectorMultiGPU": DeviceSelectorMultiGPU,
300387
"UNetLoaderLP": UNetLoaderLP,
301388
}
302389

@@ -342,8 +429,14 @@ def register_and_count(module_names, node_map):
342429

343430
count = 0
344431
if found:
432+
try:
433+
resolved_node_map = node_map() if callable(node_map) else node_map
434+
except Exception as exc:
435+
logger.warning(f"[MultiGPU] Failed to register nodes for {module_names[0]}: {exc}")
436+
resolved_node_map = {}
437+
345438
initial_len = len(NODE_CLASS_MAPPINGS)
346-
for key, value in node_map.items():
439+
for key, value in resolved_node_map.items():
347440
NODE_CLASS_MAPPINGS[key] = value
348441
count = len(NODE_CLASS_MAPPINGS) - initial_len
349442

@@ -401,29 +494,7 @@ def register_and_count(module_names, node_map):
401494
}
402495
register_and_count(["PuLID_ComfyUI", "pulid_comfyui"], pulid_nodes)
403496

404-
wanvideo_nodes = {
405-
"LoadWanVideoT5TextEncoderMultiGPU": LoadWanVideoT5TextEncoder,
406-
"WanVideoTextEncodeMultiGPU": WanVideoTextEncode,
407-
"WanVideoTextEncodeCachedMultiGPU": WanVideoTextEncodeCached,
408-
"WanVideoTextEncodeSingleMultiGPU": WanVideoTextEncodeSingle,
409-
"WanVideoVAELoaderMultiGPU": WanVideoVAELoader,
410-
"WanVideoTinyVAELoaderMultiGPU": WanVideoTinyVAELoader,
411-
"WanVideoBlockSwapMultiGPU": WanVideoBlockSwap,
412-
"WanVideoImageToVideoEncodeMultiGPU": WanVideoImageToVideoEncode,
413-
"WanVideoDecodeMultiGPU": WanVideoDecode,
414-
"WanVideoModelLoaderMultiGPU": WanVideoModelLoader,
415-
"WanVideoSamplerMultiGPU": WanVideoSampler,
416-
"WanVideoVACEEncodeMultiGPU": WanVideoVACEEncode,
417-
"WanVideoEncodeMultiGPU": WanVideoEncode,
418-
"LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder,
419-
"WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode,
420-
"WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoader,
421-
"FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoader,
422-
"Wav2VecModelLoaderMultiGPU": Wav2VecModelLoader,
423-
"WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoader,
424-
"DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModel,
425-
}
426-
register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], wanvideo_nodes)
497+
register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], _load_wanvideo_nodes)
427498

428499
for item in registration_data:
429500
logger.info(fmt_reg.format(item['name'], item['found'], str(item['count'])))

device_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def soft_empty_cache_multigpu():
175175
logger.mgpu_mm_log(f"Clearing CUDA cache on {device_str} (idx={device_idx})")
176176
multigpu_memory_log("general", f"pre-empty:{device_str}")
177177
with torch.cuda.device(device_idx):
178+
torch.cuda.synchronize()
178179
torch.cuda.empty_cache()
179180
if hasattr(torch.cuda, "ipc_collect"):
180181
torch.cuda.ipc_collect()

nodes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,22 @@
55
from .device_utils import get_device_list
66
from .model_management_mgpu import force_full_system_cleanup
77

8+
class DeviceSelectorMultiGPU:
9+
@classmethod
10+
def INPUT_TYPES(s):
11+
devices = get_device_list()
12+
return {"required": {"device": (devices,)}}
13+
14+
RETURN_TYPES = ("MULTIGPUDEVICE",)
15+
RETURN_NAMES = ("device",)
16+
FUNCTION = "select_device"
17+
CATEGORY = "multigpu"
18+
TITLE = "Device Selector (MultiGPU)"
19+
20+
def select_device(self, device):
21+
"""Return the selected device label without side effects."""
22+
return (device,)
23+
824
class UnetLoaderGGUF:
925
@classmethod
1026
def INPUT_TYPES(s):

wanvideo.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from comfy.utils import load_torch_file, ProgressBar
1212
import gc
1313
import numpy as np
14-
from accelerate import init_empty_weights
1514
import os
16-
import importlib.util
1715

1816
logger = logging.getLogger("MultiGPU")
1917

@@ -64,6 +62,7 @@ def INPUT_TYPES(s):
6462
"lora": ("WANVIDLORA", {"default": None}),
6563
"vram_management_args": ("VRAM_MANAGEMENTARGS", {"default": None, "tooltip": "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower"}),
6664
"extra_model": ("VACEPATH", {"default": None, "tooltip": "Extra model to add to the main model, ie. VACE or MTV Crafter"}),
65+
"vace_model": ("VACEPATH", {"default": None, "tooltip": "Backward-compatible alias for extra_model"}),
6766
"fantasytalking_model": ("FANTASYTALKINGMODEL", {"default": None, "tooltip": "FantasyTalking model https://github.com/Fantasy-AMAP"}),
6867
"multitalk_model": ("MULTITALKMODEL", {"default": None, "tooltip": "Multitalk model"}),
6968
"fantasyportrait_model": ("FANTASYPORTRAITMODEL", {"default": None, "tooltip": "FantasyPortrait model"}),
@@ -83,6 +82,10 @@ def loadmodel(self, model, base_precision, compute_device, quantization, load_de
8382
loader_module = inspect.getmodule(original_loader)
8483
original_module_device = loader_module.device
8584

85+
vace_model = kwargs.pop("vace_model", None)
86+
if kwargs.get("extra_model") is None and vace_model is not None:
87+
kwargs["extra_model"] = vace_model
88+
8689
set_current_device(compute_device)
8790
compute_device_to_be_patched = mm.get_torch_device()
8891

@@ -863,4 +866,4 @@ def loadmodel(self, model, base_precision, load_device, device, quantization, at
863866
set_current_device(device)
864867

865868
original_loader = NODE_CLASS_MAPPINGS["WanVideoUni3C_ControlnetLoader"]()
866-
return original_loader.loadmodel(model, base_precision, load_device, quantization, attention_mode, compile_args)
869+
return original_loader.loadmodel(model, base_precision, load_device, quantization, attention_mode, compile_args)

0 commit comments

Comments
 (0)