Skip to content

Commit d1c88a7

Browse files
committed
feat(distorch): Add universal .safetensors support & memory-based distribution
This commit introduces DisTorch v2.0.0, a major overhaul that extends multi-device model distribution to standard `.safetensors` models. Key changes include: - **Universal `.safetensors` Support:** The core distribution logic is no longer limited to GGUF models. It now fully supports `.safetensors`, allowing any UNet supported by native Comfy loaders to have its layers distributed across multiple devices (GPUs and CPU/RAM).
1 parent fb6e2e6 commit d1c88a7

5 files changed

Lines changed: 114 additions & 102 deletions

File tree

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ComfyUI-MultiGPU: Tools to free up your primary GPU’s VRAM by using your CPU or additional GPUs, now with tighter integration into kijai's WanVideoWrapper[^1]
1+
# ComfyUI-MultiGPU v2.0.0: Universal `.safetensors` and GGUF Multi-GPU Distribution with DisTorch
22
<p align="center">
33
<img src="https://raw.githubusercontent.com/pollockjj/ComfyUI-MultiGPU/main/assets/distorch_average.png" width="600">
44
<br>
@@ -18,10 +18,10 @@ ComfyUI-MultiGPU now includes a custom, tightly integrated implementation for Wa
1818
## The Core of ComfyUI-MultiGPU:
1919
[^1]: This **enhances memory management,** not parallel processing. Workflow steps still execute sequentially, but with components (in full or in part) loaded across your specified devices. *Performance gains* come from avoiding repeated model loading/unloading when VRAM is constrained. *Capability gains* come from offloading as much of the model (VAE/CLIP/UNet) off of your main **compute** device as possible—allowing you to maximize latent space for actual computation.
2020

21-
1. **DisTorch Virtual VRAM for UNet Loaders**: Move UNet layers off your compute GPU
22-
- Automatic distribution to RAM or other GPUs
23-
- One-number control of VRAM usage
24-
- Support for all GGUF models
21+
1. **DisTorch Virtual VRAM for `.safetensors` and GGUF Models**: Move model layers off your compute GPU
22+
- Automatic, memory-size based distribution to RAM or other GPUs
23+
- One-number control of VRAM usage
24+
- Universal support for `.safetensors` and GGUF models
2525

2626
2. **CLIP Offloading**: Two solutions for LLM-based and standard CLIP models:
2727
- **MultiGPU CLIP**: Full offload to CPU or secondary GPU
@@ -83,7 +83,7 @@ With a 12GB GPU running an 8GB model:
8383
- Your GPU now has extra VRAM for larger batches, higher resolutions, or longer video
8484

8585
## 🚀 Compatibility
86-
Works with all GGUF-quantized ComfyUI/ComfyUI-GGUF-supported UNet/CLIP models.
86+
Works with all `.safetensors` and GGUF-quantized models.
8787

8888
⚙️ Expert users: For those of you who were here for the 1.0 release of DisTorch, manual allocation strings still available for advanced configurations. Each log will contain the allocation string for the run so it can be easily recreated and/or manipulated for more sophisticated setups.
8989

__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,20 @@
99

1010
# --- DisTorch V2 Logging Configuration ---
1111
# Set to "E" for Engineering (DEBUG) or "P" for Production (INFO)
12-
LOG_LEVEL = "E"
12+
LOG_LEVEL = "P"
1313

1414
# Configure logger
15-
log_level = logging.DEBUG if LOG_LEVEL == "E" else logging.INFO
16-
logging.basicConfig(level=log_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
17-
logger = logging.getLogger(__name__)
15+
logger = logging.getLogger("MultiGPU")
16+
logger.propagate = False
17+
18+
if not logger.handlers:
19+
log_level = logging.DEBUG if LOG_LEVEL == "E" else logging.INFO
20+
handler = logging.StreamHandler()
21+
formatter = logging.Formatter('%(message)s')
22+
handler.setFormatter(formatter)
23+
logger.addHandler(handler)
24+
logger.setLevel(log_level)
25+
1826
# --- End Logging Configuration ---
1927

2028
# Global device state management

distorch.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
import logging
99
import hashlib
10+
11+
logger = logging.getLogger("MultiGPU")
1012
import copy
1113
from collections import defaultdict
1214
import comfy.model_management as mm
@@ -22,7 +24,7 @@ def create_model_hash(model, caller):
2224
first_layers = str(list(model.model_state_dict().keys())[:3])
2325
identifier = f"{model_type}_{model_size}_{first_layers}"
2426
final_hash = hashlib.sha256(identifier.encode()).hexdigest()
25-
logging.debug(f"[MultiGPU_DisTorch_HASH] Created hash for {caller}: {final_hash[:8]}...")
27+
logger.debug(f"[MultiGPU_DisTorch_HASH] Created hash for {caller}: {final_hash[:8]}...")
2628
return final_hash
2729

2830

@@ -99,22 +101,21 @@ def analyze_ggml_loading(model, allocations_str):
99101
"alloc_gb": alloc_gb
100102
}
101103

102-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
103-
logging.info(eq_line)
104-
logging.info(" DisTorch Model Device Allocations")
105-
logging.info(eq_line)
106-
logging.info(fmt_assign.format("Device", "Alloc %", "Total (GB)", " Alloc (GB)"))
107-
logging.info(dash_line)
104+
logger.info(eq_line)
105+
logger.info(" DisTorch Model Device Allocations")
106+
logger.info(eq_line)
107+
logger.info(fmt_assign.format("Device", "Alloc %", "Total (GB)", " Alloc (GB)"))
108+
logger.info(dash_line)
108109

109110
sorted_devices = sorted(device_table.keys(), key=lambda d: (d == "cpu", d))
110111

111112
for dev in sorted_devices:
112113
frac = device_table[dev]["fraction"]
113114
tot_gb = device_table[dev]["total_gb"]
114115
alloc_gb = device_table[dev]["alloc_gb"]
115-
logging.info(fmt_assign.format(dev,f"{int(frac * 100)}%",f"{tot_gb:.2f}",f"{alloc_gb:.2f}"))
116+
logger.info(fmt_assign.format(dev,f"{int(frac * 100)}%",f"{tot_gb:.2f}",f"{alloc_gb:.2f}"))
116117

117-
logging.info(dash_line)
118+
logger.info(dash_line)
118119

119120
layer_summary = {}
120121
layer_list = []
@@ -134,16 +135,16 @@ def analyze_ggml_loading(model, allocations_str):
134135
memory_by_type[layer_type] += layer_memory
135136
total_memory += layer_memory
136137

137-
logging.info(" DisTorch Model Layer Distribution")
138-
logging.info(dash_line)
138+
logger.info(" DisTorch Model Layer Distribution")
139+
logger.info(dash_line)
139140
fmt_layer = "{:<12}{:>10}{:>14}{:>10}"
140-
logging.info(fmt_layer.format("Layer Type", "Layers", "Memory (MB)", "% Total"))
141-
logging.info(dash_line)
141+
logger.info(fmt_layer.format("Layer Type", "Layers", "Memory (MB)", "% Total"))
142+
logger.info(dash_line)
142143
for layer_type, count in layer_summary.items():
143144
mem_mb = memory_by_type[layer_type] / (1024 * 1024)
144145
mem_percent = (memory_by_type[layer_type] / total_memory) * 100 if total_memory > 0 else 0
145-
logging.info(fmt_layer.format(layer_type,str(count),f"{mem_mb:.2f}",f"{mem_percent:.1f}%"))
146-
logging.info(dash_line)
146+
logger.info(fmt_layer.format(layer_type,str(count),f"{mem_mb:.2f}",f"{mem_percent:.1f}%"))
147+
logger.info(dash_line)
147148

148149
nonzero_devices = [d for d, r in DEVICE_RATIOS_DISTORCH.items() if r > 0]
149150
nonzero_total_ratio = sum(DEVICE_RATIOS_DISTORCH[d] for d in nonzero_devices)
@@ -162,11 +163,11 @@ def analyze_ggml_loading(model, allocations_str):
162163
device_assignments[device] = layer_list[start_idx:end_idx]
163164
current_layer += device_layer_count
164165

165-
logging.info("DisTorch Model Final Device/Layer Assignments")
166-
logging.info(dash_line)
166+
logger.info("DisTorch Model Final Device/Layer Assignments")
167+
logger.info(dash_line)
167168
fmt_assign = "{:<12}{:>10}{:>14}{:>10}"
168-
logging.info(fmt_assign.format("Device", "Layers", "Memory (MB)", "% Total"))
169-
logging.info(dash_line)
169+
logger.info(fmt_assign.format("Device", "Layers", "Memory (MB)", "% Total"))
170+
logger.info(dash_line)
170171
total_assigned_memory = 0
171172
device_memories = {}
172173
for device, layers in device_assignments.items():
@@ -185,8 +186,8 @@ def analyze_ggml_loading(model, allocations_str):
185186
layers = device_assignments[dev]
186187
mem_mb = device_memories[dev] / (1024 * 1024)
187188
mem_percent = (device_memories[dev] / total_memory) * 100 if total_memory > 0 else 0
188-
logging.info(fmt_assign.format(dev,str(len(layers)),f"{mem_mb:.2f}",f"{mem_percent:.1f}%"))
189-
logging.info(dash_line)
189+
logger.info(fmt_assign.format(dev,str(len(layers)),f"{mem_mb:.2f}",f"{mem_percent:.1f}%"))
190+
logger.info(dash_line)
190191

191192
return {"device_assignments": device_assignments}
192193

@@ -200,17 +201,16 @@ def calculate_vvram_allocation_string(model, virtual_vram_str):
200201
dash_line = "-" * 47
201202
fmt_assign = "{:<8} {:<6} {:>11} {:>9} {:>9}"
202203

203-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
204-
logging.info(eq_line)
205-
logging.info(" DisTorch Model Virtual VRAM Analysis")
206-
logging.info(eq_line)
207-
logging.info(fmt_assign.format("Object", "Role", "Original(GB)", "Total(GB)", "Virt(GB)"))
208-
logging.info(dash_line)
204+
logger.info(eq_line)
205+
logger.info(" DisTorch Model Virtual VRAM Analysis")
206+
logger.info(eq_line)
207+
logger.info(fmt_assign.format("Object", "Role", "Original(GB)", "Total(GB)", "Virt(GB)"))
208+
logger.info(dash_line)
209209

210210
recipient_vram = mm.get_total_memory(torch.device(recipient_device)) / (1024**3)
211211
recipient_virtual = recipient_vram + virtual_vram_gb
212212

213-
logging.info(fmt_assign.format(recipient_device, 'recip', f"{recipient_vram:.2f}GB",f"{recipient_virtual:.2f}GB", f"+{virtual_vram_gb:.2f}GB"))
213+
logger.info(fmt_assign.format(recipient_device, 'recip', f"{recipient_vram:.2f}GB",f"{recipient_virtual:.2f}GB", f"+{virtual_vram_gb:.2f}GB"))
214214

215215
ram_donors = [d for d in donors.split(',') if d != 'cpu']
216216
remaining_vram_needed = virtual_vram_gb
@@ -228,15 +228,15 @@ def calculate_vvram_allocation_string(model, virtual_vram_str):
228228
donor_allocations[donor] = donation
229229

230230
donor_device_info[donor] = (donor_vram, donor_virtual)
231-
logging.info(fmt_assign.format(donor, 'donor', f"{donor_vram:.2f}GB", f"{donor_virtual:.2f}GB", f"-{donation:.2f}GB"))
231+
logger.info(fmt_assign.format(donor, 'donor', f"{donor_vram:.2f}GB", f"{donor_virtual:.2f}GB", f"-{donation:.2f}GB"))
232232

233233
system_dram_gb = mm.get_total_memory(torch.device('cpu')) / (1024**3)
234234
cpu_donation = remaining_vram_needed
235235
cpu_virtual = system_dram_gb - cpu_donation
236236
donor_allocations['cpu'] = cpu_donation
237-
logging.info(fmt_assign.format('cpu', 'donor', f"{system_dram_gb:.2f}GB", f"{cpu_virtual:.2f}GB", f"-{cpu_donation:.2f}GB"))
237+
logger.info(fmt_assign.format('cpu', 'donor', f"{system_dram_gb:.2f}GB", f"{cpu_virtual:.2f}GB", f"-{cpu_donation:.2f}GB"))
238238

239-
logging.info(dash_line)
239+
logger.info(dash_line)
240240

241241
layer_summary = {}
242242
layer_list = []
@@ -259,12 +259,12 @@ def calculate_vvram_allocation_string(model, virtual_vram_str):
259259
model_size_gb = total_memory / (1024**3)
260260
new_model_size_gb = max(0, model_size_gb - virtual_vram_gb)
261261

262-
logging.info(fmt_assign.format('model', 'model', f"{model_size_gb:.2f}GB",f"{new_model_size_gb:.2f}GB", f"-{virtual_vram_gb:.2f}GB"))
262+
logger.info(fmt_assign.format('model', 'model', f"{model_size_gb:.2f}GB",f"{new_model_size_gb:.2f}GB", f"-{virtual_vram_gb:.2f}GB"))
263263

264264
if model_size_gb > (recipient_vram * 0.9):
265265
on_recipient = recipient_vram * 0.9
266266
on_virtuals = model_size_gb - on_recipient
267-
logging.info(f"\nWarning: Model size is greater than 90% of recipient VRAM. {on_virtuals:.2f} GB of GGML Layers Offloaded Automatically to Virtual VRAM.\n")
267+
logger.info(f"\nWarning: Model size is greater than 90% of recipient VRAM. {on_virtuals:.2f} GB of GGML Layers Offloaded Automatically to Virtual VRAM.\n")
268268
else:
269269
on_recipient = model_size_gb
270270
on_virtuals = 0
@@ -285,7 +285,7 @@ def calculate_vvram_allocation_string(model, virtual_vram_str):
285285

286286
allocation_string = ";".join(allocation_parts)
287287
fmt_mem = "{:<20}{:>20}"
288-
logging.info(fmt_mem.format("\n v1 Expert String", allocation_string))
288+
logger.info(fmt_mem.format("\n v1 Expert String", allocation_string))
289289

290290
return allocation_string
291291

@@ -390,7 +390,7 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0,
390390

391391
full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else ""
392392

393-
logging.info(f"[MultiGPU_DisTorch] Full allocation string: {full_allocation}")
393+
logger.info(f"[MultiGPU_DisTorch] Full allocation string: {full_allocation}")
394394

395395
if hasattr(out[0], 'model'):
396396
model_hash = create_model_hash(out[0], "override")

0 commit comments

Comments
 (0)