Skip to content

Commit a24f0a6

Browse files
committed
hotfix: Corrects for corner case when DisTorch VirtualVRAM=0.0 GB (previous refactor shunted to standard loader. This replicates that required logic across all nodes using DisTorch2 for allocations. Next time I will wait for the final test work flow to finish VAE conversion (where is the only place I test this.)
1 parent e40de4c commit a24f0a6

4 files changed

Lines changed: 36 additions & 20 deletions

File tree

checkpoint_multigpu.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
107107
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-model")
108108

109109
if distorch_config and 'unet_allocation' in distorch_config:
110-
register_patched_safetensor_modelpatcher()
111-
inner_model = model_patcher.model
112-
inner_model._distorch_v2_meta = {"full_allocation": distorch_config['unet_allocation']}
113-
logger.info(f"[CHECKPOINT_META] UNET inner_model id=0x{id(inner_model):x}")
114-
model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
110+
unet_alloc = distorch_config['unet_allocation']
111+
if unet_alloc:
112+
register_patched_safetensor_modelpatcher()
113+
inner_model = model_patcher.model
114+
inner_model._distorch_v2_meta = {"full_allocation": unet_alloc}
115+
logger.info(f"[CHECKPOINT_META] UNET inner_model id=0x{id(inner_model):x}")
116+
model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
115117

116118
model.load_model_weights(sd, diffusion_model_prefix)
117119
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
@@ -141,10 +143,11 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
141143
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=clip_params, model_options=te_model_options)
142144

143145
if distorch_config and 'clip_allocation' in distorch_config:
144-
if hasattr(clip, 'patcher'):
146+
clip_alloc = distorch_config['clip_allocation']
147+
if clip_alloc and hasattr(clip, 'patcher'):
145148
register_patched_safetensor_modelpatcher()
146149
inner_clip = clip.patcher.model
147-
inner_clip._distorch_v2_meta = {"full_allocation": distorch_config['clip_allocation']}
150+
inner_clip._distorch_v2_meta = {"full_allocation": clip_alloc}
148151
logger.info(f"[CHECKPOINT_META] CLIP inner_model id=0x{id(inner_clip):x}")
149152
clip.patcher.model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
150153

@@ -257,10 +260,19 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
257260
'vae_device': vae_device
258261
}
259262

260-
unet_vram_str = f"{unet_compute_device};{unet_virtual_vram_gb};{unet_donor_device}"
261-
unet_alloc = f"{unet_expert_mode_allocations}#{unet_vram_str}"
262-
clip_vram_str = f"{clip_compute_device};{clip_virtual_vram_gb};{clip_donor_device}"
263-
clip_alloc = f"{clip_expert_mode_allocations}#{clip_vram_str}"
263+
unet_vram_str = ""
264+
if unet_virtual_vram_gb > 0:
265+
unet_vram_str = f"{unet_compute_device};{unet_virtual_vram_gb};{unet_donor_device}"
266+
elif unet_expert_mode_allocations:
267+
unet_vram_str = unet_compute_device
268+
unet_alloc = f"{unet_expert_mode_allocations}#{unet_vram_str}" if unet_expert_mode_allocations or unet_vram_str else ""
269+
270+
clip_vram_str = ""
271+
if clip_virtual_vram_gb > 0:
272+
clip_vram_str = f"{clip_compute_device};{clip_virtual_vram_gb};{clip_donor_device}"
273+
elif clip_expert_mode_allocations:
274+
clip_vram_str = clip_compute_device
275+
clip_alloc = f"{clip_expert_mode_allocations}#{clip_vram_str}" if clip_expert_mode_allocations or clip_vram_str else ""
264276

265277
checkpoint_distorch_config[config_hash] = {
266278
'unet_allocation': unet_alloc,

distorch_2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,16 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
322322
"""
323323
DEVICE_RATIOS_DISTORCH = {}
324324
device_table = {}
325-
distorch_alloc = allocations_string
325+
distorch_alloc = ""
326+
virtual_vram_str = ""
326327
virtual_vram_gb = 0.0
327328

328-
distorch_alloc, virtual_vram_str = allocations_string.split('#')
329+
if '#' in allocations_string:
330+
distorch_alloc, virtual_vram_str = allocations_string.split('#', 1)
331+
else:
332+
distorch_alloc = allocations_string
329333

330-
compute_device = virtual_vram_str.split(';')[0]
334+
compute_device = virtual_vram_str.split(';')[0] if virtual_vram_str else "cuda:0"
331335
logger.debug(f"[MultiGPU DisTorch V2] Compute Device: {compute_device}")
332336

333337
if not distorch_alloc:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-multigpu"
33
description = "Provides a suite of custom nodes to manage multiple GPUs for ComfyUI, including advanced model offloading for both GGUF and Safetensor formats with DisTorch, and bespoke MultiGPU support for WanVideoWrapper and other custom nodes."
4-
version = "2.5.6"
4+
version = "2.5.7"
55
license = {file = "LICENSE"}
66

77
[project.urls]

wrappers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu",
112112
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
113113
model_to_check = out[0].patcher
114114

115-
if model_to_check:
115+
if model_to_check and full_allocation:
116116
inner_model = model_to_check.model
117117
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
118118

@@ -217,7 +217,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
217217
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
218218
model_to_check = out[0].patcher
219219

220-
if model_to_check:
220+
if model_to_check and full_allocation:
221221
inner_model = model_to_check.model
222222
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
223223

@@ -279,7 +279,7 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device
279279
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
280280
model_to_check = out[0].patcher
281281

282-
if model_to_check:
282+
if model_to_check and full_allocation:
283283
inner_model = model_to_check.model
284284
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
285285

@@ -343,7 +343,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
343343
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
344344
model_to_check = out[0].patcher
345345

346-
if model_to_check:
346+
if model_to_check and full_allocation:
347347
inner_model = model_to_check.model
348348
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
349349

@@ -407,7 +407,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
407407
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
408408
model_to_check = out[0].patcher
409409

410-
if model_to_check:
410+
if model_to_check and full_allocation:
411411
inner_model = model_to_check.model
412412
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
413413

0 commit comments

Comments
 (0)