Skip to content

Commit e28b040

Browse files
committed
Refactor DisTorchV2 loader to support both on-device (to avoid tensor mis-match on some models, but much slower patching) and on-compute (faster, highest fidelity for the combination of [fp8 model/LoRAs/store-on-CPU])
1 parent dfe6612 commit e28b040

1 file changed

Lines changed: 100 additions & 66 deletions

File tree

distorch_2.py

Lines changed: 100 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -63,86 +63,115 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
6363
# Check if we have a device allocation for this model
6464
debug_hash = create_safetensor_model_hash(self, "partial_load")
6565
allocations = safetensor_allocation_store.get(debug_hash)
66+
6667

6768
if not hasattr(self.model, '_distorch_high_precision_loras') or not allocations:
68-
result = original_partially_load(self, device_to, extra_memory, force_patch_weights)
69-
69+
result = original_partially_load(self, device_to, extra_memory, force_patch_weights)
7070
# Clean up
7171
if hasattr(self, '_distorch_block_assignments'):
7272
del self._distorch_block_assignments
73-
73+
74+
7475
return result
75-
76-
logger.info(f"[MultiGPU_DisTorch2] high_precision_loras flag not retrieved from model. DisTorchV2 Loader not used. Reverting to normal loading behavior")
76+
77+
logger.info(f"[MultiGPU_DisTorch2] DisTorchV2 Loader activated")
7778

7879
mem_counter = 0
7980
patch_counter = 0
8081

81-
loading = self._load_list()
82-
83-
load_completely = []
84-
loading.sort(reverse=True)
85-
for x in loading:
86-
n = x[1]
87-
m = x[2]
88-
params = x[3]
89-
module_mem = x[0]
90-
91-
weight_key = "{}.weight".format(n)
92-
bias_key = "{}.bias".format(n)
93-
94-
cast_weight = self.force_cast_weights
95-
96-
if hasattr(m, "comfy_cast_weights"):
97-
#logging.info(f"Unpatching weight {weight_key} for Distorch2")
98-
wipe_lowvram_weight(m)
99-
100-
#logging.info(f"Adding {n} to 'load_completely' list")
101-
mem_counter += module_mem
102-
load_completely.append((module_mem, n, m, params))
103-
104-
if cast_weight and hasattr(m, "comfy_cast_weights"):
105-
#logging.info(f"Setting cast weights for {weight_key}")
106-
m.prev_comfy_cast_weights = m.comfy_cast_weights
107-
m.comfy_cast_weights = True
108-
109-
if weight_key in self.weight_wrapper_patches:
110-
logging.info(f"Patching weight wrapper {m} for Distorch2")
111-
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
112-
113-
if bias_key in self.weight_wrapper_patches:
114-
logging.info(f"Patching bias wrapper {bias_key} for Distorch2")
115-
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
116-
117-
mem_counter += move_weight_functions(m, device_to)
118-
119-
12082
logger.info(f"[MultiGPU_DisTorch2] Using static allocation for model {debug_hash[:8]}")
12183
# Parse allocation string and apply static assignment
12284
device_assignments = analyze_safetensor_loading(self, allocations)
123-
124-
12585

126-
# Apply our static assignments instead of ComfyUI's dynamic ones
127-
for block_name, target_device in device_assignments['block_assignments'].items():
128-
# Find the module by name
129-
parts = block_name.split('.')
130-
module = self.model
131-
for part in parts:
132-
if hasattr(module, part):
133-
module = getattr(module, part)
134-
else:
135-
break
86+
model_type = type(self.model).__name__
87+
88+
if model_type == "SDXL" or model_type == "LTXV":
89+
on_compute_patching = False
90+
else:
91+
on_compute_patching = True
92+
93+
if on_compute_patching:
94+
high_precision_loras = self.model._distorch_high_precision_loras
95+
loading = self._load_list()
96+
loading.sort(reverse=True)
97+
for module_size, module_name, module_object, params in loading:
98+
# Step 1: Write block/tensor to compute device first
99+
module_object.to(device_to)
100+
101+
# Step 2: Apply LoRa patches while on compute device
102+
weight_key = "{}.weight".format(module_name)
103+
bias_key = "{}.bias".format(module_name)
104+
105+
if weight_key in self.patches:
106+
self.patch_weight_to_device(weight_key, device_to=device_to)
107+
if weight_key in self.weight_wrapper_patches:
108+
module_object.weight_function.extend(self.weight_wrapper_patches[weight_key])
109+
110+
if bias_key in self.patches:
111+
self.patch_weight_to_device(bias_key, device_to=device_to)
112+
if bias_key in self.weight_wrapper_patches:
113+
module_object.bias_function.extend(self.weight_wrapper_patches[bias_key])
114+
115+
# Step 3: FP8 casting for CPU storage (if enabled)
116+
block_target_device = device_assignments['block_assignments'].get(module_name, device_to)
117+
has_patches = weight_key in self.patches or bias_key in self.patches
118+
119+
logger.info(f"[MultiGPU_DisTorch2] Patch-on-Compute: Processing {module_name} -> block_target_device={block_target_device}")
120+
121+
if not high_precision_loras and block_target_device == "cpu" and has_patches and model_original_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
122+
logger.info(f"[MultiGPU_DisTorch2] FP8 casting conditions met for {module_name}")
123+
for param_name, param in module_object.named_parameters():
124+
if param.dtype.is_floating_point:
125+
cast_data = comfy.float.stochastic_rounding(param.data, torch.float8_e4m3fn)
126+
new_param = torch.nn.Parameter(cast_data.to(torch.float8_e4m3fn))
127+
new_param.requires_grad = param.requires_grad
128+
setattr(module_object, param_name, new_param)
129+
logger.debug(f"[MultiGPU_DisTorch2] Cast {module_name}.{param_name} to FP8 for CPU storage")
130+
131+
# Step 4: Move to ultimate destination based on DisTorch assignment
132+
if block_target_device != device_to:
133+
logger.debug(f"[MultiGPU_DisTorch2] Moving {module_name} from {device_to} to {block_target_device}")
134+
module_object.to(block_target_device)
135+
136+
# Mark as patched and update memory counter
137+
module_object.comfy_patched_weights = True
138+
mem_counter += module_size
139+
140+
logger.info(f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: {mem_counter / (1024 * 1024):.2f}MB")
141+
142+
else:
143+
# Apply our static assignments instead of ComfyUI's dynamic ones
144+
for block_name, target_device in device_assignments['block_assignments'].items():
145+
# Find the module by name
146+
parts = block_name.split('.')
147+
module = self.model
148+
for part in parts:
149+
if hasattr(module, part):
150+
module = getattr(module, part)
151+
else:
152+
break
153+
154+
if hasattr(module, 'weight') or hasattr(module, 'comfy_cast_weights'):
155+
# Move to our assigned device
156+
logger.info(f"[MultiGPU_DisTorch2] Patch-on-Device: Moving {block_name} to {target_device}")
157+
module.to(target_device)
158+
# Mark for ComfyUI's cast system if not already marked
159+
if hasattr(module, 'comfy_cast_weights'):
160+
module.comfy_cast_weights = True
136161

137-
if hasattr(module, 'weight') or hasattr(module, 'comfy_cast_weights'):
138-
# Move to our assigned device
139-
logger.debug(f"[MultiGPU_DisTorch2] Moving {block_name} to {target_device}")
140-
module.to(target_device)
141-
# Mark for ComfyUI's cast system if not already marked
142-
if hasattr(module, 'comfy_cast_weights'):
143-
module.comfy_cast_weights = True
144-
145-
# Return 0 to indicate no additional memory used on compute device
162+
weight_key = "{}.weight".format(block_name)
163+
bias_key = "{}.bias".format(block_name)
164+
165+
if weight_key in self.patches:
166+
self.patch_weight_to_device(weight_key, device_to=target_device)
167+
if weight_key in self.weight_wrapper_patches:
168+
module_object.weight_function.extend(self.weight_wrapper_patches[weight_key])
169+
170+
if bias_key in self.patches:
171+
self.patch_weight_to_device(bias_key, device_to=target_device)
172+
if bias_key in self.weight_wrapper_patches:
173+
module_object.bias_function.extend(self.weight_wrapper_patches[bias_key])
174+
146175
return 0
147176

148177

@@ -248,7 +277,12 @@ def analyze_safetensor_loading(model_patcher, allocations_str):
248277
device_assignments = {device: [] for device in DEVICE_RATIOS_DISTORCH.keys()}
249278
block_assignments = {}
250279

251-
compute_device = str(current_device)
280+
# Determine the primary compute device (first non-cpu device)
281+
compute_device = "cuda:0" # Fallback
282+
for dev in sorted_devices:
283+
if dev != "cpu":
284+
compute_device = dev
285+
break
252286

253287
# Calculate total memory to be offloaded to donor devices
254288
total_offload_gb = sum(DEVICE_RATIOS_DISTORCH.get(d, 0) for d in sorted_devices if d != compute_device)

0 commit comments

Comments
 (0)