Skip to content

Commit 03e5368

Browse files
committed
pdate Distorch for ComfyUI 0.6.0+ load list with device assignment caching.
1 parent 62f98ed commit 03e5368

2 files changed

Lines changed: 96 additions & 21 deletions

File tree

checkpoint_multigpu.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6262
diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(sd)
6363
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
6464
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
65+
66+
custom_operations = model_options.get("custom_operations", None)
67+
if custom_operations is None:
68+
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
69+
6570
model_config = comfy.model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
6671

6772
if model_config is None:
@@ -79,13 +84,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
7984
if model_config.scaled_fp8 is not None:
8085
weight_dtype = None
8186

82-
model_config.custom_operations = model_options.get("custom_operations", None)
87+
if custom_operations is not None:
88+
model_config.custom_operations = custom_operations
8389
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
8490
if unet_dtype is None:
8591
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
8692

8793
unet_compute_device = device_config.get('unet_device', original_main_device)
88-
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
94+
if model_config.scaled_fp8 is not None:
95+
manual_cast_dtype = mm.unet_manual_cast(None, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
96+
else:
97+
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
8998
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
9099
logger.info(f"UNet DType: {unet_dtype}, Manual Cast: {manual_cast_dtype}")
91100

@@ -101,6 +110,8 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
101110
multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load")
102111

103112
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
113+
model.load_model_weights(sd, diffusion_model_prefix)
114+
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
104115

105116
logger.mgpu_mm_log("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup")
106117
soft_empty_cache_multigpu()
@@ -116,9 +127,6 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
116127
logger.info(f"[CHECKPOINT_META] UNET inner_model id=0x{id(inner_model):x}")
117128
model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
118129

119-
model.load_model_weights(sd, diffusion_model_prefix)
120-
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
121-
122130
if output_vae:
123131
vae_target_device = torch.device(device_config.get('vae_device', original_main_device))
124132
set_current_device(vae_target_device) # Use main device context for VAE
@@ -130,6 +138,27 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
130138
multigpu_memory_log(f"vae:{config_hash[:8]}", "post-load")
131139

132140
if output_clip:
141+
if te_model_options.get("custom_operations", None) is None:
142+
scaled_fp8_list = []
143+
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
144+
if k.endswith(".scaled_fp8"):
145+
scaled_fp8_list.append(k[:-len("scaled_fp8")])
146+
147+
if len(scaled_fp8_list) > 0:
148+
out_sd = {}
149+
for k in sd:
150+
skip = False
151+
for pref in scaled_fp8_list:
152+
skip = skip or k.startswith(pref)
153+
if not skip:
154+
out_sd[k] = sd[k]
155+
156+
for pref in scaled_fp8_list:
157+
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
158+
for k in quant_sd:
159+
out_sd[k] = quant_sd[k]
160+
sd = out_sd
161+
133162
clip_target_device = device_config.get('clip_device', original_clip_device)
134163
set_current_text_encoder_device(clip_target_device)
135164

@@ -224,15 +253,16 @@ def INPUT_TYPES(s):
224253
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
225254
"unet_compute_device": (devices, {"default": compute_device}),
226255
"unet_virtual_vram_gb": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 128.0, "step": 0.1}),
227-
"unet_donor_device": ("STRING", {"default": "cpu"}),
256+
"unet_donor_device": (devices, {"default": "cpu"}),
228257
"clip_compute_device": (devices, {"default": "cpu"}),
229258
"clip_virtual_vram_gb": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 128.0, "step": 0.1}),
230-
"clip_donor_device": ("STRING", {"default": "cpu"}),
259+
"clip_donor_device": (devices, {"default": "cpu"}),
231260
"vae_device": (devices, {"default": compute_device}),
232261
}, "optional": {
233262
"unet_expert_mode_allocations": ("STRING", {"multiline": False, "default": ""}),
234263
"clip_expert_mode_allocations": ("STRING", {"multiline": False, "default": ""}),
235264
"high_precision_loras": ("BOOLEAN", {"default": True}),
265+
"eject_models": ("BOOLEAN", {"default": True}),
236266
}
237267
}
238268

@@ -243,7 +273,22 @@ def INPUT_TYPES(s):
243273

244274
def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, unet_donor_device,
245275
clip_compute_device, clip_virtual_vram_gb, clip_donor_device, vae_device,
246-
unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True):
276+
unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True, eject_models=True):
277+
278+
if eject_models:
279+
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
280+
ejection_count = 0
281+
for i, lm in enumerate(mm.current_loaded_models):
282+
model_name = type(getattr(lm.model, 'model', lm.model)).__name__ if lm.model else 'Unknown'
283+
if hasattr(lm.model, 'model') and lm.model.model is not None:
284+
lm.model.model._mgpu_unload_distorch_model = True
285+
logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (id=0x{id(lm):x}) → marked for eviction")
286+
ejection_count += 1
287+
elif lm.model is not None:
288+
lm.model._mgpu_unload_distorch_model = True
289+
logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (direct patcher) → marked for eviction")
290+
ejection_count += 1
291+
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP_COMPLETE] Marked {ejection_count} models for Comfy Core eviction during load_models_gpu")
247292

248293
patch_load_state_dict_guess_config()
249294

distorch_2.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@
2020
from .model_management_mgpu import multigpu_memory_log, force_full_system_cleanup
2121

2222

23+
24+
def unpack_load_item(item):
25+
"""Handle ComfyUI 0.6.0+ 5-tuple vs legacy 4-tuple"""
26+
if len(item) == 5:
27+
# (module_offload_mem, module_mem, module_name, module_object, params)
28+
return item[1], item[2], item[3], item[4]
29+
# (module_mem, module_name, module_object, params)
30+
return item[0], item[1], item[2], item[3]
31+
32+
33+
34+
35+
36+
2337
def register_patched_safetensor_modelpatcher():
2438
"""Register and patch the ModelPatcher for distributed safetensor loading"""
2539
from comfy.model_patcher import wipe_lowvram_weight, move_weight_functions
@@ -236,13 +250,26 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
236250
mem_counter = 0
237251

238252
is_clip_model = getattr(self, 'is_clip', False)
239-
device_assignments = analyze_safetensor_loading(self, allocations, is_clip=is_clip_model)
253+
## TODO - I do not believe this code is needed and needs to be flagged for proof it is needed
254+
# Check for valid cache
255+
allocations_match = hasattr(self, '_distorch_last_allocations') and self._distorch_last_allocations == allocations
256+
cache_exists = hasattr(self, '_distorch_cached_assignments')
257+
258+
if cache_exists and allocations_match and not unpatch_weights and not force_patch_weights:
259+
device_assignments = self._distorch_cached_assignments
260+
logger.debug(f"[MultiGPU DisTorch V2] Reusing cached analysis for {type(inner_model).__name__}")
261+
else:
262+
device_assignments = analyze_safetensor_loading(self, allocations, is_clip=is_clip_model) ## This should be the only required line - that is how it worked previous release so if it doesn't it is Comfy changes
263+
self._distorch_cached_assignments = device_assignments
264+
self._distorch_last_allocations = allocations
240265

241266
model_original_dtype = comfy.utils.weight_dtype(self.model.state_dict())
242267
high_precision_loras = getattr(self.model, "_distorch_high_precision_loras", True)
268+
# Use standard ComfyUI load list - the device comparison fix ensures we don't crash
243269
loading = self._load_list()
244270
loading.sort(reverse=True)
245-
for module_size, module_name, module_object, params in loading:
271+
for item in loading:
272+
module_size, module_name, module_object, params = unpack_load_item(item)
246273
if not unpatch_weights and hasattr(module_object, "comfy_patched_weights") and module_object.comfy_patched_weights == True:
247274
block_target_device = device_assignments['block_assignments'].get(module_name, device_to)
248275
current_module_device = None
@@ -290,7 +317,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
290317
logger.debug(f"[MultiGPU DisTorch V2] Cast {module_name}.{param_name} to FP8 for CPU storage")
291318

292319
# Step 4: Move to ultimate destination based on DisTorch assignment
293-
if block_target_device != device_to:
320+
if str(block_target_device) != str(device_to):
294321
logger.debug(f"[MultiGPU DisTorch V2] Moving {module_name} from {device_to} to {block_target_device}")
295322
module_object.to(block_target_device)
296323
module_object.comfy_cast_weights = True
@@ -321,7 +348,10 @@ def _extract_clip_head_blocks(raw_block_list, compute_device):
321348
head_memory = 0
322349
block_assignments = {}
323350

324-
for module_size, module_name, module_object, params in raw_block_list:
351+
block_assignments = {}
352+
353+
for item in raw_block_list:
354+
module_size, module_name, module_object, params = unpack_load_item(item)
325355
if any(kw in module_name.lower() for kw in head_keywords):
326356
head_blocks.append((module_size, module_name, module_object, params))
327357
block_assignments[module_name] = compute_device
@@ -423,7 +453,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
423453
total_memory = 0
424454

425455
raw_block_list = model_patcher._load_list()
426-
total_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
456+
total_memory = sum(unpack_load_item(x)[0] for x in raw_block_list)
427457

428458
MIN_BLOCK_THRESHOLD = total_memory * 0.0001
429459
logger.debug(f"[MultiGPU DisTorch V2] Total model memory: {total_memory} bytes")
@@ -441,7 +471,8 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
441471

442472
# Build all_blocks list for summary (using full raw_block_list)
443473
all_blocks = []
444-
for module_size, module_name, module_object, params in raw_block_list:
474+
for item in raw_block_list:
475+
module_size, module_name, module_object, params = unpack_load_item(item)
445476
block_type = type(module_object).__name__
446477
# Populate summary dictionaries
447478
block_summary[block_type] = block_summary.get(block_type, 0) + 1
@@ -450,11 +481,12 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
450481

451482
# Use distributable blocks for actual allocation (for CLIP, this excludes heads)
452483
distributable_all_blocks = []
453-
for module_size, module_name, module_object, params in distributable_raw:
484+
for item in distributable_raw:
485+
module_size, module_name, module_object, params = unpack_load_item(item)
454486
distributable_all_blocks.append((module_name, module_object, type(module_object).__name__, module_size))
455487

456-
block_list = [b for b in distributable_all_blocks if b[3] >= MIN_BLOCK_THRESHOLD]
457-
tiny_block_list = [b for b in distributable_all_blocks if b[3] < MIN_BLOCK_THRESHOLD]
488+
block_list = [b for b in distributable_all_blocks if (b[3] >= MIN_BLOCK_THRESHOLD and hasattr(b[1], "bias"))]
489+
tiny_block_list = [b for b in distributable_all_blocks if b not in block_list]
458490

459491
logger.debug(f"[MultiGPU DisTorch V2] Total blocks: {len(all_blocks)}")
460492
logger.debug(f"[MultiGPU DisTorch V2] Distributable blocks: {len(block_list)}")
@@ -476,8 +508,6 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
476508
# Distribute blocks sequentially from the tail of the model
477509

478510
device_assignments = {device: [] for device in DEVICE_RATIOS_DISTORCH.keys()}
479-
block_assignments = {}
480-
481511
# Create a memory quota for each donor device based on its calculated allocation.
482512
donor_devices = [d for d in sorted_devices]
483513
donor_quotas = {
@@ -581,7 +611,7 @@ def parse_memory_string(mem_str):
581611
def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
582612
"""Convert byte allocation string (e.g. 'cuda:1,4gb;cpu,*') to fractional VRAM allocation string respecting device order and byte quotas."""
583613
raw_block_list = model_patcher._load_list()
584-
total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
614+
total_model_memory = sum(unpack_load_item(x)[0] for x in raw_block_list)
585615
remaining_model_bytes = total_model_memory
586616

587617
# Use a list of tuples to preserve the user-defined order
@@ -640,7 +670,7 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
640670
def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str):
641671
"""Convert ratio allocation string (e.g. 'cuda:0,25%;cpu,75%') describing model split to fractional VRAM allocation string."""
642672
raw_block_list = model_patcher._load_list()
643-
total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
673+
total_model_memory = sum(unpack_load_item(x)[0] for x in raw_block_list)
644674

645675
raw_ratios = {}
646676
for allocation in ratio_str.split(';'):

0 commit comments

Comments
 (0)