Skip to content

Commit edc8a4d

Browse files
committed
Identified a long-standing bug where fully-allocated CLIP (for example 99G of VirtualVRAM = 100% of major blocks no matter the model) proceeded to execute on the donor device (e.g. cpu) instead of the indicated compute device. Turns out, it only happens when *all* blocks are identified to go onto the donor card. In the case of the donor being the cpu this was irritatingly slow.
On a 4x PCIe bus, swapping a normal CLIP-sized number of layers once/twice (for neg) into compute should be the optimal solution: Reside on `cpu`, use the optimized cuda kernals for computation JiT on `compute`, discard layers once used (residing permenantly on `cpu`), then move efficently to the main UNet computation.
1 parent d34a32f commit edc8a4d

1 file changed

Lines changed: 214 additions & 2 deletions

File tree

distorch_2.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,14 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
8686

8787
mem_counter = 0
8888

89-
logger.info(f"[MultiGPU_DisTorch2] Using static allocation for model {debug_hash[:8]}")
90-
device_assignments = analyze_safetensor_loading(self, allocations)
89+
is_clip_model = getattr(self, 'is_clip', False)
90+
if is_clip_model:
91+
logger.info(f"[MultiGPU_DisTorch2] Using CLIP-specific allocation for model {debug_hash[:8]} (HEAD PRESERVATION ENABLED)")
92+
device_assignments = analyze_safetensor_loading_clip(self, allocations)
93+
else:
94+
logger.debug(f"[MultiGPU_DisTorch2] Using standard allocation for model {debug_hash[:8]} (UNET/VAE - UNTOUCHED)")
95+
device_assignments = analyze_safetensor_loading(self, allocations)
96+
9197
model_original_dtype = comfy.utils.weight_dtype(self.model.state_dict())
9298
high_precision_loras = self.model._distorch_high_precision_loras
9399
loading = self._load_list()
@@ -164,6 +170,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
164170
def analyze_safetensor_loading(model_patcher, allocations_string):
165171
"""
166172
Analyze and distribute safetensor model blocks across devices
173+
Target for refactor back into one function once stability for CLIP is established.
167174
"""
168175
DEVICE_RATIOS_DISTORCH = {}
169176
device_table = {}
@@ -366,6 +373,211 @@ def analyze_safetensor_loading(model_patcher, allocations_string):
366373
"block_assignments": block_assignments
367374
}
368375

376+
377+
def analyze_safetensor_loading_clip(model_patcher, allocations_string):
378+
"""
379+
CLIP-SPECIFIC: A 1:1 clone of the working UNET allocation logic with the
380+
single required modification to preserve head-blocks on the compute device.
381+
All other logic and UX (logging, etc.) is identical to the original.
382+
Target for refactor once stability for CLIP is established.
383+
"""
384+
DEVICE_RATIOS_DISTORCH = {}
385+
device_table = {}
386+
distorch_alloc = allocations_string
387+
virtual_vram_gb = 0.0
388+
389+
distorch_alloc, virtual_vram_str = allocations_string.split('#')
390+
391+
compute_device = virtual_vram_str.split(';')[0]
392+
393+
logger.info(f"[MultiGPU_DisTorch2_CLIP] CLIP Compute Device: {compute_device}")
394+
395+
if not distorch_alloc:
396+
mode = "fraction"
397+
logger.info("[MultiGPU_DisTorch2_CLIP] Expert String Examples:")
398+
logger.info(" Direct(byte) Mode - cuda:0,500mb;cuda:1,3.0g;cpu,5gb* -> '*' cpu = over/underflow device, put 0.50gb on cuda0, 3.00gb on cuda1, and 5.00gb (or the rest) on cpu")
399+
logger.info(" Ratio(%) Mode - cuda:0,8%;cuda:1,8%;cpu,4% -> 8:8:4 ratio, put 40% on cuda0, 40% on cuda1, and 20% on cpu")
400+
distorch_alloc = calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str)
401+
402+
elif any(c in distorch_alloc.lower() for c in ['g', 'm', 'k', 'b']):
403+
mode = "byte"
404+
distorch_alloc = calculate_fraction_from_byte_expert_string(model_patcher, distorch_alloc)
405+
elif "%" in distorch_alloc:
406+
mode = "ratio"
407+
distorch_alloc = calculate_fraction_from_ratio_expert_string(model_patcher, distorch_alloc)
408+
409+
all_devices = get_device_list()
410+
present_devices = {item.split(',')[0] for item in distorch_alloc.split(';') if ',' in item}
411+
for device in all_devices:
412+
if device not in present_devices:
413+
distorch_alloc += f";{device},0.0"
414+
415+
logger.info(f"[MultiGPU_DisTorch2_CLIP] Final CLIP Allocation String: {distorch_alloc}")
416+
417+
eq_line = "=" * 50
418+
dash_line = "-" * 50
419+
fmt_assign = "{:<18}{:>7}{:>14}{:>10}"
420+
421+
for allocation in distorch_alloc.split(';'):
422+
if ',' not in allocation:
423+
continue
424+
dev_name, fraction = allocation.split(',')
425+
fraction = float(fraction)
426+
total_mem_bytes = mm.get_total_memory(torch.device(dev_name))
427+
alloc_gb = (total_mem_bytes * fraction) / (1024**3)
428+
DEVICE_RATIOS_DISTORCH[dev_name] = alloc_gb
429+
device_table[dev_name] = {
430+
"fraction": fraction,
431+
"total_gb": total_mem_bytes / (1024**3),
432+
"alloc_gb": alloc_gb
433+
}
434+
435+
logger.info(eq_line)
436+
logger.info(" DisTorch2 CLIP Model Device Allocations")
437+
logger.info(eq_line)
438+
439+
fmt_rosetta = "{:<8}{:>9}{:>9}{:>11}{:>10}"
440+
logger.info(fmt_rosetta.format("Device", "VRAM GB", "Dev %", "Model GB", "Dist %"))
441+
logger.info(dash_line)
442+
443+
sorted_devices = sorted(device_table.keys(), key=lambda d: (d == "cpu", d))
444+
445+
total_allocated_model_bytes = sum(d["alloc_gb"] * (1024**3) for d in device_table.values())
446+
447+
for dev in sorted_devices:
448+
total_dev_gb = device_table[dev]["total_gb"]
449+
alloc_fraction = device_table[dev]["fraction"]
450+
alloc_gb = device_table[dev]["alloc_gb"]
451+
452+
dist_ratio_percent = (alloc_gb * (1024**3) / total_allocated_model_bytes) * 100 if total_allocated_model_bytes > 0 else 0
453+
454+
logger.info(fmt_rosetta.format(
455+
dev,
456+
f"{total_dev_gb:.2f}",
457+
f"{alloc_fraction*100:.1f}%",
458+
f"{alloc_gb:.2f}",
459+
f"{dist_ratio_percent:.1f}%"
460+
))
461+
462+
logger.info(dash_line)
463+
464+
block_summary = {}
465+
memory_by_type = defaultdict(int)
466+
467+
raw_block_list = model_patcher._load_list()
468+
total_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
469+
470+
# Split the model into head and distributable parts
471+
head_keywords = ['embed', 'wte', 'wpe', 'token_embedding', 'position_embedding']
472+
head_blocks = []
473+
distributable_blocks_raw = []
474+
head_memory = 0
475+
476+
for module_size, module_name, module_object, params in raw_block_list:
477+
if any(keyword in module_name.lower() for keyword in head_keywords):
478+
head_blocks.append((module_size, module_name, module_object, params))
479+
else:
480+
distributable_blocks_raw.append((module_size, module_name, module_object, params))
481+
482+
MIN_BLOCK_THRESHOLD = total_memory * 0.0001
483+
all_blocks = []
484+
485+
for module_size, module_name, module_object, params in raw_block_list:
486+
block_type = type(module_object).__name__
487+
block_summary[block_type] = block_summary.get(block_type, 0) + 1
488+
memory_by_type[block_type] += module_size
489+
all_blocks.append((module_name, module_object, block_type, module_size))
490+
491+
# Use the distributable part for actual allocation logic
492+
distributable_all_blocks = []
493+
for module_size, module_name, module_object, params in distributable_blocks_raw:
494+
distributable_all_blocks.append((module_name, module_object, type(module_object).__name__, module_size))
495+
496+
block_list = [b for b in distributable_all_blocks if b[3] >= MIN_BLOCK_THRESHOLD]
497+
tiny_block_list = [b for b in distributable_all_blocks if b[3] < MIN_BLOCK_THRESHOLD]
498+
499+
logger.info(" DisTorch2 CLIP Model Layer Distribution")
500+
logger.info(dash_line)
501+
fmt_layer = "{:<18}{:>7}{:>14}{:>10}"
502+
logger.info(fmt_layer.format("Layer Type", "Layers", "Memory (MB)", "% Total"))
503+
logger.info(dash_line)
504+
505+
for layer_type, count in block_summary.items():
506+
mem_mb = memory_by_type[layer_type] / (1024 * 1024)
507+
mem_percent = (memory_by_type[layer_type] / total_memory) * 100 if total_memory > 0 else 0
508+
logger.info(fmt_layer.format(layer_type[:18], str(count), f"{mem_mb:.2f}", f"{mem_percent:.1f}%"))
509+
510+
logger.info(dash_line)
511+
512+
block_assignments = {}
513+
514+
# Pre-assign head blocks and calculate their memory usage
515+
for module_size, module_name, module_object, params in head_blocks:
516+
block_assignments[module_name] = compute_device
517+
head_memory += module_size
518+
if head_blocks:
519+
logger.info(f"[MultiGPU_DisTorch2_CLIP] Preserving {len(head_blocks)} head layer(s) ({head_memory / (1024*1024):.2f} MB) on compute device: {compute_device}")
520+
donor_devices = [d for d in sorted_devices]
521+
donor_quotas = {
522+
dev: device_table[dev]["alloc_gb"] * (1024**3)
523+
for dev in donor_devices
524+
}
525+
# Adjust compute_device quota to account for the locked head
526+
if compute_device in donor_quotas:
527+
donor_quotas[compute_device] = max(0, donor_quotas[compute_device] - head_memory)
528+
529+
for block_name, module, block_type, block_memory in reversed(block_list):
530+
assigned_to_donor = False
531+
for donor in donor_devices:
532+
if donor_quotas[donor] >= block_memory:
533+
block_assignments[block_name] = donor
534+
donor_quotas[donor] -= block_memory
535+
assigned_to_donor = True
536+
break # Move to the next block
537+
538+
if not assigned_to_donor:
539+
block_assignments[block_name] = compute_device
540+
541+
for block_name, module, block_type, block_memory in tiny_block_list:
542+
block_assignments[block_name] = compute_device
543+
544+
device_assignments = {device: [] for device in DEVICE_RATIOS_DISTORCH.keys()}
545+
for block_name, device in block_assignments.items():
546+
# Find the block in the original list to get all its info
547+
for b_name, b_module, b_type, b_mem in all_blocks:
548+
if b_name == block_name:
549+
device_assignments[device].append((b_name, b_module, b_type, b_mem))
550+
break
551+
552+
logger.info("DisTorch2 CLIP Model Final Device/Layer Assignments")
553+
logger.info(dash_line)
554+
logger.info(fmt_assign.format("Device", "Layers", "Memory (MB)", "% Total"))
555+
logger.info(dash_line)
556+
557+
device_memories = defaultdict(int)
558+
device_counts = defaultdict(int)
559+
for device, blocks in device_assignments.items():
560+
for b_name, b_module, b_type, b_mem in blocks:
561+
device_memories[device] += b_mem
562+
device_counts[device] += 1
563+
564+
sorted_assignments = sorted(device_memories.keys(), key=lambda d: (d == "cpu", d))
565+
566+
for dev in sorted_assignments:
567+
if device_counts[dev] == 0:
568+
continue
569+
mem_mb = device_memories[dev] / (1024 * 1024)
570+
mem_percent = (device_memories[dev] / total_memory) * 100 if total_memory > 0 else 0
571+
logger.info(fmt_assign.format(dev, str(device_counts[dev]), f"{mem_mb:.2f}", f"{mem_percent:.1f}%"))
572+
573+
logger.info(dash_line)
574+
575+
return {
576+
"device_assignments": device_assignments,
577+
"block_assignments": block_assignments
578+
}
579+
580+
369581
def parse_memory_string(mem_str):
370582
"""Parses a memory string (e.g., '4.0g', '512M') and returns bytes."""
371583
mem_str = mem_str.strip().lower()

0 commit comments

Comments
 (0)