2020from .model_management_mgpu import multigpu_memory_log , force_full_system_cleanup
2121
2222
23+
24+ def unpack_load_item (item ):
25+ """Handle ComfyUI 0.6.0+ 5-tuple vs legacy 4-tuple"""
26+ if len (item ) == 5 :
27+ # (module_offload_mem, module_mem, module_name, module_object, params)
28+ return item [1 ], item [2 ], item [3 ], item [4 ]
29+ # (module_mem, module_name, module_object, params)
30+ return item [0 ], item [1 ], item [2 ], item [3 ]
31+
32+
33+
34+
35+
36+
2337def register_patched_safetensor_modelpatcher ():
2438 """Register and patch the ModelPatcher for distributed safetensor loading"""
2539 from comfy .model_patcher import wipe_lowvram_weight , move_weight_functions
@@ -53,7 +67,7 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False
5367 models_temp .add (m )
5468 model_type = type (m ).__name__
5569
56- if ("GGUF" in model_type or "ModelPatcher" in model_type ) and hasattr (m , "model_patches_to" ):
70+ if ("GGUF" in model_type or "ModelPatcher" in model_type ) and hasattr (m , "model_patches_to" ) and not hasattr ( m , "model_patches_models" ) :
5771 logger .info (f"[MultiGPU DisTorch V2] { type (m ).__name__ } missing 'model_patches_models' attribute, using 'model_patches_to' fallback." )
5872 target_device = m .load_device
5973 logger .debug (f"[MultiGPU DisTorch V2] Target device: { target_device } " )
@@ -236,13 +250,26 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
236250 mem_counter = 0
237251
238252 is_clip_model = getattr (self , 'is_clip' , False )
239- device_assignments = analyze_safetensor_loading (self , allocations , is_clip = is_clip_model )
253+ ## TODO - I do not believe this code is needed and needs to be flagged for proof it is needed
254+ # Check for valid cache
255+ allocations_match = hasattr (self , '_distorch_last_allocations' ) and self ._distorch_last_allocations == allocations
256+ cache_exists = hasattr (self , '_distorch_cached_assignments' )
257+
258+ if cache_exists and allocations_match and not unpatch_weights and not force_patch_weights :
259+ device_assignments = self ._distorch_cached_assignments
260+ logger .debug (f"[MultiGPU DisTorch V2] Reusing cached analysis for { type (inner_model ).__name__ } " )
261+ else :
262+ device_assignments = analyze_safetensor_loading (self , allocations , is_clip = is_clip_model ) ## This should be the only required line - that is how it worked previous release so if it doesn't it is Comfy changes
263+ self ._distorch_cached_assignments = device_assignments
264+ self ._distorch_last_allocations = allocations
240265
241266 model_original_dtype = comfy .utils .weight_dtype (self .model .state_dict ())
242267 high_precision_loras = getattr (self .model , "_distorch_high_precision_loras" , True )
268+ # Use standard ComfyUI load list - the device comparison fix ensures we don't crash
243269 loading = self ._load_list ()
244270 loading .sort (reverse = True )
245- for module_size , module_name , module_object , params in loading :
271+ for item in loading :
272+ module_size , module_name , module_object , params = unpack_load_item (item )
246273 if not unpatch_weights and hasattr (module_object , "comfy_patched_weights" ) and module_object .comfy_patched_weights == True :
247274 block_target_device = device_assignments ['block_assignments' ].get (module_name , device_to )
248275 current_module_device = None
@@ -290,7 +317,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
290317 logger .debug (f"[MultiGPU DisTorch V2] Cast { module_name } .{ param_name } to FP8 for CPU storage" )
291318
292319 # Step 4: Move to ultimate destination based on DisTorch assignment
293- if block_target_device != device_to :
320+ if str ( block_target_device ) != str ( device_to ) :
294321 logger .debug (f"[MultiGPU DisTorch V2] Moving { module_name } from { device_to } to { block_target_device } " )
295322 module_object .to (block_target_device )
296323 module_object .comfy_cast_weights = True
@@ -321,7 +348,10 @@ def _extract_clip_head_blocks(raw_block_list, compute_device):
321348 head_memory = 0
322349 block_assignments = {}
323350
324- for module_size , module_name , module_object , params in raw_block_list :
351+ block_assignments = {}
352+
353+ for item in raw_block_list :
354+ module_size , module_name , module_object , params = unpack_load_item (item )
325355 if any (kw in module_name .lower () for kw in head_keywords ):
326356 head_blocks .append ((module_size , module_name , module_object , params ))
327357 block_assignments [module_name ] = compute_device
@@ -423,7 +453,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
423453 total_memory = 0
424454
425455 raw_block_list = model_patcher ._load_list ()
426- total_memory = sum (module_size for module_size , _ , _ , _ in raw_block_list )
456+ total_memory = sum (unpack_load_item ( x )[ 0 ] for x in raw_block_list )
427457
428458 MIN_BLOCK_THRESHOLD = total_memory * 0.0001
429459 logger .debug (f"[MultiGPU DisTorch V2] Total model memory: { total_memory } bytes" )
@@ -441,7 +471,8 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
441471
442472 # Build all_blocks list for summary (using full raw_block_list)
443473 all_blocks = []
444- for module_size , module_name , module_object , params in raw_block_list :
474+ for item in raw_block_list :
475+ module_size , module_name , module_object , params = unpack_load_item (item )
445476 block_type = type (module_object ).__name__
446477 # Populate summary dictionaries
447478 block_summary [block_type ] = block_summary .get (block_type , 0 ) + 1
@@ -450,11 +481,12 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
450481
451482 # Use distributable blocks for actual allocation (for CLIP, this excludes heads)
452483 distributable_all_blocks = []
453- for module_size , module_name , module_object , params in distributable_raw :
484+ for item in distributable_raw :
485+ module_size , module_name , module_object , params = unpack_load_item (item )
454486 distributable_all_blocks .append ((module_name , module_object , type (module_object ).__name__ , module_size ))
455487
456- block_list = [b for b in distributable_all_blocks if b [3 ] >= MIN_BLOCK_THRESHOLD ]
457- tiny_block_list = [b for b in distributable_all_blocks if b [ 3 ] < MIN_BLOCK_THRESHOLD ]
488+ block_list = [b for b in distributable_all_blocks if ( b [3 ] >= MIN_BLOCK_THRESHOLD and hasattr ( b [ 1 ], "bias" ))]
489+ tiny_block_list = [b for b in distributable_all_blocks if b not in block_list ]
458490
459491 logger .debug (f"[MultiGPU DisTorch V2] Total blocks: { len (all_blocks )} " )
460492 logger .debug (f"[MultiGPU DisTorch V2] Distributable blocks: { len (block_list )} " )
@@ -476,8 +508,6 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
476508 # Distribute blocks sequentially from the tail of the model
477509
478510 device_assignments = {device : [] for device in DEVICE_RATIOS_DISTORCH .keys ()}
479- block_assignments = {}
480-
481511 # Create a memory quota for each donor device based on its calculated allocation.
482512 donor_devices = [d for d in sorted_devices ]
483513 donor_quotas = {
@@ -581,7 +611,7 @@ def parse_memory_string(mem_str):
581611def calculate_fraction_from_byte_expert_string (model_patcher , byte_str ):
582612 """Convert byte allocation string (e.g. 'cuda:1,4gb;cpu,*') to fractional VRAM allocation string respecting device order and byte quotas."""
583613 raw_block_list = model_patcher ._load_list ()
584- total_model_memory = sum (module_size for module_size , _ , _ , _ in raw_block_list )
614+ total_model_memory = sum (unpack_load_item ( x )[ 0 ] for x in raw_block_list )
585615 remaining_model_bytes = total_model_memory
586616
587617 # Use a list of tuples to preserve the user-defined order
@@ -640,7 +670,7 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
640670def calculate_fraction_from_ratio_expert_string (model_patcher , ratio_str ):
641671 """Convert ratio allocation string (e.g. 'cuda:0,25%;cpu,75%') describing model split to fractional VRAM allocation string."""
642672 raw_block_list = model_patcher ._load_list ()
643- total_model_memory = sum (module_size for module_size , _ , _ , _ in raw_block_list )
673+ total_model_memory = sum (unpack_load_item ( x )[ 0 ] for x in raw_block_list )
644674
645675 raw_ratios = {}
646676 for allocation in ratio_str .split (';' ):
0 commit comments