Skip to content

Commit 4d0d4a6

Browse files
committed
fix for issue #87: ComfyU-MultiGPU not supporting all device types currently supported by Comfy Core.
Refactor device detection into dedicated utility module - Extract device enumeration and compatibility checks to device_utils.py - Add support for additional device types (NPU, MLU, DirectML, CoreX) - Update all modules to use centralized device utilities - Implement caching for device list to improve performance - Reduce code duplication across distorch, nodes, and wanvideo modules
1 parent 06bc2c3 commit 4d0d4a6

7 files changed

Lines changed: 242 additions & 62 deletions

File tree

__init__.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import folder_paths
77
import comfy.model_management as mm
88
from nodes import NODE_CLASS_MAPPINGS as GLOBAL_NODE_CLASS_MAPPINGS
9+
from .device_utils import get_device_list, is_accelerator_available
910

1011
# --- DisTorch V2 Logging Configuration ---
1112
# Set to "E" for Engineering (DEBUG) or "P" for Production (INFO)
@@ -29,31 +30,6 @@
2930
current_device = mm.get_torch_device()
3031
current_text_encoder_device = mm.text_encoder_device()
3132

32-
def _has_xpu():
33-
try:
34-
return hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available()
35-
except Exception:
36-
return False
37-
38-
def get_device_list():
39-
devs = ["cpu"]
40-
try:
41-
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
42-
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
43-
except Exception:
44-
pass
45-
try:
46-
if _has_xpu():
47-
devs += [f"xpu:{i}" for i in range(torch.xpu.device_count())]
48-
except Exception:
49-
pass
50-
try:
51-
if torch.backends.mps.is_available():
52-
devs += ["mps"]
53-
except Exception:
54-
pass
55-
return devs
56-
5733
def set_current_device(device):
5834
global current_device
5935
current_device = device
@@ -119,7 +95,7 @@ def override(self, *args, device=None, **kwargs):
11995

12096
def get_torch_device_patched():
12197
device = None
122-
if (not (torch.cuda.is_available() or _has_xpu()) or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_device).lower()):
98+
if (not is_accelerator_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_device).lower()):
12399
device = torch.device("cpu")
124100
else:
125101
devs = set(get_device_list())
@@ -129,7 +105,7 @@ def get_torch_device_patched():
129105

130106
def text_encoder_device_patched():
131107
device = None
132-
if (not (torch.cuda.is_available() or _has_xpu()) or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_text_encoder_device).lower()):
108+
if (not is_accelerator_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_text_encoder_device).lower()):
133109
device = torch.device("cpu")
134110
else:
135111
devs = set(get_device_list())
@@ -367,4 +343,4 @@ def register_and_count(module_names, node_map):
367343
logger.info(dash_line)
368344

369345

370-
logger.info(f"[MultiGPU] Registration complete. Final mappings: {', '.join(NODE_CLASS_MAPPINGS.keys())}")
346+
logger.info(f"[MultiGPU] Registration complete. Final mappings: {', '.join(NODE_CLASS_MAPPINGS.keys())}")

device_utils.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
"""
2+
Device detection and management utilities for ComfyUI-MultiGPU.
3+
Single source of truth for all device enumeration and compatibility checks.
4+
Handles all device types supported by ComfyUI core.
5+
"""
6+
7+
import torch
8+
import logging
9+
10+
logger = logging.getLogger("MultiGPU")
11+
12+
# Module-level cache for device list (populated once on first call)
13+
_DEVICE_LIST_CACHE = None
14+
15+
def get_device_list():
16+
"""
17+
Enumerate ALL physically available devices that can store torch tensors.
18+
This includes all device types supported by ComfyUI core.
19+
Results are cached after first call since devices don't change during runtime.
20+
21+
Returns a comprehensive list of all available devices across all types:
22+
- CPU (always available)
23+
- CUDA devices (NVIDIA GPUs)
24+
- XPU devices (Intel GPUs)
25+
- NPU devices (Ascend NPUs from Huawei)
26+
- MLU devices (Cambricon MLUs)
27+
- MPS device (Apple Metal)
28+
- DirectML devices (Windows DirectML)
29+
- CoreX/IXUCA devices
30+
"""
31+
global _DEVICE_LIST_CACHE
32+
33+
# Return cached result if already populated
34+
if _DEVICE_LIST_CACHE is not None:
35+
return _DEVICE_LIST_CACHE
36+
37+
# First time - do the actual detection
38+
devs = []
39+
40+
# CPU is always physically present and can store tensors
41+
devs.append("cpu")
42+
43+
# CUDA devices (NVIDIA GPUs)
44+
try:
45+
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
46+
device_count = torch.cuda.device_count()
47+
devs += [f"cuda:{i}" for i in range(device_count)]
48+
logger.debug(f"[MultiGPU] Found {device_count} CUDA device(s)")
49+
except Exception as e:
50+
logger.debug(f"[MultiGPU] CUDA detection failed: {e}")
51+
52+
# XPU devices (Intel GPUs)
53+
try:
54+
# Try to import intel extension first (may be required for XPU support)
55+
import intel_extension_for_pytorch as ipex
56+
except ImportError:
57+
pass
58+
try:
59+
if hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available():
60+
device_count = torch.xpu.device_count()
61+
devs += [f"xpu:{i}" for i in range(device_count)]
62+
logger.debug(f"[MultiGPU] Found {device_count} XPU device(s)")
63+
except Exception as e:
64+
logger.debug(f"[MultiGPU] XPU detection failed: {e}")
65+
66+
# NPU devices (Ascend NPUs from Huawei)
67+
try:
68+
import torch_npu
69+
if hasattr(torch, "npu") and hasattr(torch.npu, "is_available") and torch.npu.is_available():
70+
device_count = torch.npu.device_count()
71+
devs += [f"npu:{i}" for i in range(device_count)]
72+
logger.debug(f"[MultiGPU] Found {device_count} NPU device(s)")
73+
except Exception as e:
74+
logger.debug(f"[MultiGPU] NPU detection failed: {e}")
75+
76+
# MLU devices (Cambricon MLUs)
77+
try:
78+
import torch_mlu
79+
if hasattr(torch, "mlu") and hasattr(torch.mlu, "is_available") and torch.mlu.is_available():
80+
device_count = torch.mlu.device_count()
81+
devs += [f"mlu:{i}" for i in range(device_count)]
82+
logger.debug(f"[MultiGPU] Found {device_count} MLU device(s)")
83+
except Exception as e:
84+
logger.debug(f"[MultiGPU] MLU detection failed: {e}")
85+
86+
# MPS device (Apple Metal - single device only)
87+
try:
88+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
89+
devs.append("mps")
90+
logger.debug("[MultiGPU] Found MPS device")
91+
except Exception as e:
92+
logger.debug(f"[MultiGPU] MPS detection failed: {e}")
93+
94+
# DirectML devices (Windows DirectML for AMD/Intel/NVIDIA)
95+
try:
96+
import torch_directml
97+
adapter_count = torch_directml.device_count()
98+
if adapter_count > 0:
99+
devs += [f"directml:{i}" for i in range(adapter_count)]
100+
logger.debug(f"[MultiGPU] Found {adapter_count} DirectML adapter(s)")
101+
except Exception as e:
102+
logger.debug(f"[MultiGPU] DirectML detection failed: {e}")
103+
104+
# IXUCA/CoreX devices (special accelerator)
105+
try:
106+
if hasattr(torch, "corex"):
107+
# CoreX typically exposes single device, but check if there's a count method
108+
if hasattr(torch.corex, "device_count"):
109+
device_count = torch.corex.device_count()
110+
devs += [f"corex:{i}" for i in range(device_count)]
111+
logger.debug(f"[MultiGPU] Found {device_count} CoreX device(s)")
112+
else:
113+
devs.append("corex:0")
114+
logger.debug("[MultiGPU] Found CoreX device")
115+
except Exception as e:
116+
logger.debug(f"[MultiGPU] CoreX detection failed: {e}")
117+
118+
# Cache the result for future calls
119+
_DEVICE_LIST_CACHE = devs
120+
121+
# Log only once when initially populated
122+
logger.info(f"[MultiGPU] Device list initialized: {devs}")
123+
124+
return devs
125+
126+
127+
def is_accelerator_available():
128+
"""
129+
Check if any accelerator device is available.
130+
Used by patched functions to determine CPU fallback.
131+
132+
Returns True if any GPU/accelerator is available, False otherwise.
133+
"""
134+
# Check CUDA
135+
try:
136+
if torch.cuda.is_available():
137+
return True
138+
except:
139+
pass
140+
141+
# Check XPU (Intel GPU)
142+
try:
143+
if hasattr(torch, "xpu") and torch.xpu.is_available():
144+
return True
145+
except:
146+
pass
147+
148+
# Check NPU (Ascend)
149+
try:
150+
import torch_npu
151+
if hasattr(torch, "npu") and torch.npu.is_available():
152+
return True
153+
except:
154+
pass
155+
156+
# Check MLU (Cambricon)
157+
try:
158+
import torch_mlu
159+
if hasattr(torch, "mlu") and torch.mlu.is_available():
160+
return True
161+
except:
162+
pass
163+
164+
# Check MPS (Apple Metal)
165+
try:
166+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
167+
return True
168+
except:
169+
pass
170+
171+
# Check DirectML
172+
try:
173+
import torch_directml
174+
if torch_directml.device_count() > 0:
175+
return True
176+
except:
177+
pass
178+
179+
# Check CoreX/IXUCA
180+
try:
181+
if hasattr(torch, "corex"):
182+
return True
183+
except:
184+
pass
185+
186+
return False
187+
188+
189+
def is_device_compatible(device_string):
190+
"""
191+
Check if a device string represents a valid, available device.
192+
193+
Args:
194+
device_string: Device identifier like "cuda:0", "cpu", "xpu:1", etc.
195+
196+
Returns:
197+
True if the device is available, False otherwise.
198+
"""
199+
available_devices = get_device_list()
200+
return device_string in available_devices
201+
202+
203+
def get_device_type(device_string):
204+
"""
205+
Extract the device type from a device string.
206+
207+
Args:
208+
device_string: Device identifier like "cuda:0", "cpu", "xpu:1", etc.
209+
210+
Returns:
211+
Device type string (e.g., "cuda", "cpu", "xpu", "npu", "mlu", "mps", "directml", "corex")
212+
"""
213+
if ":" in device_string:
214+
return device_string.split(":")[0]
215+
return device_string
216+
217+
218+
def parse_device_string(device_string):
219+
"""
220+
Parse a device string into type and index.
221+
222+
Args:
223+
device_string: Device identifier like "cuda:0", "cpu", "xpu:1", etc.
224+
225+
Returns:
226+
Tuple of (device_type, device_index) where index is None for non-indexed devices
227+
"""
228+
if ":" in device_string:
229+
parts = device_string.split(":")
230+
return parts[0], int(parts[1])
231+
return device_string, None

distorch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import copy
1313
from collections import defaultdict
1414
import comfy.model_management as mm
15+
from .device_utils import get_device_list
1516

1617
# Global store for model allocations
1718
model_allocation_store = {}
@@ -292,7 +293,6 @@ def calculate_vvram_allocation_string(model, virtual_vram_str):
292293

293294
def override_class_with_distorch_gguf(cls):
294295
"""Legacy DisTorch wrapper for GGUF models for backward compatibility."""
295-
from .nodes import get_device_list
296296
from . import current_device
297297

298298
class NodeOverrideDisTorchGGUFLegacy(cls):
@@ -330,7 +330,7 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v
330330
vram_string = ""
331331
if virtual_vram_gb > 0:
332332
if use_other_vram:
333-
available_devices = [d for d in get_device_list() if d.startswith(("cuda", "xpu"))]
333+
available_devices = [d for d in get_device_list() if d != "cpu"]
334334
other_devices = [d for d in available_devices if d != device]
335335
other_devices.sort(key=lambda x: int(x.split(':')[1] if ':' in x else x[-1]), reverse=False)
336336
device_string = ','.join(other_devices + ['cpu'])
@@ -354,7 +354,6 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v
354354

355355
def override_class_with_distorch_gguf_v2(cls):
356356
"""DisTorch 2.0 wrapper for GGUF models."""
357-
from .nodes import get_device_list
358357
from . import current_device
359358

360359
class NodeOverrideDisTorchGGUFv2(cls):
@@ -406,7 +405,6 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0,
406405

407406
def override_class_with_distorch_clip(cls):
408407
"""DisTorch wrapper for CLIP models with GGUF support"""
409-
from .nodes import get_device_list
410408
from . import current_text_encoder_device
411409

412410
class NodeOverrideDisTorch(cls):
@@ -441,7 +439,7 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v
441439
vram_string = ""
442440
if virtual_vram_gb > 0:
443441
if use_other_vram:
444-
available_devices = [d for d in get_device_list() if d.startswith(("cuda", "xpu"))]
442+
available_devices = [d for d in get_device_list() if d != "cpu"]
445443
other_devices = [d for d in available_devices if d != device]
446444
other_devices.sort(key=lambda x: int(x.split(':')[1] if ':' in x else x[-1]), reverse=False)
447445
device_string = ','.join(other_devices + ['cpu'])

distorch_2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import comfy.model_management as mm
1717
import comfy.model_patcher
1818
from . import current_device
19-
from .nodes import get_device_list
19+
from .device_utils import get_device_list
2020

2121
safetensor_allocation_store = {}
2222
safetensor_settings_store = {}
@@ -549,7 +549,6 @@ def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str):
549549

550550
def override_class_with_distorch_safetensor_v2(cls):
551551
"""DisTorch 2.0 wrapper for safetensor models"""
552-
from .nodes import get_device_list
553552
from . import current_device
554553

555554
class NodeOverrideDisTorchSafetensorV2(cls):

nodes.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,7 @@
22
import folder_paths
33
from pathlib import Path
44
from nodes import NODE_CLASS_MAPPINGS
5-
6-
def _has_xpu():
7-
try:
8-
return hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available()
9-
except Exception:
10-
return False
11-
12-
def get_device_list():
13-
devs = ["cpu"]
14-
try:
15-
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
16-
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
17-
except Exception:
18-
pass
19-
try:
20-
if _has_xpu():
21-
devs += [f"xpu:{i}" for i in range(torch.xpu.device_count())]
22-
except Exception:
23-
pass
24-
return devs
5+
from .device_utils import get_device_list
256

267
class DeviceSelectorMultiGPU:
278
@classmethod

0 commit comments

Comments
 (0)