@@ -86,8 +86,14 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
8686
8787 mem_counter = 0
8888
89- logger .info (f"[MultiGPU_DisTorch2] Using static allocation for model { debug_hash [:8 ]} " )
90- device_assignments = analyze_safetensor_loading (self , allocations )
89+ is_clip_model = getattr (self , 'is_clip' , False )
90+ if is_clip_model :
91+ logger .info (f"[MultiGPU_DisTorch2] Using CLIP-specific allocation for model { debug_hash [:8 ]} (HEAD PRESERVATION ENABLED)" )
92+ device_assignments = analyze_safetensor_loading_clip (self , allocations )
93+ else :
94+ logger .debug (f"[MultiGPU_DisTorch2] Using standard allocation for model { debug_hash [:8 ]} (UNET/VAE - UNTOUCHED)" )
95+ device_assignments = analyze_safetensor_loading (self , allocations )
96+
9197 model_original_dtype = comfy .utils .weight_dtype (self .model .state_dict ())
9298 high_precision_loras = self .model ._distorch_high_precision_loras
9399 loading = self ._load_list ()
@@ -164,6 +170,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
164170def analyze_safetensor_loading (model_patcher , allocations_string ):
165171 """
166172 Analyze and distribute safetensor model blocks across devices
173+ Target for refactor back into one function once stability for CLIP is established.
167174 """
168175 DEVICE_RATIOS_DISTORCH = {}
169176 device_table = {}
@@ -366,6 +373,211 @@ def analyze_safetensor_loading(model_patcher, allocations_string):
366373 "block_assignments" : block_assignments
367374 }
368375
376+
377+ def analyze_safetensor_loading_clip (model_patcher , allocations_string ):
378+ """
379+ CLIP-SPECIFIC: A 1:1 clone of the working UNET allocation logic with the
380+ single required modification to preserve head-blocks on the compute device.
381+ All other logic and UX (logging, etc.) is identical to the original.
382+ Target for refactor once stability for CLIP is established.
383+ """
384+ DEVICE_RATIOS_DISTORCH = {}
385+ device_table = {}
386+ distorch_alloc = allocations_string
387+ virtual_vram_gb = 0.0
388+
389+ distorch_alloc , virtual_vram_str = allocations_string .split ('#' )
390+
391+ compute_device = virtual_vram_str .split (';' )[0 ]
392+
393+ logger .info (f"[MultiGPU_DisTorch2_CLIP] CLIP Compute Device: { compute_device } " )
394+
395+ if not distorch_alloc :
396+ mode = "fraction"
397+ logger .info ("[MultiGPU_DisTorch2_CLIP] Expert String Examples:" )
398+ logger .info (" Direct(byte) Mode - cuda:0,500mb;cuda:1,3.0g;cpu,5gb* -> '*' cpu = over/underflow device, put 0.50gb on cuda0, 3.00gb on cuda1, and 5.00gb (or the rest) on cpu" )
399+ logger .info (" Ratio(%) Mode - cuda:0,8%;cuda:1,8%;cpu,4% -> 8:8:4 ratio, put 40% on cuda0, 40% on cuda1, and 20% on cpu" )
400+ distorch_alloc = calculate_safetensor_vvram_allocation (model_patcher , virtual_vram_str )
401+
402+ elif any (c in distorch_alloc .lower () for c in ['g' , 'm' , 'k' , 'b' ]):
403+ mode = "byte"
404+ distorch_alloc = calculate_fraction_from_byte_expert_string (model_patcher , distorch_alloc )
405+ elif "%" in distorch_alloc :
406+ mode = "ratio"
407+ distorch_alloc = calculate_fraction_from_ratio_expert_string (model_patcher , distorch_alloc )
408+
409+ all_devices = get_device_list ()
410+ present_devices = {item .split (',' )[0 ] for item in distorch_alloc .split (';' ) if ',' in item }
411+ for device in all_devices :
412+ if device not in present_devices :
413+ distorch_alloc += f";{ device } ,0.0"
414+
415+ logger .info (f"[MultiGPU_DisTorch2_CLIP] Final CLIP Allocation String: { distorch_alloc } " )
416+
417+ eq_line = "=" * 50
418+ dash_line = "-" * 50
419+ fmt_assign = "{:<18}{:>7}{:>14}{:>10}"
420+
421+ for allocation in distorch_alloc .split (';' ):
422+ if ',' not in allocation :
423+ continue
424+ dev_name , fraction = allocation .split (',' )
425+ fraction = float (fraction )
426+ total_mem_bytes = mm .get_total_memory (torch .device (dev_name ))
427+ alloc_gb = (total_mem_bytes * fraction ) / (1024 ** 3 )
428+ DEVICE_RATIOS_DISTORCH [dev_name ] = alloc_gb
429+ device_table [dev_name ] = {
430+ "fraction" : fraction ,
431+ "total_gb" : total_mem_bytes / (1024 ** 3 ),
432+ "alloc_gb" : alloc_gb
433+ }
434+
435+ logger .info (eq_line )
436+ logger .info (" DisTorch2 CLIP Model Device Allocations" )
437+ logger .info (eq_line )
438+
439+ fmt_rosetta = "{:<8}{:>9}{:>9}{:>11}{:>10}"
440+ logger .info (fmt_rosetta .format ("Device" , "VRAM GB" , "Dev %" , "Model GB" , "Dist %" ))
441+ logger .info (dash_line )
442+
443+ sorted_devices = sorted (device_table .keys (), key = lambda d : (d == "cpu" , d ))
444+
445+ total_allocated_model_bytes = sum (d ["alloc_gb" ] * (1024 ** 3 ) for d in device_table .values ())
446+
447+ for dev in sorted_devices :
448+ total_dev_gb = device_table [dev ]["total_gb" ]
449+ alloc_fraction = device_table [dev ]["fraction" ]
450+ alloc_gb = device_table [dev ]["alloc_gb" ]
451+
452+ dist_ratio_percent = (alloc_gb * (1024 ** 3 ) / total_allocated_model_bytes ) * 100 if total_allocated_model_bytes > 0 else 0
453+
454+ logger .info (fmt_rosetta .format (
455+ dev ,
456+ f"{ total_dev_gb :.2f} " ,
457+ f"{ alloc_fraction * 100 :.1f} %" ,
458+ f"{ alloc_gb :.2f} " ,
459+ f"{ dist_ratio_percent :.1f} %"
460+ ))
461+
462+ logger .info (dash_line )
463+
464+ block_summary = {}
465+ memory_by_type = defaultdict (int )
466+
467+ raw_block_list = model_patcher ._load_list ()
468+ total_memory = sum (module_size for module_size , _ , _ , _ in raw_block_list )
469+
470+ # Split the model into head and distributable parts
471+ head_keywords = ['embed' , 'wte' , 'wpe' , 'token_embedding' , 'position_embedding' ]
472+ head_blocks = []
473+ distributable_blocks_raw = []
474+ head_memory = 0
475+
476+ for module_size , module_name , module_object , params in raw_block_list :
477+ if any (keyword in module_name .lower () for keyword in head_keywords ):
478+ head_blocks .append ((module_size , module_name , module_object , params ))
479+ else :
480+ distributable_blocks_raw .append ((module_size , module_name , module_object , params ))
481+
482+ MIN_BLOCK_THRESHOLD = total_memory * 0.0001
483+ all_blocks = []
484+
485+ for module_size , module_name , module_object , params in raw_block_list :
486+ block_type = type (module_object ).__name__
487+ block_summary [block_type ] = block_summary .get (block_type , 0 ) + 1
488+ memory_by_type [block_type ] += module_size
489+ all_blocks .append ((module_name , module_object , block_type , module_size ))
490+
491+ # Use the distributable part for actual allocation logic
492+ distributable_all_blocks = []
493+ for module_size , module_name , module_object , params in distributable_blocks_raw :
494+ distributable_all_blocks .append ((module_name , module_object , type (module_object ).__name__ , module_size ))
495+
496+ block_list = [b for b in distributable_all_blocks if b [3 ] >= MIN_BLOCK_THRESHOLD ]
497+ tiny_block_list = [b for b in distributable_all_blocks if b [3 ] < MIN_BLOCK_THRESHOLD ]
498+
499+ logger .info (" DisTorch2 CLIP Model Layer Distribution" )
500+ logger .info (dash_line )
501+ fmt_layer = "{:<18}{:>7}{:>14}{:>10}"
502+ logger .info (fmt_layer .format ("Layer Type" , "Layers" , "Memory (MB)" , "% Total" ))
503+ logger .info (dash_line )
504+
505+ for layer_type , count in block_summary .items ():
506+ mem_mb = memory_by_type [layer_type ] / (1024 * 1024 )
507+ mem_percent = (memory_by_type [layer_type ] / total_memory ) * 100 if total_memory > 0 else 0
508+ logger .info (fmt_layer .format (layer_type [:18 ], str (count ), f"{ mem_mb :.2f} " , f"{ mem_percent :.1f} %" ))
509+
510+ logger .info (dash_line )
511+
512+ block_assignments = {}
513+
514+ # Pre-assign head blocks and calculate their memory usage
515+ for module_size , module_name , module_object , params in head_blocks :
516+ block_assignments [module_name ] = compute_device
517+ head_memory += module_size
518+ if head_blocks :
519+ logger .info (f"[MultiGPU_DisTorch2_CLIP] Preserving { len (head_blocks )} head layer(s) ({ head_memory / (1024 * 1024 ):.2f} MB) on compute device: { compute_device } " )
520+ donor_devices = [d for d in sorted_devices ]
521+ donor_quotas = {
522+ dev : device_table [dev ]["alloc_gb" ] * (1024 ** 3 )
523+ for dev in donor_devices
524+ }
525+ # Adjust compute_device quota to account for the locked head
526+ if compute_device in donor_quotas :
527+ donor_quotas [compute_device ] = max (0 , donor_quotas [compute_device ] - head_memory )
528+
529+ for block_name , module , block_type , block_memory in reversed (block_list ):
530+ assigned_to_donor = False
531+ for donor in donor_devices :
532+ if donor_quotas [donor ] >= block_memory :
533+ block_assignments [block_name ] = donor
534+ donor_quotas [donor ] -= block_memory
535+ assigned_to_donor = True
536+ break # Move to the next block
537+
538+ if not assigned_to_donor :
539+ block_assignments [block_name ] = compute_device
540+
541+ for block_name , module , block_type , block_memory in tiny_block_list :
542+ block_assignments [block_name ] = compute_device
543+
544+ device_assignments = {device : [] for device in DEVICE_RATIOS_DISTORCH .keys ()}
545+ for block_name , device in block_assignments .items ():
546+ # Find the block in the original list to get all its info
547+ for b_name , b_module , b_type , b_mem in all_blocks :
548+ if b_name == block_name :
549+ device_assignments [device ].append ((b_name , b_module , b_type , b_mem ))
550+ break
551+
552+ logger .info ("DisTorch2 CLIP Model Final Device/Layer Assignments" )
553+ logger .info (dash_line )
554+ logger .info (fmt_assign .format ("Device" , "Layers" , "Memory (MB)" , "% Total" ))
555+ logger .info (dash_line )
556+
557+ device_memories = defaultdict (int )
558+ device_counts = defaultdict (int )
559+ for device , blocks in device_assignments .items ():
560+ for b_name , b_module , b_type , b_mem in blocks :
561+ device_memories [device ] += b_mem
562+ device_counts [device ] += 1
563+
564+ sorted_assignments = sorted (device_memories .keys (), key = lambda d : (d == "cpu" , d ))
565+
566+ for dev in sorted_assignments :
567+ if device_counts [dev ] == 0 :
568+ continue
569+ mem_mb = device_memories [dev ] / (1024 * 1024 )
570+ mem_percent = (device_memories [dev ] / total_memory ) * 100 if total_memory > 0 else 0
571+ logger .info (fmt_assign .format (dev , str (device_counts [dev ]), f"{ mem_mb :.2f} " , f"{ mem_percent :.1f} %" ))
572+
573+ logger .info (dash_line )
574+
575+ return {
576+ "device_assignments" : device_assignments ,
577+ "block_assignments" : block_assignments
578+ }
579+
580+
369581def parse_memory_string (mem_str ):
370582 """Parses a memory string (e.g., '4.0g', '512M') and returns bytes."""
371583 mem_str = mem_str .strip ().lower ()
0 commit comments