Skip to content

Commit ac3df4e

Browse files
authored
Merge pull request #161 from pollockjj/tmv
Fix DisTorch Engine for ComfyUI 0.6.0+
2 parents 62f98ed + c5e3e6a commit ac3df4e

3 files changed

Lines changed: 99 additions & 24 deletions

File tree

checkpoint_multigpu.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6262
diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(sd)
6363
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
6464
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
65+
66+
custom_operations = model_options.get("custom_operations", None)
67+
if custom_operations is None:
68+
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
69+
6570
model_config = comfy.model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
6671

6772
if model_config is None:
@@ -79,13 +84,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
7984
if model_config.scaled_fp8 is not None:
8085
weight_dtype = None
8186

82-
model_config.custom_operations = model_options.get("custom_operations", None)
87+
if custom_operations is not None:
88+
model_config.custom_operations = custom_operations
8389
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
8490
if unet_dtype is None:
8591
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
8692

8793
unet_compute_device = device_config.get('unet_device', original_main_device)
88-
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
94+
if model_config.scaled_fp8 is not None:
95+
manual_cast_dtype = mm.unet_manual_cast(None, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
96+
else:
97+
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
8998
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
9099
logger.info(f"UNet DType: {unet_dtype}, Manual Cast: {manual_cast_dtype}")
91100

@@ -101,6 +110,8 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
101110
multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load")
102111

103112
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
113+
model.load_model_weights(sd, diffusion_model_prefix)
114+
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
104115

105116
logger.mgpu_mm_log("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup")
106117
soft_empty_cache_multigpu()
@@ -116,9 +127,6 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
116127
logger.info(f"[CHECKPOINT_META] UNET inner_model id=0x{id(inner_model):x}")
117128
model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
118129

119-
model.load_model_weights(sd, diffusion_model_prefix)
120-
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
121-
122130
if output_vae:
123131
vae_target_device = torch.device(device_config.get('vae_device', original_main_device))
124132
set_current_device(vae_target_device) # Use main device context for VAE
@@ -130,6 +138,27 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
130138
multigpu_memory_log(f"vae:{config_hash[:8]}", "post-load")
131139

132140
if output_clip:
141+
if te_model_options.get("custom_operations", None) is None:
142+
scaled_fp8_list = []
143+
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
144+
if k.endswith(".scaled_fp8"):
145+
scaled_fp8_list.append(k[:-len("scaled_fp8")])
146+
147+
if len(scaled_fp8_list) > 0:
148+
out_sd = {}
149+
for k in sd:
150+
skip = False
151+
for pref in scaled_fp8_list:
152+
skip = skip or k.startswith(pref)
153+
if not skip:
154+
out_sd[k] = sd[k]
155+
156+
for pref in scaled_fp8_list:
157+
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
158+
for k in quant_sd:
159+
out_sd[k] = quant_sd[k]
160+
sd = out_sd
161+
133162
clip_target_device = device_config.get('clip_device', original_clip_device)
134163
set_current_text_encoder_device(clip_target_device)
135164

@@ -224,15 +253,16 @@ def INPUT_TYPES(s):
224253
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
225254
"unet_compute_device": (devices, {"default": compute_device}),
226255
"unet_virtual_vram_gb": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 128.0, "step": 0.1}),
227-
"unet_donor_device": ("STRING", {"default": "cpu"}),
256+
"unet_donor_device": (devices, {"default": "cpu"}),
228257
"clip_compute_device": (devices, {"default": "cpu"}),
229258
"clip_virtual_vram_gb": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 128.0, "step": 0.1}),
230-
"clip_donor_device": ("STRING", {"default": "cpu"}),
259+
"clip_donor_device": (devices, {"default": "cpu"}),
231260
"vae_device": (devices, {"default": compute_device}),
232261
}, "optional": {
233262
"unet_expert_mode_allocations": ("STRING", {"multiline": False, "default": ""}),
234263
"clip_expert_mode_allocations": ("STRING", {"multiline": False, "default": ""}),
235264
"high_precision_loras": ("BOOLEAN", {"default": True}),
265+
"eject_models": ("BOOLEAN", {"default": True}),
236266
}
237267
}
238268

@@ -243,7 +273,22 @@ def INPUT_TYPES(s):
243273

244274
def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, unet_donor_device,
245275
clip_compute_device, clip_virtual_vram_gb, clip_donor_device, vae_device,
246-
unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True):
276+
unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True, eject_models=True):
277+
278+
if eject_models:
279+
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
280+
ejection_count = 0
281+
for i, lm in enumerate(mm.current_loaded_models):
282+
model_name = type(getattr(lm.model, 'model', lm.model)).__name__ if lm.model else 'Unknown'
283+
if hasattr(lm.model, 'model') and lm.model.model is not None:
284+
lm.model.model._mgpu_unload_distorch_model = True
285+
logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (id=0x{id(lm):x}) → marked for eviction")
286+
ejection_count += 1
287+
elif lm.model is not None:
288+
lm.model._mgpu_unload_distorch_model = True
289+
logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (direct patcher) → marked for eviction")
290+
ejection_count += 1
291+
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP_COMPLETE] Marked {ejection_count} models for Comfy Core eviction during load_models_gpu")
247292

248293
patch_load_state_dict_guess_config()
249294

distorch_2.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@
2020
from .model_management_mgpu import multigpu_memory_log, force_full_system_cleanup
2121

2222

23+
24+
def unpack_load_item(item):
25+
"""Handle ComfyUI 0.6.0+ 5-tuple vs legacy 4-tuple"""
26+
if len(item) == 5:
27+
# (module_offload_mem, module_mem, module_name, module_object, params)
28+
return item[1], item[2], item[3], item[4]
29+
# (module_mem, module_name, module_object, params)
30+
return item[0], item[1], item[2], item[3]
31+
32+
33+
34+
35+
36+
2337
def register_patched_safetensor_modelpatcher():
2438
"""Register and patch the ModelPatcher for distributed safetensor loading"""
2539
from comfy.model_patcher import wipe_lowvram_weight, move_weight_functions
@@ -53,7 +67,7 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False
5367
models_temp.add(m)
5468
model_type = type(m).__name__
5569

56-
if ("GGUF" in model_type or "ModelPatcher" in model_type) and hasattr(m, "model_patches_to"):
70+
if ("GGUF" in model_type or "ModelPatcher" in model_type) and hasattr(m, "model_patches_to") and not hasattr(m, "model_patches_models"):
5771
logger.info(f"[MultiGPU DisTorch V2] {type(m).__name__} missing 'model_patches_models' attribute, using 'model_patches_to' fallback.")
5872
target_device = m.load_device
5973
logger.debug(f"[MultiGPU DisTorch V2] Target device: {target_device}")
@@ -236,13 +250,26 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
236250
mem_counter = 0
237251

238252
is_clip_model = getattr(self, 'is_clip', False)
239-
device_assignments = analyze_safetensor_loading(self, allocations, is_clip=is_clip_model)
253+
## TODO - I do not believe this code is needed and needs to be flagged for proof it is needed
254+
# Check for valid cache
255+
allocations_match = hasattr(self, '_distorch_last_allocations') and self._distorch_last_allocations == allocations
256+
cache_exists = hasattr(self, '_distorch_cached_assignments')
257+
258+
if cache_exists and allocations_match and not unpatch_weights and not force_patch_weights:
259+
device_assignments = self._distorch_cached_assignments
260+
logger.debug(f"[MultiGPU DisTorch V2] Reusing cached analysis for {type(inner_model).__name__}")
261+
else:
262+
device_assignments = analyze_safetensor_loading(self, allocations, is_clip=is_clip_model) ## This should be the only required line - that is how it worked previous release so if it doesn't it is Comfy changes
263+
self._distorch_cached_assignments = device_assignments
264+
self._distorch_last_allocations = allocations
240265

241266
model_original_dtype = comfy.utils.weight_dtype(self.model.state_dict())
242267
high_precision_loras = getattr(self.model, "_distorch_high_precision_loras", True)
268+
# Use standard ComfyUI load list - the device comparison fix ensures we don't crash
243269
loading = self._load_list()
244270
loading.sort(reverse=True)
245-
for module_size, module_name, module_object, params in loading:
271+
for item in loading:
272+
module_size, module_name, module_object, params = unpack_load_item(item)
246273
if not unpatch_weights and hasattr(module_object, "comfy_patched_weights") and module_object.comfy_patched_weights == True:
247274
block_target_device = device_assignments['block_assignments'].get(module_name, device_to)
248275
current_module_device = None
@@ -290,7 +317,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
290317
logger.debug(f"[MultiGPU DisTorch V2] Cast {module_name}.{param_name} to FP8 for CPU storage")
291318

292319
# Step 4: Move to ultimate destination based on DisTorch assignment
293-
if block_target_device != device_to:
320+
if str(block_target_device) != str(device_to):
294321
logger.debug(f"[MultiGPU DisTorch V2] Moving {module_name} from {device_to} to {block_target_device}")
295322
module_object.to(block_target_device)
296323
module_object.comfy_cast_weights = True
@@ -321,7 +348,10 @@ def _extract_clip_head_blocks(raw_block_list, compute_device):
321348
head_memory = 0
322349
block_assignments = {}
323350

324-
for module_size, module_name, module_object, params in raw_block_list:
351+
block_assignments = {}
352+
353+
for item in raw_block_list:
354+
module_size, module_name, module_object, params = unpack_load_item(item)
325355
if any(kw in module_name.lower() for kw in head_keywords):
326356
head_blocks.append((module_size, module_name, module_object, params))
327357
block_assignments[module_name] = compute_device
@@ -423,7 +453,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
423453
total_memory = 0
424454

425455
raw_block_list = model_patcher._load_list()
426-
total_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
456+
total_memory = sum(unpack_load_item(x)[0] for x in raw_block_list)
427457

428458
MIN_BLOCK_THRESHOLD = total_memory * 0.0001
429459
logger.debug(f"[MultiGPU DisTorch V2] Total model memory: {total_memory} bytes")
@@ -441,7 +471,8 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
441471

442472
# Build all_blocks list for summary (using full raw_block_list)
443473
all_blocks = []
444-
for module_size, module_name, module_object, params in raw_block_list:
474+
for item in raw_block_list:
475+
module_size, module_name, module_object, params = unpack_load_item(item)
445476
block_type = type(module_object).__name__
446477
# Populate summary dictionaries
447478
block_summary[block_type] = block_summary.get(block_type, 0) + 1
@@ -450,11 +481,12 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
450481

451482
# Use distributable blocks for actual allocation (for CLIP, this excludes heads)
452483
distributable_all_blocks = []
453-
for module_size, module_name, module_object, params in distributable_raw:
484+
for item in distributable_raw:
485+
module_size, module_name, module_object, params = unpack_load_item(item)
454486
distributable_all_blocks.append((module_name, module_object, type(module_object).__name__, module_size))
455487

456-
block_list = [b for b in distributable_all_blocks if b[3] >= MIN_BLOCK_THRESHOLD]
457-
tiny_block_list = [b for b in distributable_all_blocks if b[3] < MIN_BLOCK_THRESHOLD]
488+
block_list = [b for b in distributable_all_blocks if (b[3] >= MIN_BLOCK_THRESHOLD and hasattr(b[1], "bias"))]
489+
tiny_block_list = [b for b in distributable_all_blocks if b not in block_list]
458490

459491
logger.debug(f"[MultiGPU DisTorch V2] Total blocks: {len(all_blocks)}")
460492
logger.debug(f"[MultiGPU DisTorch V2] Distributable blocks: {len(block_list)}")
@@ -476,8 +508,6 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
476508
# Distribute blocks sequentially from the tail of the model
477509

478510
device_assignments = {device: [] for device in DEVICE_RATIOS_DISTORCH.keys()}
479-
block_assignments = {}
480-
481511
# Create a memory quota for each donor device based on its calculated allocation.
482512
donor_devices = [d for d in sorted_devices]
483513
donor_quotas = {
@@ -581,7 +611,7 @@ def parse_memory_string(mem_str):
581611
def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
582612
"""Convert byte allocation string (e.g. 'cuda:1,4gb;cpu,*') to fractional VRAM allocation string respecting device order and byte quotas."""
583613
raw_block_list = model_patcher._load_list()
584-
total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
614+
total_model_memory = sum(unpack_load_item(x)[0] for x in raw_block_list)
585615
remaining_model_bytes = total_model_memory
586616

587617
# Use a list of tuples to preserve the user-defined order
@@ -640,7 +670,7 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
640670
def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str):
641671
"""Convert ratio allocation string (e.g. 'cuda:0,25%;cpu,75%') describing model split to fractional VRAM allocation string."""
642672
raw_block_list = model_patcher._load_list()
643-
total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
673+
total_model_memory = sum(unpack_load_item(x)[0] for x in raw_block_list)
644674

645675
raw_ratios = {}
646676
for allocation in ratio_str.split(';'):

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-multigpu"
33
description = "Provides a suite of custom nodes to manage multiple GPUs for ComfyUI, including advanced model offloading for both GGUF and Safetensor formats with DisTorch, and bespoke MultiGPU support for WanVideoWrapper and other custom nodes."
4-
version = "2.5.10"
4+
version = "2.5.11"
55
license = {file = "LICENSE"}
66

77
[project.urls]
@@ -11,4 +11,4 @@ Repository = "https://github.com/pollockjj/ComfyUI-MultiGPU"
1111
[tool.comfy]
1212
PublisherId = "pollockjj"
1313
DisplayName = "ComfyUI-MultiGPU"
14-
Icon = "https://raw.githubusercontent.com/pollockjj/ComfyUI-MultiGPU/main/assets/multigpu_icon.png"
14+
Icon = "https://raw.githubusercontent.com/pollockjj/ComfyUI-MultiGPU/main/assets/multigpu_icon.png"

0 commit comments

Comments
 (0)