Skip to content

Commit 4367c89

Browse files
committed
Eliminate unused safetensor loading analysis method and update example configurations, adding one with LoRAs as one of the tested configurations to avoid the issue seen during initial release.
1 parent e28b040 commit 4367c89

5 files changed

Lines changed: 855 additions & 147 deletions

File tree

distorch_2.py

Lines changed: 0 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -377,152 +377,6 @@ def analyze_safetensor_loading(model_patcher, allocations_str):
377377
"device_assignments": device_assignments,
378378
"block_assignments": block_assignments
379379
}
380-
def analyze_safetensor_loading_comfy(model_patcher, allocations_str):
381-
"""
382-
Analyze and distribute safetensor model blocks across devices utilizing model_patcher._load_list().sort(reverse=True) method like Comfy
383-
"""
384-
DEVICE_RATIOS_DISTORCH = {}
385-
device_table = {}
386-
distorch_alloc = allocations_str
387-
virtual_vram_gb = 0.0
388-
389-
# Parse allocation string EXACTLY like GGML
390-
if '#' in allocations_str:
391-
distorch_alloc, virtual_vram_str = allocations_str.split('#')
392-
if not distorch_alloc:
393-
distorch_alloc = calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str)
394-
395-
# EXACT SAME FORMATTING AS GGML
396-
eq_line = "=" * 50
397-
dash_line = "-" * 50
398-
fmt_assign = "{:<18}{:>7}{:>14}{:>10}"
399-
400-
# Parse device allocations
401-
for allocation in distorch_alloc.split(';'):
402-
if ',' not in allocation:
403-
continue
404-
dev_name, fraction = allocation.split(',')
405-
fraction = float(fraction)
406-
total_mem_bytes = mm.get_total_memory(torch.device(dev_name))
407-
alloc_gb = (total_mem_bytes * fraction) / (1024**3)
408-
DEVICE_RATIOS_DISTORCH[dev_name] = alloc_gb
409-
device_table[dev_name] = {
410-
"fraction": fraction,
411-
"total_gb": total_mem_bytes / (1024**3),
412-
"alloc_gb": alloc_gb
413-
}
414-
415-
# IDENTICAL LOGGING TO DISTORCH
416-
logger.info(eq_line)
417-
logger.info(" DisTorch2 Model Device Allocations")
418-
logger.info(eq_line)
419-
logger.info(fmt_assign.format("Device", "Alloc %", "Total (GB)", " Alloc (GB)"))
420-
logger.info(dash_line)
421-
422-
sorted_devices = sorted(device_table.keys(), key=lambda d: (d == "cpu", d))
423-
424-
for dev in sorted_devices:
425-
frac = device_table[dev]["fraction"]
426-
tot_gb = device_table[dev]["total_gb"]
427-
alloc_gb = device_table[dev]["alloc_gb"]
428-
logger.info(fmt_assign.format(dev,f"{int(frac * 100)}%",f"{tot_gb:.2f}",f"{alloc_gb:.2f}"))
429-
430-
logger.info(dash_line)
431-
432-
# Get the model blocks using ComfyUI's method
433-
block_list = model_patcher._load_list()
434-
block_list.sort(reverse=True)
435-
436-
# Log layer distribution
437-
total_memory = sum(b[0] for b in block_list)
438-
memory_by_type = defaultdict(int)
439-
block_summary = defaultdict(int)
440-
for module_size, module_name, module_object, params in block_list:
441-
block_type = module_object.__class__.__name__
442-
block_summary[block_type] += 1
443-
memory_by_type[block_type] += module_size
444-
445-
# Log layer distribution - IDENTICAL FORMAT TO GGML
446-
logger.info(" DisTorch2 Model Layer Distribution")
447-
logger.info(dash_line)
448-
fmt_layer = "{:<18}{:>7}{:>14}{:>10}"
449-
logger.info(fmt_layer.format("Layer Type", "Layers", "Memory (MB)", "% Total"))
450-
logger.info(dash_line)
451-
452-
for layer_type, count in block_summary.items():
453-
mem_mb = memory_by_type[layer_type] / (1024 * 1024)
454-
mem_percent = (memory_by_type[layer_type] / total_memory) * 100 if total_memory > 0 else 0
455-
logger.info(fmt_layer.format(layer_type[:18], str(count), f"{mem_mb:.2f}", f"{mem_percent:.1f}%"))
456-
457-
logger.info(dash_line)
458-
459-
# Distribute blocks sequentially
460-
device_assignments = {device: [] for device in DEVICE_RATIOS_DISTORCH.keys()}
461-
block_assignments = {}
462-
463-
compute_device = str(current_device)
464-
# Calculate total memory to be offloaded to donor devices
465-
total_offload_gb = sum(DEVICE_RATIOS_DISTORCH.get(d, 0) for d in sorted_devices if d != compute_device)
466-
total_offload_bytes = total_offload_gb * (1024**3)
467-
468-
offloaded_bytes = 0
469-
470-
# Iterate through the sorted list (largest blocks first)
471-
for module_size, module_name, module_object, params in block_list:
472-
# Assign to donor device until target is met
473-
if offloaded_bytes < total_offload_bytes:
474-
# For now, simple offload to CPU, will expand for multi-donor
475-
donor_device = "cpu"
476-
for dev in sorted_devices:
477-
if dev != compute_device:
478-
donor_device = dev
479-
break # Use first available donor
480-
481-
block_assignments[module_name] = donor_device
482-
setattr(module_object, 'distorch2_cpu_offload', True) # Attach the attribute here
483-
offloaded_bytes += module_size
484-
else:
485-
# Assign remaining blocks to the primary compute device
486-
block_assignments[module_name] = compute_device
487-
488-
# Populate device_assignments from the final block_assignments
489-
for module_size, module_name, module_object, params in block_list:
490-
device = block_assignments[module_name]
491-
if device not in device_assignments:
492-
device_assignments[device] = []
493-
device_assignments[device].append((module_name, module_object, module_object.__class__.__name__, module_size))
494-
495-
# Log final assignments - IDENTICAL FORMAT TO GGML
496-
logger.info("DisTorch2 Model Final Device/Layer Assignments")
497-
logger.info(dash_line)
498-
logger.info(fmt_assign.format("Device", "Layers", "Memory (MB)", "% Total"))
499-
logger.info(dash_line)
500-
501-
# Log distributed blocks
502-
total_assigned_memory = 0
503-
device_memories = {}
504-
505-
for device, blocks in device_assignments.items():
506-
device_memory = sum(b[3] for b in blocks)
507-
device_memories[device] = device_memory
508-
total_assigned_memory += device_memory
509-
510-
sorted_assignments = sorted(device_memories.keys(), key=lambda d: (d == "cpu", d))
511-
512-
for dev in sorted_assignments:
513-
if dev not in device_memories:
514-
continue
515-
mem_mb = device_memories[dev] / (1024 * 1024)
516-
mem_percent = (device_memories[dev] / total_memory) * 100 if total_memory > 0 else 0
517-
logger.info(fmt_assign.format(dev, str(len(device_assignments[dev])), f"{mem_mb:.2f}", f"{mem_percent:.1f}%"))
518-
519-
logger.info(dash_line)
520-
521-
return {
522-
"device_assignments": device_assignments,
523-
"block_assignments": block_assignments,
524-
"lowvram_model_memory": total_assigned_memory,
525-
}
526380

527381
def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str):
528382
"""Calculate virtual VRAM allocation string for distributed safetensor loading"""
@@ -609,7 +463,6 @@ def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str):
609463

610464
return allocation_string
611465

612-
613466
def override_class_with_distorch_safetensor_v2(cls):
614467
"""DisTorch 2.0 wrapper for safetensor models"""
615468
from .nodes import get_device_list
File renamed without changes.

examples/distorch2/qwen_image_basic_example_DisTorch2.json renamed to examples/distorch2/qwen_image_basic_DisTorch2.json

File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)