Skip to content

Commit 24e4e17

Browse files
committed
feat: enhance MultiGPU support with device guards and runtime management
1 parent 7af256a commit 24e4e17

2 files changed

Lines changed: 170 additions & 19 deletions

File tree

__init__.py

Lines changed: 156 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
import copy
66
import json
77
import importlib
8+
from contextlib import contextmanager
89
from datetime import datetime
910
from pathlib import Path
1011
import folder_paths
1112
import comfy.model_management as mm
13+
import comfy.memory_management
1214
import comfy.model_patcher
15+
import comfy.sample as comfy_sample
1316
from nodes import NODE_CLASS_MAPPINGS as GLOBAL_NODE_CLASS_MAPPINGS
1417
from .device_utils import (
1518
get_device_list,
@@ -171,6 +174,9 @@ def check_module_exists(module_path):
171174
current_device = mm.get_torch_device()
172175
current_text_encoder_device = mm.text_encoder_device()
173176
current_unet_offload_device = mm.unet_offload_device()
177+
_aimdo_initialized_devices = set()
178+
if isinstance(current_device, torch.device) and current_device.type == "cuda" and current_device.index is not None:
179+
_aimdo_initialized_devices.add(current_device.index)
174180

175181
def set_current_device(device):
176182
"""Set the current device context for MultiGPU operations."""
@@ -205,6 +211,73 @@ def get_current_unet_offload_device():
205211
"""Get the current UNet offload device context at runtime."""
206212
return current_unet_offload_device
207213

214+
def _coerce_torch_device(device):
215+
"""Best-effort conversion to torch.device for guard and patch helpers."""
216+
if device is None:
217+
return None
218+
if isinstance(device, torch.device):
219+
return device
220+
try:
221+
return torch.device(device)
222+
except (TypeError, RuntimeError, ValueError):
223+
return None
224+
225+
@contextmanager
226+
def cuda_device_guard(device, reason="runtime"):
227+
"""Temporarily switch the real CUDA current device for non-primary execution paths."""
228+
target_device = _coerce_torch_device(device)
229+
previous_device_index = None
230+
switched_device = False
231+
232+
if (
233+
target_device is not None
234+
and target_device.type == "cuda"
235+
and target_device.index is not None
236+
and torch.cuda.is_available()
237+
):
238+
previous_device_index = torch.cuda.current_device()
239+
if previous_device_index != target_device.index:
240+
logger.info(
241+
f"[MultiGPU CUDA Guard] Switching CUDA current device {previous_device_index} -> {target_device.index} ({reason})"
242+
)
243+
torch.cuda.set_device(target_device.index)
244+
switched_device = True
245+
246+
try:
247+
yield target_device
248+
finally:
249+
if switched_device and previous_device_index is not None:
250+
torch.cuda.set_device(previous_device_index)
251+
logger.info(
252+
f"[MultiGPU CUDA Guard] Restored CUDA current device {target_device.index} -> {previous_device_index} ({reason})"
253+
)
254+
255+
def _get_runtime_device_from_model(model):
256+
"""Resolve the actual execution device from a model or patcher wrapper."""
257+
if hasattr(model, "load_device"):
258+
return getattr(model, "load_device")
259+
patcher = getattr(model, "patcher", None)
260+
if patcher is not None and hasattr(patcher, "load_device"):
261+
return patcher.load_device
262+
inner_model = getattr(model, "model", None)
263+
if inner_model is not None and hasattr(inner_model, "load_device"):
264+
return inner_model.load_device
265+
return None
266+
267+
@contextmanager
268+
def multigpu_runtime_device_guard(device, reason="runtime"):
269+
"""Align MultiGPU logical device state with the real runtime device for inference."""
270+
original_device = get_current_device()
271+
target_device = _coerce_torch_device(device) or device
272+
if target_device is not None:
273+
set_current_device(target_device)
274+
logger.info(f"[MultiGPU Runtime] Using runtime device {target_device} ({reason})")
275+
try:
276+
with cuda_device_guard(target_device, reason=reason):
277+
yield _coerce_torch_device(target_device)
278+
finally:
279+
set_current_device(original_device)
280+
208281
def get_torch_device_patched():
209282
"""Return MultiGPU-aware device selection for patched mm.get_torch_device."""
210283
device = None
@@ -238,6 +311,85 @@ def unet_offload_device_patched():
238311
logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})")
239312
return device
240313

314+
def _patch_model_management_current_stream():
315+
"""Make ComfyUI stream lookup honor the requested CUDA device."""
316+
current_stream = getattr(mm, "current_stream", None)
317+
if current_stream is None:
318+
return False
319+
if getattr(current_stream, "_multigpu_cuda_device_aware", False):
320+
return True
321+
322+
def current_stream_device_aware(device):
323+
target_device = _coerce_torch_device(device)
324+
if target_device is not None and target_device.type == "cuda":
325+
return torch.cuda.current_stream(device=target_device)
326+
return current_stream(device)
327+
328+
current_stream_device_aware._multigpu_cuda_device_aware = True
329+
current_stream_device_aware._multigpu_original = current_stream
330+
mm.current_stream = current_stream_device_aware
331+
logger.info("[MultiGPU] Patched comfy.model_management.current_stream to honor CUDA device arguments")
332+
return True
333+
334+
def _initialize_aimdo_visible_cuda_devices():
335+
"""Ensure DynamicVRAM initializes every visible CUDA device once when enabled."""
336+
if not getattr(comfy.memory_management, "aimdo_enabled", False):
337+
logger.info("[MultiGPU] DynamicVRAM not enabled; skipping multi-device aimdo initialization")
338+
return False
339+
if not torch.cuda.is_available():
340+
logger.info("[MultiGPU] CUDA unavailable; skipping multi-device aimdo initialization")
341+
return False
342+
343+
try:
344+
from comfy_aimdo import control as aimdo_control
345+
except ImportError:
346+
logger.warning("[MultiGPU] comfy_aimdo unavailable during multi-device initialization")
347+
return False
348+
349+
init_device = getattr(aimdo_control, "init_device", None)
350+
if not callable(init_device):
351+
logger.warning("[MultiGPU] comfy_aimdo.control.init_device missing; skipping multi-device initialization")
352+
return False
353+
354+
initialized_any = False
355+
for device_index in range(torch.cuda.device_count()):
356+
if device_index in _aimdo_initialized_devices:
357+
continue
358+
logger.info(f"[MultiGPU] Initializing comfy_aimdo for CUDA device {device_index}")
359+
initialized = bool(init_device(device_index))
360+
logger.info(f"[MultiGPU] comfy_aimdo init_device({device_index}) -> {initialized}")
361+
if initialized:
362+
_aimdo_initialized_devices.add(device_index)
363+
initialized_any = True
364+
365+
return initialized_any
366+
367+
def _patch_comfy_sample_runtime_device():
368+
"""Wrap Comfy sampling entrypoints so runtime device state matches the model load device."""
369+
sample_fn = getattr(comfy_sample, "sample", None)
370+
if callable(sample_fn) and not getattr(sample_fn, "_multigpu_runtime_device_guard", False):
371+
def sample_with_runtime_device(model, *args, **kwargs):
372+
runtime_device = _get_runtime_device_from_model(model)
373+
with multigpu_runtime_device_guard(runtime_device, reason=f"comfy.sample.sample:{type(model).__name__}"):
374+
return sample_fn(model, *args, **kwargs)
375+
376+
sample_with_runtime_device._multigpu_runtime_device_guard = True
377+
sample_with_runtime_device._multigpu_original = sample_fn
378+
comfy_sample.sample = sample_with_runtime_device
379+
logger.info("[MultiGPU] Patched comfy.sample.sample with runtime device guard")
380+
381+
sample_custom_fn = getattr(comfy_sample, "sample_custom", None)
382+
if callable(sample_custom_fn) and not getattr(sample_custom_fn, "_multigpu_runtime_device_guard", False):
383+
def sample_custom_with_runtime_device(model, *args, **kwargs):
384+
runtime_device = _get_runtime_device_from_model(model)
385+
with multigpu_runtime_device_guard(runtime_device, reason=f"comfy.sample.sample_custom:{type(model).__name__}"):
386+
return sample_custom_fn(model, *args, **kwargs)
387+
388+
sample_custom_with_runtime_device._multigpu_runtime_device_guard = True
389+
sample_custom_with_runtime_device._multigpu_original = sample_custom_fn
390+
comfy_sample.sample_custom = sample_custom_with_runtime_device
391+
logger.info("[MultiGPU] Patched comfy.sample.sample_custom with runtime device guard")
392+
241393
def _patch_comfy_kitchen_dlpack_device_guard():
242394
"""Guard comfy_kitchen DLPack export by switching to the tensor's CUDA device."""
243395
try:
@@ -256,20 +408,8 @@ def _patch_comfy_kitchen_dlpack_device_guard():
256408

257409
def wrap_for_dlpack_with_device_guard(*args, **kwargs):
258410
tensor = args[0] if args else kwargs.get("tensor")
259-
previous_device_index = None
260-
switched_device = False
261-
262-
if isinstance(tensor, torch.Tensor) and tensor.is_cuda and tensor.device.index is not None:
263-
previous_device_index = torch.cuda.current_device()
264-
if previous_device_index != tensor.device.index:
265-
torch.cuda.set_device(tensor.device.index)
266-
switched_device = True
267-
268-
try:
411+
with cuda_device_guard(getattr(tensor, "device", None), reason="comfy_kitchen._wrap_for_dlpack"):
269412
return wrap_for_dlpack(*args, **kwargs)
270-
finally:
271-
if switched_device and previous_device_index is not None:
272-
torch.cuda.set_device(previous_device_index)
273413

274414
wrap_for_dlpack_with_device_guard._multigpu_cuda_device_guard = True
275415
comfy_kitchen_cuda._wrap_for_dlpack = wrap_for_dlpack_with_device_guard
@@ -284,7 +424,10 @@ def wrap_for_dlpack_with_device_guard(*args, **kwargs):
284424
mm.get_torch_device = get_torch_device_patched
285425
mm.text_encoder_device = text_encoder_device_patched
286426
mm.unet_offload_device = unet_offload_device_patched
427+
_patch_model_management_current_stream()
428+
_patch_comfy_sample_runtime_device()
287429
_patch_comfy_kitchen_dlpack_device_guard()
430+
_initialize_aimdo_visible_cuda_devices()
288431

289432
from .nodes import (
290433
DeviceSelectorMultiGPU,

wrappers.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
447447

448448
def override_class(cls):
449449
"""Standard MultiGPU device override for UNet/VAE models"""
450-
from . import set_current_device, get_current_device
450+
from . import set_current_device, get_current_device, cuda_device_guard
451451

452452
class NodeOverride(cls):
453453
@classmethod
@@ -466,18 +466,25 @@ def override(self, *args, device=None, **kwargs):
466466
original_device = get_current_device()
467467
if device is not None:
468468
set_current_device(device)
469+
target_device = device if device is not None else get_current_device()
469470
fn = getattr(super(), cls.FUNCTION)
470-
out = fn(*args, **kwargs)
471471
try:
472-
return out
472+
with cuda_device_guard(target_device, reason=f"{type(self).__name__}.{cls.FUNCTION}"):
473+
return fn(*args, **kwargs)
473474
finally:
474475
set_current_device(original_device)
475476

476477
return NodeOverride
477478

478479
def override_class_offload(cls):
479480
"""Standard MultiGPU device override for UNet/VAE models"""
480-
from . import set_current_device, set_current_unet_offload_device, get_current_device, get_current_unet_offload_device
481+
from . import (
482+
set_current_device,
483+
set_current_unet_offload_device,
484+
get_current_device,
485+
get_current_unet_offload_device,
486+
cuda_device_guard,
487+
)
481488

482489
class NodeOverride(cls):
483490
@classmethod
@@ -500,10 +507,11 @@ def override(self, *args, device=None, offload_device=None, **kwargs):
500507
set_current_device(device)
501508
if offload_device is not None:
502509
set_current_unet_offload_device(offload_device)
510+
target_device = device if device is not None else get_current_device()
503511
fn = getattr(super(), cls.FUNCTION)
504-
out = fn(*args, **kwargs)
505512
try:
506-
return out
513+
with cuda_device_guard(target_device, reason=f"{type(self).__name__}.{cls.FUNCTION}"):
514+
return fn(*args, **kwargs)
507515
finally:
508516
set_current_device(original_device)
509517
set_current_unet_offload_device(original_offload_device)

0 commit comments

Comments
 (0)