Skip to content

Commit e64cdf7

Browse files
authored
Merge pull request #175 from pollockjj/dvram
fix: restore MultiGPU compatibility on current ComfyUI with basic DynamicVRAM support
2 parents ac3df4e + 82e6931 commit e64cdf7

14 files changed

Lines changed: 1435 additions & 353 deletions

__init__.py

Lines changed: 283 additions & 70 deletions
Large diffs are not rendered by default.

checkpoint_multigpu.py

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,39 @@
2121
def patch_load_state_dict_guess_config():
2222
"""Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading."""
2323
global original_load_state_dict_guess_config
24-
24+
2525
if original_load_state_dict_guess_config is not None:
2626
logger.debug("[MultiGPU Checkpoint] load_state_dict_guess_config is already patched.")
2727
return
28-
28+
2929
logger.info("[MultiGPU Core Patching] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading.")
3030
original_load_state_dict_guess_config = comfy.sd.load_state_dict_guess_config
3131
comfy.sd.load_state_dict_guess_config = patched_load_state_dict_guess_config
3232

3333
def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False,
3434
embedding_directory=None, output_model=True, model_options={},
35-
te_model_options={}, metadata=None):
35+
te_model_options={}, metadata=None, disable_dynamic=False):
3636
"""Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
3737
from . import set_current_device, set_current_text_encoder_device, get_current_device, get_current_text_encoder_device
38-
38+
3939
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
4040
config_hash = str(sd_size)
4141
device_config = checkpoint_device_config.get(config_hash)
4242
distorch_config = checkpoint_distorch_config.get(config_hash)
4343

4444
if not device_config and not distorch_config:
45-
return original_load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options, metadata)
45+
return original_load_state_dict_guess_config(
46+
sd,
47+
output_vae=output_vae,
48+
output_clip=output_clip,
49+
output_clipvision=output_clipvision,
50+
embedding_directory=embedding_directory,
51+
output_model=output_model,
52+
model_options=model_options,
53+
te_model_options=te_model_options,
54+
metadata=metadata,
55+
disable_dynamic=disable_dynamic,
56+
)
4657

4758
logger.debug("[MultiGPU Checkpoint] ENTERING Patched Checkpoint Loader")
4859
logger.debug(f"[MultiGPU Checkpoint] Received Device Config: {device_config}")
@@ -53,7 +64,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5364
vae = None
5465
model = None
5566
model_patcher = None
56-
67+
5768
# Capture the current devices at runtime so we can restore them after loading
5869
original_main_device = get_current_device()
5970
original_clip_device = get_current_text_encoder_device()
@@ -68,12 +79,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6879
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
6980

7081
model_config = comfy.model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
71-
82+
7283
if model_config is None:
7384
logger.warning("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only.")
7485
# Simplified fallback for non-checkpoints
7586
set_current_device(device_config.get('unet_device', original_main_device))
76-
diffusion_model = comfy.sd.load_diffusion_model_state_dict(sd, model_options={})
87+
diffusion_model = comfy.sd.load_diffusion_model_state_dict(
88+
sd,
89+
model_options={},
90+
metadata=metadata,
91+
disable_dynamic=disable_dynamic,
92+
)
7793
if diffusion_model is None:
7894
return None
7995
return (diffusion_model, None, VAE(sd={}), None)
@@ -83,18 +99,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
8399
unet_weight_dtype = list(model_config.supported_inference_dtypes)
84100
if model_config.scaled_fp8 is not None:
85101
weight_dtype = None
86-
102+
87103
if custom_operations is not None:
88104
model_config.custom_operations = custom_operations
89105
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
90106
if unet_dtype is None:
91107
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
92-
93-
unet_compute_device = device_config.get('unet_device', original_main_device)
108+
109+
unet_compute_device = torch.device(device_config.get('unet_device', original_main_device))
94110
if model_config.scaled_fp8 is not None:
95-
manual_cast_dtype = mm.unet_manual_cast(None, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
111+
manual_cast_dtype = mm.unet_manual_cast(None, unet_compute_device, model_config.supported_inference_dtypes)
96112
else:
97-
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
113+
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, unet_compute_device, model_config.supported_inference_dtypes)
98114
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
99115
logger.info(f"UNet DType: {unet_dtype}, Manual Cast: {manual_cast_dtype}")
100116

@@ -103,19 +119,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
103119
clipvision = comfy.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
104120

105121
if output_model:
106-
unet_compute_device = device_config.get('unet_device', original_main_device)
107-
set_current_device(unet_compute_device)
122+
unet_compute_device = torch.device(device_config.get('unet_device', original_main_device))
123+
set_current_device(unet_compute_device)
108124
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
109125

110126
multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load")
111127

112128
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
113-
model.load_model_weights(sd, diffusion_model_prefix)
129+
model_patcher_class = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
130+
model_patcher = model_patcher_class(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
131+
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
114132
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
115133

116134
logger.mgpu_mm_log("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup")
117135
soft_empty_cache_multigpu()
118-
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
119136
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-model")
120137

121138
if distorch_config and 'unet_allocation' in distorch_config:
@@ -131,7 +148,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
131148
vae_target_device = torch.device(device_config.get('vae_device', original_main_device))
132149
set_current_device(vae_target_device) # Use main device context for VAE
133150
multigpu_memory_log(f"vae:{config_hash[:8]}", "pre-load")
134-
151+
135152
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
136153
vae_sd = model_config.process_vae_state_dict(vae_sd)
137154
vae = VAE(sd=vae_sd, metadata=metadata)
@@ -151,17 +168,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
151168
for pref in scaled_fp8_list:
152169
skip = skip or k.startswith(pref)
153170
if not skip:
154-
out_sd[k] = sd[k]
171+
out_sd[k] = sd[k]
155172

156173
for pref in scaled_fp8_list:
157174
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
158175
for k in quant_sd:
159176
out_sd[k] = quant_sd[k]
160177
sd = out_sd
161178

162-
clip_target_device = device_config.get('clip_device', original_clip_device)
179+
clip_target_device = torch.device(device_config.get('clip_device', original_clip_device))
163180
set_current_text_encoder_device(clip_target_device)
164-
181+
165182
clip_target = model_config.clip_target(state_dict=sd)
166183
if clip_target is not None:
167184
clip_sd = model_config.process_clip_state_dict(sd)
@@ -170,7 +187,15 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
170187
multigpu_memory_log(f"clip:{config_hash[:8]}", "pre-load")
171188
soft_empty_cache_multigpu()
172189
clip_params = comfy.utils.calculate_parameters(clip_sd)
173-
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=clip_params, model_options=te_model_options)
190+
clip = CLIP(
191+
clip_target,
192+
embedding_directory=embedding_directory,
193+
tokenizer_data=clip_sd,
194+
parameters=clip_params,
195+
state_dict=clip_sd,
196+
model_options=te_model_options,
197+
disable_dynamic=disable_dynamic,
198+
)
174199

175200
if distorch_config and 'clip_allocation' in distorch_config:
176201
clip_alloc = distorch_config['clip_allocation']
@@ -181,16 +206,13 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
181206
logger.info(f"[CHECKPOINT_META] CLIP inner_model id=0x{id(inner_clip):x}")
182207
clip.patcher.model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
183208

184-
m, u = clip.load_sd(clip_sd, full_model=True) # This respects the patched text_encoder_device
185-
if len(m) > 0: logger.warning(f"CLIP missing keys: {m}")
186-
if len(u) > 0: logger.debug(f"CLIP unexpected keys: {u}")
187209
logger.info("CLIP Loaded.")
188210
multigpu_memory_log(f"clip:{config_hash[:8]}", "post-load")
189211
else:
190212
logger.warning("No CLIP/text encoder weights in checkpoint.")
191213
else:
192214
logger.warning("CLIP target not found in model config.")
193-
215+
194216
finally:
195217
set_current_device(original_main_device)
196218
set_current_text_encoder_device(original_clip_device)
@@ -206,7 +228,7 @@ def INPUT_TYPES(s):
206228
import folder_paths
207229
devices = get_device_list()
208230
default_device = devices[1] if len(devices) > 1 else devices[0]
209-
231+
210232
return {
211233
"required": {
212234
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
@@ -215,27 +237,27 @@ def INPUT_TYPES(s):
215237
"vae_device": (devices, {"default": default_device}),
216238
}
217239
}
218-
240+
219241
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
220242
FUNCTION = "load_checkpoint"
221243
CATEGORY = "multigpu"
222244
TITLE = "Checkpoint Loader Advanced (MultiGPU)"
223-
245+
224246
def load_checkpoint(self, ckpt_name, unet_device, clip_device, vae_device):
225247
patch_load_state_dict_guess_config()
226-
248+
227249
import folder_paths
228250
import comfy.utils
229-
251+
230252
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
231253
sd = comfy.utils.load_torch_file(ckpt_path)
232254
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
233255
config_hash = str(sd_size)
234-
256+
235257
checkpoint_device_config[config_hash] = {
236258
'unet_device': unet_device, 'clip_device': clip_device, 'vae_device': vae_device
237259
}
238-
260+
239261
# Load using standard loader, our patch will intercept
240262
from nodes import CheckpointLoaderSimple
241263
return CheckpointLoaderSimple().load_checkpoint(ckpt_name)
@@ -247,7 +269,7 @@ def INPUT_TYPES(s):
247269
import folder_paths
248270
devices = get_device_list()
249271
compute_device = devices[1] if len(devices) > 1 else devices[0]
250-
272+
251273
return {
252274
"required": {
253275
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
@@ -265,18 +287,18 @@ def INPUT_TYPES(s):
265287
"eject_models": ("BOOLEAN", {"default": True}),
266288
}
267289
}
268-
290+
269291
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
270292
FUNCTION = "load_checkpoint"
271293
CATEGORY = "multigpu/distorch_2"
272294
TITLE = "Checkpoint Loader Advanced (DisTorch2)"
273-
295+
274296
def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, unet_donor_device,
275297
clip_compute_device, clip_virtual_vram_gb, clip_donor_device, vae_device,
276298
unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True, eject_models=True):
277-
299+
278300
if eject_models:
279-
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
301+
logger.mgpu_mm_log("[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
280302
ejection_count = 0
281303
for i, lm in enumerate(mm.current_loaded_models):
282304
model_name = type(getattr(lm.model, 'model', lm.model)).__name__ if lm.model else 'Unknown'
@@ -289,17 +311,17 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
289311
logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (direct patcher) → marked for eviction")
290312
ejection_count += 1
291313
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP_COMPLETE] Marked {ejection_count} models for Comfy Core eviction during load_models_gpu")
292-
293-
patch_load_state_dict_guess_config()
294-
314+
315+
patch_load_state_dict_guess_config()
316+
295317
import folder_paths
296318
import comfy.utils
297-
319+
298320
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
299321
sd = comfy.utils.load_torch_file(ckpt_path)
300322
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
301323
config_hash = str(sd_size)
302-
324+
303325
checkpoint_device_config[config_hash] = {
304326
'unet_device': unet_compute_device,
305327
'clip_device': clip_compute_device,
@@ -312,7 +334,7 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
312334
elif unet_expert_mode_allocations:
313335
unet_vram_str = unet_compute_device
314336
unet_alloc = f"{unet_expert_mode_allocations}#{unet_vram_str}" if unet_expert_mode_allocations or unet_vram_str else ""
315-
337+
316338
clip_vram_str = ""
317339
if clip_virtual_vram_gb > 0:
318340
clip_vram_str = f"{clip_compute_device};{clip_virtual_vram_gb};{clip_donor_device}"
@@ -327,6 +349,6 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
327349
'unet_settings': hashlib.sha256(f"{unet_alloc}{high_precision_loras}".encode()).hexdigest(),
328350
'clip_settings': hashlib.sha256(f"{clip_alloc}{high_precision_loras}".encode()).hexdigest(),
329351
}
330-
352+
331353
from nodes import CheckpointLoaderSimple
332354
return CheckpointLoaderSimple().load_checkpoint(ckpt_name)

ci/extract_allocation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33

44
import argparse
55
import json
6+
import sys
67
from pathlib import Path
78
from typing import Iterable, Iterator, Dict, Any
89

910

11+
def _write_stdout(message: str = "") -> None:
12+
sys.stdout.write(f"{message}\n")
13+
sys.stdout.flush()
14+
15+
1016
def load_json_lines(path: Path) -> Iterator[Dict[str, Any]]:
1117
with path.open("r", encoding="utf-8") as handle:
1218
for line in handle:
@@ -37,12 +43,12 @@ def main() -> int:
3743

3844
entries = list(load_json_lines(args.logfile))
3945
if not entries:
40-
print("No entries found in log file.")
46+
_write_stdout("No entries found in log file.")
4147
return 0
4248

4349
matched = [entry for entry in entries if is_allocation_event(entry, args.keywords)]
4450
if not matched:
45-
print("No allocation events matched provided keywords.")
51+
_write_stdout("No allocation events matched provided keywords.")
4652
return 0
4753

4854
for entry in matched:
@@ -51,9 +57,9 @@ def main() -> int:
5157
component = entry.get("component", "")
5258
header_bits = [bit for bit in (timestamp, category, component) if bit]
5359
header = " | ".join(header_bits) if header_bits else "allocation"
54-
print(f"## {header}")
55-
print(entry.get("message", ""))
56-
print()
60+
_write_stdout(f"## {header}")
61+
_write_stdout(entry.get("message", ""))
62+
_write_stdout()
5763

5864
return 0
5965

0 commit comments

Comments
 (0)