55import copy
66import json
77import importlib
8+ from contextlib import contextmanager
89from datetime import datetime
910from pathlib import Path
1011import folder_paths
1112import comfy .model_management as mm
13+ import comfy .memory_management
1214import comfy .model_patcher
15+ import comfy .sample as comfy_sample
1316from nodes import NODE_CLASS_MAPPINGS as GLOBAL_NODE_CLASS_MAPPINGS
1417from .device_utils import (
1518 get_device_list ,
@@ -171,6 +174,9 @@ def check_module_exists(module_path):
171174current_device = mm .get_torch_device ()
172175current_text_encoder_device = mm .text_encoder_device ()
173176current_unet_offload_device = mm .unet_offload_device ()
177+ _aimdo_initialized_devices = set ()
178+ if isinstance (current_device , torch .device ) and current_device .type == "cuda" and current_device .index is not None :
179+ _aimdo_initialized_devices .add (current_device .index )
174180
175181def set_current_device (device ):
176182 """Set the current device context for MultiGPU operations."""
@@ -205,6 +211,73 @@ def get_current_unet_offload_device():
205211 """Get the current UNet offload device context at runtime."""
206212 return current_unet_offload_device
207213
214+ def _coerce_torch_device (device ):
215+ """Best-effort conversion to torch.device for guard and patch helpers."""
216+ if device is None :
217+ return None
218+ if isinstance (device , torch .device ):
219+ return device
220+ try :
221+ return torch .device (device )
222+ except (TypeError , RuntimeError , ValueError ):
223+ return None
224+
225+ @contextmanager
226+ def cuda_device_guard (device , reason = "runtime" ):
227+ """Temporarily switch the real CUDA current device for non-primary execution paths."""
228+ target_device = _coerce_torch_device (device )
229+ previous_device_index = None
230+ switched_device = False
231+
232+ if (
233+ target_device is not None
234+ and target_device .type == "cuda"
235+ and target_device .index is not None
236+ and torch .cuda .is_available ()
237+ ):
238+ previous_device_index = torch .cuda .current_device ()
239+ if previous_device_index != target_device .index :
240+ logger .info (
241+ f"[MultiGPU CUDA Guard] Switching CUDA current device { previous_device_index } -> { target_device .index } ({ reason } )"
242+ )
243+ torch .cuda .set_device (target_device .index )
244+ switched_device = True
245+
246+ try :
247+ yield target_device
248+ finally :
249+ if switched_device and previous_device_index is not None :
250+ torch .cuda .set_device (previous_device_index )
251+ logger .info (
252+ f"[MultiGPU CUDA Guard] Restored CUDA current device { target_device .index } -> { previous_device_index } ({ reason } )"
253+ )
254+
255+ def _get_runtime_device_from_model (model ):
256+ """Resolve the actual execution device from a model or patcher wrapper."""
257+ if hasattr (model , "load_device" ):
258+ return getattr (model , "load_device" )
259+ patcher = getattr (model , "patcher" , None )
260+ if patcher is not None and hasattr (patcher , "load_device" ):
261+ return patcher .load_device
262+ inner_model = getattr (model , "model" , None )
263+ if inner_model is not None and hasattr (inner_model , "load_device" ):
264+ return inner_model .load_device
265+ return None
266+
267+ @contextmanager
268+ def multigpu_runtime_device_guard (device , reason = "runtime" ):
269+ """Align MultiGPU logical device state with the real runtime device for inference."""
270+ original_device = get_current_device ()
271+ target_device = _coerce_torch_device (device ) or device
272+ if target_device is not None :
273+ set_current_device (target_device )
274+ logger .info (f"[MultiGPU Runtime] Using runtime device { target_device } ({ reason } )" )
275+ try :
276+ with cuda_device_guard (target_device , reason = reason ):
277+ yield _coerce_torch_device (target_device )
278+ finally :
279+ set_current_device (original_device )
280+
208281def get_torch_device_patched ():
209282 """Return MultiGPU-aware device selection for patched mm.get_torch_device."""
210283 device = None
@@ -238,6 +311,85 @@ def unet_offload_device_patched():
238311 logger .debug (f"[MultiGPU Core Patching] unet_offload_device_patched returning device: { device } (current_unet_offload_device={ current_unet_offload_device } )" )
239312 return device
240313
314+ def _patch_model_management_current_stream ():
315+ """Make ComfyUI stream lookup honor the requested CUDA device."""
316+ current_stream = getattr (mm , "current_stream" , None )
317+ if current_stream is None :
318+ return False
319+ if getattr (current_stream , "_multigpu_cuda_device_aware" , False ):
320+ return True
321+
322+ def current_stream_device_aware (device ):
323+ target_device = _coerce_torch_device (device )
324+ if target_device is not None and target_device .type == "cuda" :
325+ return torch .cuda .current_stream (device = target_device )
326+ return current_stream (device )
327+
328+ current_stream_device_aware ._multigpu_cuda_device_aware = True
329+ current_stream_device_aware ._multigpu_original = current_stream
330+ mm .current_stream = current_stream_device_aware
331+ logger .info ("[MultiGPU] Patched comfy.model_management.current_stream to honor CUDA device arguments" )
332+ return True
333+
334+ def _initialize_aimdo_visible_cuda_devices ():
335+ """Ensure DynamicVRAM initializes every visible CUDA device once when enabled."""
336+ if not getattr (comfy .memory_management , "aimdo_enabled" , False ):
337+ logger .info ("[MultiGPU] DynamicVRAM not enabled; skipping multi-device aimdo initialization" )
338+ return False
339+ if not torch .cuda .is_available ():
340+ logger .info ("[MultiGPU] CUDA unavailable; skipping multi-device aimdo initialization" )
341+ return False
342+
343+ try :
344+ from comfy_aimdo import control as aimdo_control
345+ except ImportError :
346+ logger .warning ("[MultiGPU] comfy_aimdo unavailable during multi-device initialization" )
347+ return False
348+
349+ init_device = getattr (aimdo_control , "init_device" , None )
350+ if not callable (init_device ):
351+ logger .warning ("[MultiGPU] comfy_aimdo.control.init_device missing; skipping multi-device initialization" )
352+ return False
353+
354+ initialized_any = False
355+ for device_index in range (torch .cuda .device_count ()):
356+ if device_index in _aimdo_initialized_devices :
357+ continue
358+ logger .info (f"[MultiGPU] Initializing comfy_aimdo for CUDA device { device_index } " )
359+ initialized = bool (init_device (device_index ))
360+ logger .info (f"[MultiGPU] comfy_aimdo init_device({ device_index } ) -> { initialized } " )
361+ if initialized :
362+ _aimdo_initialized_devices .add (device_index )
363+ initialized_any = True
364+
365+ return initialized_any
366+
367+ def _patch_comfy_sample_runtime_device ():
368+ """Wrap Comfy sampling entrypoints so runtime device state matches the model load device."""
369+ sample_fn = getattr (comfy_sample , "sample" , None )
370+ if callable (sample_fn ) and not getattr (sample_fn , "_multigpu_runtime_device_guard" , False ):
371+ def sample_with_runtime_device (model , * args , ** kwargs ):
372+ runtime_device = _get_runtime_device_from_model (model )
373+ with multigpu_runtime_device_guard (runtime_device , reason = f"comfy.sample.sample:{ type (model ).__name__ } " ):
374+ return sample_fn (model , * args , ** kwargs )
375+
376+ sample_with_runtime_device ._multigpu_runtime_device_guard = True
377+ sample_with_runtime_device ._multigpu_original = sample_fn
378+ comfy_sample .sample = sample_with_runtime_device
379+ logger .info ("[MultiGPU] Patched comfy.sample.sample with runtime device guard" )
380+
381+ sample_custom_fn = getattr (comfy_sample , "sample_custom" , None )
382+ if callable (sample_custom_fn ) and not getattr (sample_custom_fn , "_multigpu_runtime_device_guard" , False ):
383+ def sample_custom_with_runtime_device (model , * args , ** kwargs ):
384+ runtime_device = _get_runtime_device_from_model (model )
385+ with multigpu_runtime_device_guard (runtime_device , reason = f"comfy.sample.sample_custom:{ type (model ).__name__ } " ):
386+ return sample_custom_fn (model , * args , ** kwargs )
387+
388+ sample_custom_with_runtime_device ._multigpu_runtime_device_guard = True
389+ sample_custom_with_runtime_device ._multigpu_original = sample_custom_fn
390+ comfy_sample .sample_custom = sample_custom_with_runtime_device
391+ logger .info ("[MultiGPU] Patched comfy.sample.sample_custom with runtime device guard" )
392+
241393def _patch_comfy_kitchen_dlpack_device_guard ():
242394 """Guard comfy_kitchen DLPack export by switching to the tensor's CUDA device."""
243395 try :
@@ -256,20 +408,8 @@ def _patch_comfy_kitchen_dlpack_device_guard():
256408
257409 def wrap_for_dlpack_with_device_guard (* args , ** kwargs ):
258410 tensor = args [0 ] if args else kwargs .get ("tensor" )
259- previous_device_index = None
260- switched_device = False
261-
262- if isinstance (tensor , torch .Tensor ) and tensor .is_cuda and tensor .device .index is not None :
263- previous_device_index = torch .cuda .current_device ()
264- if previous_device_index != tensor .device .index :
265- torch .cuda .set_device (tensor .device .index )
266- switched_device = True
267-
268- try :
411+ with cuda_device_guard (getattr (tensor , "device" , None ), reason = "comfy_kitchen._wrap_for_dlpack" ):
269412 return wrap_for_dlpack (* args , ** kwargs )
270- finally :
271- if switched_device and previous_device_index is not None :
272- torch .cuda .set_device (previous_device_index )
273413
274414 wrap_for_dlpack_with_device_guard ._multigpu_cuda_device_guard = True
275415 comfy_kitchen_cuda ._wrap_for_dlpack = wrap_for_dlpack_with_device_guard
@@ -284,7 +424,10 @@ def wrap_for_dlpack_with_device_guard(*args, **kwargs):
284424mm .get_torch_device = get_torch_device_patched
285425mm .text_encoder_device = text_encoder_device_patched
286426mm .unet_offload_device = unet_offload_device_patched
427+ _patch_model_management_current_stream ()
428+ _patch_comfy_sample_runtime_device ()
287429_patch_comfy_kitchen_dlpack_device_guard ()
430+ _initialize_aimdo_visible_cuda_devices ()
288431
289432from .nodes import (
290433 DeviceSelectorMultiGPU ,
0 commit comments