Skip to content

Commit 17a50f1

Browse files
committed
Fix Windows issue 178 dynamic clip loading
1 parent 8c4034f commit 17a50f1

2 files changed

Lines changed: 143 additions & 0 deletions

File tree

clip_dynamic_load_list_guard.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import logging
2+
3+
import comfy.model_management
4+
import comfy.model_patcher
5+
from comfy.model_patcher import QuantizedTensor, get_key_weight, low_vram_patch_estimate_vram
6+
7+
8+
logger = logging.getLogger("MultiGPU")
9+
10+
_PATCH_MARKER = "_mgpu_issue21_clip_dynamic_load_list_guard"
11+
_MODULE_THRESHOLD = 200
12+
_DEPTH_THRESHOLD = 200
13+
14+
15+
def _iter_named_modules_nonrecursive(module):
16+
stack = [("", module)]
17+
seen = set()
18+
while stack:
19+
prefix, current = stack.pop()
20+
current_id = id(current)
21+
if current_id in seen:
22+
continue
23+
seen.add(current_id)
24+
yield prefix, current
25+
children = list(current._modules.items())
26+
for child_name, child in reversed(children):
27+
if child is None:
28+
continue
29+
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
30+
stack.append((child_prefix, child))
31+
32+
33+
def _iter_named_parameters_nonrecursive(module):
34+
stack = [("", module)]
35+
seen = set()
36+
while stack:
37+
prefix, current = stack.pop()
38+
for name, param in current._parameters.items():
39+
if param is None:
40+
continue
41+
param_id = id(param)
42+
if param_id in seen:
43+
continue
44+
seen.add(param_id)
45+
full_name = f"{prefix}.{name}" if prefix else name
46+
yield full_name, param
47+
children = list(current._modules.items())
48+
for child_name, child in reversed(children):
49+
if child is None:
50+
continue
51+
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
52+
stack.append((child_prefix, child))
53+
54+
55+
def _graph_requires_guard(module):
56+
stack = [(module, 0)]
57+
seen = set()
58+
module_count = 0
59+
max_depth = 0
60+
61+
while stack:
62+
current, depth = stack.pop()
63+
current_id = id(current)
64+
if current_id in seen:
65+
continue
66+
seen.add(current_id)
67+
module_count += 1
68+
max_depth = max(max_depth, depth)
69+
if module_count > _MODULE_THRESHOLD or max_depth > _DEPTH_THRESHOLD:
70+
return True
71+
for child in current._modules.values():
72+
if child is not None:
73+
stack.append((child, depth + 1))
74+
75+
return False
76+
77+
78+
def _safe_dynamic_load_list(self, default_device=None):
79+
loading = []
80+
for n, m in _iter_named_modules_nonrecursive(self.model):
81+
default = False
82+
params = dict(m.named_parameters(recurse=False))
83+
if params:
84+
for name, _ in _iter_named_parameters_nonrecursive(m):
85+
if name not in params:
86+
default = True
87+
break
88+
89+
if default and default_device is not None:
90+
for param_name, param in params.items():
91+
param.data = param.data.to(
92+
device=default_device,
93+
dtype=getattr(m, param_name + "_comfy_model_dtype", None),
94+
)
95+
96+
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
97+
module_mem = comfy.model_management.module_size(m)
98+
module_offload_mem = module_mem
99+
if hasattr(m, "comfy_cast_weights"):
100+
101+
def check_module_offload_mem(key):
102+
if key in self.patches:
103+
return low_vram_patch_estimate_vram(self.model, key)
104+
model_dtype = getattr(self.model, "manual_cast_dtype", None)
105+
weight, _, _ = get_key_weight(self.model, key)
106+
if model_dtype is None or weight is None:
107+
return 0
108+
if weight.dtype != model_dtype or isinstance(weight, QuantizedTensor):
109+
return weight.numel() * model_dtype.itemsize
110+
return 0
111+
112+
module_offload_mem += check_module_offload_mem(f"{n}.weight")
113+
module_offload_mem += check_module_offload_mem(f"{n}.bias")
114+
115+
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
116+
loading.append(sort_criteria + (module_mem, n, m, params))
117+
118+
return loading
119+
120+
121+
def register_clip_dynamic_load_list_guard():
122+
original = comfy.model_patcher.ModelPatcherDynamic._load_list
123+
if getattr(original, _PATCH_MARKER, False):
124+
return False
125+
126+
def guarded_load_list(self, for_dynamic=False, default_device=None):
127+
if not for_dynamic:
128+
return original(self, for_dynamic=for_dynamic, default_device=default_device)
129+
130+
if _graph_requires_guard(self.model):
131+
logger.info("[MultiGPU Issue21] Using non-recursive ModelPatcherDynamic._load_list guard")
132+
return _safe_dynamic_load_list(self, default_device=default_device)
133+
134+
return original(self, for_dynamic=for_dynamic, default_device=default_device)
135+
136+
setattr(guarded_load_list, _PATCH_MARKER, True)
137+
comfy.model_patcher.ModelPatcherDynamic._load_list = guarded_load_list
138+
logger.info("[MultiGPU Issue21] Registered ModelPatcherDynamic._load_list guard")
139+
return True

wrappers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def override(self, *args, device=None, offload_device=None, **kwargs):
522522
def override_class_clip(cls):
523523
"""Standard MultiGPU device override for CLIP models (with device kwarg workaround)"""
524524
from . import set_current_text_encoder_device, get_current_text_encoder_device
525+
from .clip_dynamic_load_list_guard import register_clip_dynamic_load_list_guard
525526

526527
class NodeOverride(cls):
527528
@classmethod
@@ -540,6 +541,7 @@ def override(self, *args, device=None, **kwargs):
540541
original_text_device = get_current_text_encoder_device()
541542
if device is not None:
542543
set_current_text_encoder_device(device)
544+
register_clip_dynamic_load_list_guard()
543545
kwargs['device'] = 'default'
544546
fn = getattr(super(), cls.FUNCTION)
545547
out = fn(*args, **kwargs)
@@ -554,6 +556,7 @@ def override(self, *args, device=None, **kwargs):
554556
def override_class_clip_no_device(cls):
555557
"""Standard MultiGPU device override for Triple/Quad CLIP models (no device kwarg workaround)"""
556558
from . import set_current_text_encoder_device, get_current_text_encoder_device
559+
from .clip_dynamic_load_list_guard import register_clip_dynamic_load_list_guard
557560

558561
class NodeOverride(cls):
559562
@classmethod
@@ -572,6 +575,7 @@ def override(self, *args, device=None, **kwargs):
572575
original_text_device = get_current_text_encoder_device()
573576
if device is not None:
574577
set_current_text_encoder_device(device)
578+
register_clip_dynamic_load_list_guard()
575579
fn = getattr(super(), cls.FUNCTION)
576580
out = fn(*args, **kwargs)
577581
try:

0 commit comments

Comments
 (0)