Skip to content

Commit afd8fec

Browse files
committed
Refactor DisTorch model patching logic for improved device assignment and FP8 casting
1 parent 4367c89 commit afd8fec

1 file changed

Lines changed: 46 additions & 88 deletions

File tree

distorch_2.py

Lines changed: 46 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -83,94 +83,52 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
8383
# Parse allocation string and apply static assignment
8484
device_assignments = analyze_safetensor_loading(self, allocations)
8585

86-
model_type = type(self.model).__name__
87-
88-
if model_type == "SDXL" or model_type == "LTXV":
89-
on_compute_patching = False
90-
else:
91-
on_compute_patching = True
92-
93-
if on_compute_patching:
94-
high_precision_loras = self.model._distorch_high_precision_loras
95-
loading = self._load_list()
96-
loading.sort(reverse=True)
97-
for module_size, module_name, module_object, params in loading:
98-
# Step 1: Write block/tensor to compute device first
99-
module_object.to(device_to)
100-
101-
# Step 2: Apply LoRa patches while on compute device
102-
weight_key = "{}.weight".format(module_name)
103-
bias_key = "{}.bias".format(module_name)
104-
105-
if weight_key in self.patches:
106-
self.patch_weight_to_device(weight_key, device_to=device_to)
107-
if weight_key in self.weight_wrapper_patches:
108-
module_object.weight_function.extend(self.weight_wrapper_patches[weight_key])
109-
110-
if bias_key in self.patches:
111-
self.patch_weight_to_device(bias_key, device_to=device_to)
112-
if bias_key in self.weight_wrapper_patches:
113-
module_object.bias_function.extend(self.weight_wrapper_patches[bias_key])
114-
115-
# Step 3: FP8 casting for CPU storage (if enabled)
116-
block_target_device = device_assignments['block_assignments'].get(module_name, device_to)
117-
has_patches = weight_key in self.patches or bias_key in self.patches
118-
119-
logger.info(f"[MultiGPU_DisTorch2] Patch-on-Compute: Processing {module_name} -> block_target_device={block_target_device}")
120-
121-
if not high_precision_loras and block_target_device == "cpu" and has_patches and model_original_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
122-
logger.info(f"[MultiGPU_DisTorch2] FP8 casting conditions met for {module_name}")
123-
for param_name, param in module_object.named_parameters():
124-
if param.dtype.is_floating_point:
125-
cast_data = comfy.float.stochastic_rounding(param.data, torch.float8_e4m3fn)
126-
new_param = torch.nn.Parameter(cast_data.to(torch.float8_e4m3fn))
127-
new_param.requires_grad = param.requires_grad
128-
setattr(module_object, param_name, new_param)
129-
logger.debug(f"[MultiGPU_DisTorch2] Cast {module_name}.{param_name} to FP8 for CPU storage")
130-
131-
# Step 4: Move to ultimate destination based on DisTorch assignment
132-
if block_target_device != device_to:
133-
logger.debug(f"[MultiGPU_DisTorch2] Moving {module_name} from {device_to} to {block_target_device}")
134-
module_object.to(block_target_device)
135-
136-
# Mark as patched and update memory counter
137-
module_object.comfy_patched_weights = True
138-
mem_counter += module_size
139-
140-
logger.info(f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: {mem_counter / (1024 * 1024):.2f}MB")
141-
142-
else:
143-
# Apply our static assignments instead of ComfyUI's dynamic ones
144-
for block_name, target_device in device_assignments['block_assignments'].items():
145-
# Find the module by name
146-
parts = block_name.split('.')
147-
module = self.model
148-
for part in parts:
149-
if hasattr(module, part):
150-
module = getattr(module, part)
151-
else:
152-
break
153-
154-
if hasattr(module, 'weight') or hasattr(module, 'comfy_cast_weights'):
155-
# Move to our assigned device
156-
logger.info(f"[MultiGPU_DisTorch2] Patch-on-Device: Moving {block_name} to {target_device}")
157-
module.to(target_device)
158-
# Mark for ComfyUI's cast system if not already marked
159-
if hasattr(module, 'comfy_cast_weights'):
160-
module.comfy_cast_weights = True
161-
162-
weight_key = "{}.weight".format(block_name)
163-
bias_key = "{}.bias".format(block_name)
164-
165-
if weight_key in self.patches:
166-
self.patch_weight_to_device(weight_key, device_to=target_device)
167-
if weight_key in self.weight_wrapper_patches:
168-
module_object.weight_function.extend(self.weight_wrapper_patches[weight_key])
169-
170-
if bias_key in self.patches:
171-
self.patch_weight_to_device(bias_key, device_to=target_device)
172-
if bias_key in self.weight_wrapper_patches:
173-
module_object.bias_function.extend(self.weight_wrapper_patches[bias_key])
86+
high_precision_loras = self.model._distorch_high_precision_loras
87+
loading = self._load_list()
88+
loading.sort(reverse=True)
89+
for module_size, module_name, module_object, params in loading:
90+
# Step 1: Write block/tensor to compute device first
91+
module_object.to(device_to)
92+
93+
# Step 2: Apply LoRa patches while on compute device
94+
weight_key = "{}.weight".format(module_name)
95+
bias_key = "{}.bias".format(module_name)
96+
97+
if weight_key in self.patches:
98+
self.patch_weight_to_device(weight_key, device_to=device_to)
99+
if weight_key in self.weight_wrapper_patches:
100+
module_object.weight_function.extend(self.weight_wrapper_patches[weight_key])
101+
102+
if bias_key in self.patches:
103+
self.patch_weight_to_device(bias_key, device_to=device_to)
104+
if bias_key in self.weight_wrapper_patches:
105+
module_object.bias_function.extend(self.weight_wrapper_patches[bias_key])
106+
107+
# Step 3: FP8 casting for CPU storage (if enabled)
108+
block_target_device = device_assignments['block_assignments'].get(module_name, device_to)
109+
has_patches = weight_key in self.patches or bias_key in self.patches
110+
111+
if not high_precision_loras and block_target_device == "cpu" and has_patches and model_original_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
112+
logger.info(f"[MultiGPU_DisTorch2] FP8 casting conditions met for {module_name}")
113+
for param_name, param in module_object.named_parameters():
114+
if param.dtype.is_floating_point:
115+
cast_data = comfy.float.stochastic_rounding(param.data, torch.float8_e4m3fn)
116+
new_param = torch.nn.Parameter(cast_data.to(torch.float8_e4m3fn))
117+
new_param.requires_grad = param.requires_grad
118+
setattr(module_object, param_name, new_param)
119+
logger.debug(f"[MultiGPU_DisTorch2] Cast {module_name}.{param_name} to FP8 for CPU storage")
120+
121+
# Step 4: Move to ultimate destination based on DisTorch assignment
122+
if block_target_device != device_to:
123+
logger.debug(f"[MultiGPU_DisTorch2] Moving {module_name} from {device_to} to {block_target_device}")
124+
module_object.to(block_target_device)
125+
module_object.comfy_cast_weights = True
126+
127+
# Mark as patched and update memory counter
128+
module_object.comfy_patched_weights = True
129+
mem_counter += module_size
130+
131+
logger.info(f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: {mem_counter / (1024 * 1024):.2f}MB")
174132

175133
return 0
176134

0 commit comments

Comments
 (0)