@@ -377,152 +377,6 @@ def analyze_safetensor_loading(model_patcher, allocations_str):
377377 "device_assignments" : device_assignments ,
378378 "block_assignments" : block_assignments
379379 }
380- def analyze_safetensor_loading_comfy (model_patcher , allocations_str ):
381- """
382- Analyze and distribute safetensor model blocks across devices utilizing model_patcher._load_list().sort(reverse=True) method like Comfy
383- """
384- DEVICE_RATIOS_DISTORCH = {}
385- device_table = {}
386- distorch_alloc = allocations_str
387- virtual_vram_gb = 0.0
388-
389- # Parse allocation string EXACTLY like GGML
390- if '#' in allocations_str :
391- distorch_alloc , virtual_vram_str = allocations_str .split ('#' )
392- if not distorch_alloc :
393- distorch_alloc = calculate_safetensor_vvram_allocation (model_patcher , virtual_vram_str )
394-
395- # EXACT SAME FORMATTING AS GGML
396- eq_line = "=" * 50
397- dash_line = "-" * 50
398- fmt_assign = "{:<18}{:>7}{:>14}{:>10}"
399-
400- # Parse device allocations
401- for allocation in distorch_alloc .split (';' ):
402- if ',' not in allocation :
403- continue
404- dev_name , fraction = allocation .split (',' )
405- fraction = float (fraction )
406- total_mem_bytes = mm .get_total_memory (torch .device (dev_name ))
407- alloc_gb = (total_mem_bytes * fraction ) / (1024 ** 3 )
408- DEVICE_RATIOS_DISTORCH [dev_name ] = alloc_gb
409- device_table [dev_name ] = {
410- "fraction" : fraction ,
411- "total_gb" : total_mem_bytes / (1024 ** 3 ),
412- "alloc_gb" : alloc_gb
413- }
414-
415- # IDENTICAL LOGGING TO DISTORCH
416- logger .info (eq_line )
417- logger .info (" DisTorch2 Model Device Allocations" )
418- logger .info (eq_line )
419- logger .info (fmt_assign .format ("Device" , "Alloc %" , "Total (GB)" , " Alloc (GB)" ))
420- logger .info (dash_line )
421-
422- sorted_devices = sorted (device_table .keys (), key = lambda d : (d == "cpu" , d ))
423-
424- for dev in sorted_devices :
425- frac = device_table [dev ]["fraction" ]
426- tot_gb = device_table [dev ]["total_gb" ]
427- alloc_gb = device_table [dev ]["alloc_gb" ]
428- logger .info (fmt_assign .format (dev ,f"{ int (frac * 100 )} %" ,f"{ tot_gb :.2f} " ,f"{ alloc_gb :.2f} " ))
429-
430- logger .info (dash_line )
431-
432- # Get the model blocks using ComfyUI's method
433- block_list = model_patcher ._load_list ()
434- block_list .sort (reverse = True )
435-
436- # Log layer distribution
437- total_memory = sum (b [0 ] for b in block_list )
438- memory_by_type = defaultdict (int )
439- block_summary = defaultdict (int )
440- for module_size , module_name , module_object , params in block_list :
441- block_type = module_object .__class__ .__name__
442- block_summary [block_type ] += 1
443- memory_by_type [block_type ] += module_size
444-
445- # Log layer distribution - IDENTICAL FORMAT TO GGML
446- logger .info (" DisTorch2 Model Layer Distribution" )
447- logger .info (dash_line )
448- fmt_layer = "{:<18}{:>7}{:>14}{:>10}"
449- logger .info (fmt_layer .format ("Layer Type" , "Layers" , "Memory (MB)" , "% Total" ))
450- logger .info (dash_line )
451-
452- for layer_type , count in block_summary .items ():
453- mem_mb = memory_by_type [layer_type ] / (1024 * 1024 )
454- mem_percent = (memory_by_type [layer_type ] / total_memory ) * 100 if total_memory > 0 else 0
455- logger .info (fmt_layer .format (layer_type [:18 ], str (count ), f"{ mem_mb :.2f} " , f"{ mem_percent :.1f} %" ))
456-
457- logger .info (dash_line )
458-
459- # Distribute blocks sequentially
460- device_assignments = {device : [] for device in DEVICE_RATIOS_DISTORCH .keys ()}
461- block_assignments = {}
462-
463- compute_device = str (current_device )
464- # Calculate total memory to be offloaded to donor devices
465- total_offload_gb = sum (DEVICE_RATIOS_DISTORCH .get (d , 0 ) for d in sorted_devices if d != compute_device )
466- total_offload_bytes = total_offload_gb * (1024 ** 3 )
467-
468- offloaded_bytes = 0
469-
470- # Iterate through the sorted list (largest blocks first)
471- for module_size , module_name , module_object , params in block_list :
472- # Assign to donor device until target is met
473- if offloaded_bytes < total_offload_bytes :
474- # For now, simple offload to CPU, will expand for multi-donor
475- donor_device = "cpu"
476- for dev in sorted_devices :
477- if dev != compute_device :
478- donor_device = dev
479- break # Use first available donor
480-
481- block_assignments [module_name ] = donor_device
482- setattr (module_object , 'distorch2_cpu_offload' , True ) # Attach the attribute here
483- offloaded_bytes += module_size
484- else :
485- # Assign remaining blocks to the primary compute device
486- block_assignments [module_name ] = compute_device
487-
488- # Populate device_assignments from the final block_assignments
489- for module_size , module_name , module_object , params in block_list :
490- device = block_assignments [module_name ]
491- if device not in device_assignments :
492- device_assignments [device ] = []
493- device_assignments [device ].append ((module_name , module_object , module_object .__class__ .__name__ , module_size ))
494-
495- # Log final assignments - IDENTICAL FORMAT TO GGML
496- logger .info ("DisTorch2 Model Final Device/Layer Assignments" )
497- logger .info (dash_line )
498- logger .info (fmt_assign .format ("Device" , "Layers" , "Memory (MB)" , "% Total" ))
499- logger .info (dash_line )
500-
501- # Log distributed blocks
502- total_assigned_memory = 0
503- device_memories = {}
504-
505- for device , blocks in device_assignments .items ():
506- device_memory = sum (b [3 ] for b in blocks )
507- device_memories [device ] = device_memory
508- total_assigned_memory += device_memory
509-
510- sorted_assignments = sorted (device_memories .keys (), key = lambda d : (d == "cpu" , d ))
511-
512- for dev in sorted_assignments :
513- if dev not in device_memories :
514- continue
515- mem_mb = device_memories [dev ] / (1024 * 1024 )
516- mem_percent = (device_memories [dev ] / total_memory ) * 100 if total_memory > 0 else 0
517- logger .info (fmt_assign .format (dev , str (len (device_assignments [dev ])), f"{ mem_mb :.2f} " , f"{ mem_percent :.1f} %" ))
518-
519- logger .info (dash_line )
520-
521- return {
522- "device_assignments" : device_assignments ,
523- "block_assignments" : block_assignments ,
524- "lowvram_model_memory" : total_assigned_memory ,
525- }
526380
527381def calculate_safetensor_vvram_allocation (model_patcher , virtual_vram_str ):
528382 """Calculate virtual VRAM allocation string for distributed safetensor loading"""
@@ -609,7 +463,6 @@ def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str):
609463
610464 return allocation_string
611465
612-
613466def override_class_with_distorch_safetensor_v2 (cls ):
614467 """DisTorch 2.0 wrapper for safetensor models"""
615468 from .nodes import get_device_list
0 commit comments