Skip to content

Commit 7d24e15

Browse files
authored
Merge pull request #183 from pollockjj/issue-178
Fix Windows dynamic clip loading stack overflow (#178)
2 parents 8c4034f + e507f85 commit 7d24e15

3 files changed

Lines changed: 144 additions & 1 deletion

File tree

clip_dynamic_load_list_guard.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import logging
2+
3+
import comfy.model_management
4+
import comfy.model_patcher
5+
from comfy.model_patcher import QuantizedTensor, get_key_weight, low_vram_patch_estimate_vram
6+
7+
8+
logger = logging.getLogger("MultiGPU")
9+
10+
_PATCH_MARKER = "_mgpu_issue21_clip_dynamic_load_list_guard"
11+
_MODULE_THRESHOLD = 200
12+
_DEPTH_THRESHOLD = 200
13+
14+
15+
def _iter_named_modules_nonrecursive(module):
16+
stack = [("", module)]
17+
seen = set()
18+
while stack:
19+
prefix, current = stack.pop()
20+
current_id = id(current)
21+
if current_id in seen:
22+
continue
23+
seen.add(current_id)
24+
yield prefix, current
25+
children = list(current._modules.items())
26+
for child_name, child in reversed(children):
27+
if child is None:
28+
continue
29+
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
30+
stack.append((child_prefix, child))
31+
32+
33+
def _iter_named_parameters_nonrecursive(module):
34+
stack = [("", module)]
35+
seen = set()
36+
while stack:
37+
prefix, current = stack.pop()
38+
for name, param in current._parameters.items():
39+
if param is None:
40+
continue
41+
param_id = id(param)
42+
if param_id in seen:
43+
continue
44+
seen.add(param_id)
45+
full_name = f"{prefix}.{name}" if prefix else name
46+
yield full_name, param
47+
children = list(current._modules.items())
48+
for child_name, child in reversed(children):
49+
if child is None:
50+
continue
51+
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
52+
stack.append((child_prefix, child))
53+
54+
55+
def _graph_requires_guard(module):
56+
stack = [(module, 0)]
57+
seen = set()
58+
module_count = 0
59+
max_depth = 0
60+
61+
while stack:
62+
current, depth = stack.pop()
63+
current_id = id(current)
64+
if current_id in seen:
65+
continue
66+
seen.add(current_id)
67+
module_count += 1
68+
max_depth = max(max_depth, depth)
69+
if module_count > _MODULE_THRESHOLD or max_depth > _DEPTH_THRESHOLD:
70+
return True
71+
for child in current._modules.values():
72+
if child is not None:
73+
stack.append((child, depth + 1))
74+
75+
return False
76+
77+
78+
def _safe_dynamic_load_list(self, default_device=None):
79+
loading = []
80+
for n, m in _iter_named_modules_nonrecursive(self.model):
81+
default = False
82+
params = dict(m.named_parameters(recurse=False))
83+
if params:
84+
for name, _ in _iter_named_parameters_nonrecursive(m):
85+
if name not in params:
86+
default = True
87+
break
88+
89+
if default and default_device is not None:
90+
for param_name, param in params.items():
91+
param.data = param.data.to(
92+
device=default_device,
93+
dtype=getattr(m, param_name + "_comfy_model_dtype", None),
94+
)
95+
96+
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
97+
module_mem = comfy.model_management.module_size(m)
98+
module_offload_mem = module_mem
99+
if hasattr(m, "comfy_cast_weights"):
100+
101+
def check_module_offload_mem(key):
102+
if key in self.patches:
103+
return low_vram_patch_estimate_vram(self.model, key)
104+
model_dtype = getattr(self.model, "manual_cast_dtype", None)
105+
weight, _, _ = get_key_weight(self.model, key)
106+
if model_dtype is None or weight is None:
107+
return 0
108+
if weight.dtype != model_dtype or isinstance(weight, QuantizedTensor):
109+
return weight.numel() * model_dtype.itemsize
110+
return 0
111+
112+
module_offload_mem += check_module_offload_mem(f"{n}.weight")
113+
module_offload_mem += check_module_offload_mem(f"{n}.bias")
114+
115+
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
116+
loading.append(sort_criteria + (module_mem, n, m, params))
117+
118+
return loading
119+
120+
121+
def register_clip_dynamic_load_list_guard():
122+
original = comfy.model_patcher.ModelPatcherDynamic._load_list
123+
if getattr(original, _PATCH_MARKER, False):
124+
return False
125+
126+
def guarded_load_list(self, for_dynamic=False, default_device=None):
127+
if not for_dynamic:
128+
return original(self, for_dynamic=for_dynamic, default_device=default_device)
129+
130+
if _graph_requires_guard(self.model):
131+
logger.info("[MultiGPU Issue21] Using non-recursive ModelPatcherDynamic._load_list guard")
132+
return _safe_dynamic_load_list(self, default_device=default_device)
133+
134+
return original(self, for_dynamic=for_dynamic, default_device=default_device)
135+
136+
setattr(guarded_load_list, _PATCH_MARKER, True)
137+
comfy.model_patcher.ModelPatcherDynamic._load_list = guarded_load_list
138+
logger.info("[MultiGPU Issue21] Registered ModelPatcherDynamic._load_list guard")
139+
return True

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-multigpu"
33
description = "Provides a suite of custom nodes to manage multiple GPUs for ComfyUI, including advanced model offloading for both GGUF and Safetensor formats with DisTorch, and bespoke MultiGPU support for WanVideoWrapper and other custom nodes."
4-
version = "2.6.1"
4+
version = "2.6.2"
55
license = {file = "LICENSE"}
66

77
[project.urls]

wrappers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def override(self, *args, device=None, offload_device=None, **kwargs):
522522
def override_class_clip(cls):
523523
"""Standard MultiGPU device override for CLIP models (with device kwarg workaround)"""
524524
from . import set_current_text_encoder_device, get_current_text_encoder_device
525+
from .clip_dynamic_load_list_guard import register_clip_dynamic_load_list_guard
525526

526527
class NodeOverride(cls):
527528
@classmethod
@@ -540,6 +541,7 @@ def override(self, *args, device=None, **kwargs):
540541
original_text_device = get_current_text_encoder_device()
541542
if device is not None:
542543
set_current_text_encoder_device(device)
544+
register_clip_dynamic_load_list_guard()
543545
kwargs['device'] = 'default'
544546
fn = getattr(super(), cls.FUNCTION)
545547
out = fn(*args, **kwargs)
@@ -554,6 +556,7 @@ def override(self, *args, device=None, **kwargs):
554556
def override_class_clip_no_device(cls):
555557
"""Standard MultiGPU device override for Triple/Quad CLIP models (no device kwarg workaround)"""
556558
from . import set_current_text_encoder_device, get_current_text_encoder_device
559+
from .clip_dynamic_load_list_guard import register_clip_dynamic_load_list_guard
557560

558561
class NodeOverride(cls):
559562
@classmethod
@@ -572,6 +575,7 @@ def override(self, *args, device=None, **kwargs):
572575
original_text_device = get_current_text_encoder_device()
573576
if device is not None:
574577
set_current_text_encoder_device(device)
578+
register_clip_dynamic_load_list_guard()
575579
fn = getattr(super(), cls.FUNCTION)
576580
out = fn(*args, **kwargs)
577581
try:

0 commit comments

Comments
 (0)