Skip to content

Commit 8c4034f

Browse files
authored
Merge pull request #180 from pollockjj/fix-dlpack-p2p-cu130
Fix DLPack P2P Cross-Device Transfer (cu130) I believe this fixes the issue. Ran on my own mixed-mode 5090/3090 setup and passed 0/0 1/1 0/1 1/0
2 parents e64cdf7 + b526e11 commit 8c4034f

3 files changed

Lines changed: 109 additions & 5 deletions

File tree

__init__.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def sample_custom_with_runtime_device(model, *args, **kwargs):
390390
logger.info("[MultiGPU] Patched comfy.sample.sample_custom with runtime device guard")
391391

392392
def _patch_comfy_kitchen_dlpack_device_guard():
393-
"""Guard comfy_kitchen DLPack export by switching to the tensor's CUDA device."""
393+
"""Guard comfy_kitchen DLPack export with P2P-aware CPU-staging fallback."""
394394
try:
395395
comfy_kitchen_cuda = importlib.import_module("comfy_kitchen.backends.cuda")
396396
except ImportError:
@@ -405,14 +405,43 @@ def _patch_comfy_kitchen_dlpack_device_guard():
405405
if getattr(wrap_for_dlpack, "_multigpu_cuda_device_guard", False):
406406
return True
407407

408+
from .p2p_registry import p2p_registry
409+
408410
def wrap_for_dlpack_with_device_guard(*args, **kwargs):
409411
tensor = args[0] if args else kwargs.get("tensor")
410-
with cuda_device_guard(getattr(tensor, "device", None), reason="comfy_kitchen._wrap_for_dlpack"):
411-
return wrap_for_dlpack(*args, **kwargs)
412+
tensor_device = getattr(tensor, "device", None)
413+
exec_device = get_current_device()
414+
exec_device = _coerce_torch_device(exec_device)
415+
416+
# Determine if cross-device staging is needed
417+
needs_staging = False
418+
def _valid_cuda(d):
419+
return d is not None and d.type == "cuda" and d.index is not None
420+
421+
if _valid_cuda(tensor_device) and _valid_cuda(exec_device):
422+
if tensor_device.index != exec_device.index and not p2p_registry.can_access_peer(tensor_device.index, exec_device.index):
423+
needs_staging = True
424+
425+
if needs_staging:
426+
logger.info(
427+
f"[MultiGPU DLPack] CPU-staging tensor from cuda:{tensor_device.index} "
428+
f"to cuda:{exec_device.index} (P2P unavailable)"
429+
)
430+
staged_tensor = tensor.to("cpu").to(exec_device)
431+
wrap_for_dlpack_with_device_guard._dlpack_staging_count += 1
432+
with cuda_device_guard(exec_device, reason="comfy_kitchen._wrap_for_dlpack(staged)"):
433+
if args:
434+
return wrap_for_dlpack(staged_tensor, *args[1:], **kwargs)
435+
else:
436+
return wrap_for_dlpack(staged_tensor, **kwargs)
437+
else:
438+
with cuda_device_guard(tensor_device, reason="comfy_kitchen._wrap_for_dlpack"):
439+
return wrap_for_dlpack(*args, **kwargs)
412440

413441
wrap_for_dlpack_with_device_guard._multigpu_cuda_device_guard = True
442+
wrap_for_dlpack_with_device_guard._dlpack_staging_count = 0
414443
comfy_kitchen_cuda._wrap_for_dlpack = wrap_for_dlpack_with_device_guard
415-
logger.info("[MultiGPU] Applied comfy_kitchen CUDA DLPack device guard patch")
444+
logger.info("[MultiGPU] Applied comfy_kitchen CUDA DLPack device guard patch (P2P-aware)")
416445
return True
417446

418447
logger.info("[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")

p2p_registry.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""P2P accessibility registry for multi-GPU DLPack operations.
2+
3+
Caches cudaDeviceCanAccessPeer results per GPU pair to avoid
4+
repeated CUDA runtime API calls.
5+
"""
6+
7+
import ctypes
8+
import logging
9+
import torch
10+
11+
logger = logging.getLogger("MultiGPU")
12+
13+
_libcudart = None
14+
15+
16+
def _get_libcudart():
17+
"""Load libcudart.so once and cache the handle."""
18+
global _libcudart
19+
if _libcudart is None:
20+
_libcudart = ctypes.CDLL("libcudart.so")
21+
return _libcudart
22+
23+
24+
class MultiGPUP2PRegistry:
25+
"""Cached registry for CUDA peer-to-peer accessibility between GPU pairs.
26+
27+
Uses the CUDA runtime cudaDeviceCanAccessPeer API directly via ctypes
28+
because torch.cuda.can_access_peer does not exist in PyTorch 2.10.0+.
29+
Results are cached per (src, dst) pair for the lifetime of the registry.
30+
"""
31+
32+
def __init__(self):
33+
self._cache = {}
34+
35+
@staticmethod
36+
def _raw_can_access_peer(device_a: int, device_b: int) -> bool:
37+
"""Call cudaDeviceCanAccessPeer via ctypes. Returns True if P2P is available."""
38+
lib = _get_libcudart()
39+
can_access = ctypes.c_int(0)
40+
result = lib.cudaDeviceCanAccessPeer(ctypes.byref(can_access), device_a, device_b)
41+
if result != 0:
42+
logger.warning(
43+
f"[MultiGPU P2P] cudaDeviceCanAccessPeer({device_a}, {device_b}) "
44+
f"returned error code {result}, assuming no P2P"
45+
)
46+
return False
47+
return bool(can_access.value)
48+
49+
def can_access_peer(self, src_device: int, dst_device: int) -> bool:
50+
"""Check if src_device can access dst_device memory via P2P.
51+
52+
Results are cached per (src, dst) pair.
53+
"""
54+
if src_device == dst_device:
55+
return True
56+
57+
key = (src_device, dst_device)
58+
if key not in self._cache:
59+
if not torch.cuda.is_available():
60+
self._cache[key] = False
61+
else:
62+
result = self._raw_can_access_peer(src_device, dst_device)
63+
self._cache[key] = result
64+
logger.info(
65+
f"[MultiGPU P2P] can_access_peer({src_device}, {dst_device}) = {result}"
66+
)
67+
return self._cache[key]
68+
69+
def clear_cache(self):
70+
"""Clear the P2P cache (useful for testing)."""
71+
self._cache.clear()
72+
73+
74+
# Module-level singleton
75+
p2p_registry = MultiGPUP2PRegistry()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-multigpu"
33
description = "Provides a suite of custom nodes to manage multiple GPUs for ComfyUI, including advanced model offloading for both GGUF and Safetensor formats with DisTorch, and bespoke MultiGPU support for WanVideoWrapper and other custom nodes."
4-
version = "2.6.0"
4+
version = "2.6.1"
55
license = {file = "LICENSE"}
66

77
[project.urls]

0 commit comments

Comments
 (0)