Skip to content

Commit 842ee65

Browse files
committed
Optimize DisTorchV2 loader and FP8 casting logic
- Remove redundant logging and counters in safetensor model patcher - Add model original dtype detection for better precision handling - Streamline FP8 casting conditions and remove verbose debug logs - Improve static allocation parsing and device assignment flow
1 parent afd8fec commit 842ee65

1 file changed

Lines changed: 1 addition & 11 deletions

File tree

distorch_2.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,29 +60,20 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
6060
"""Override to use our static device assignments"""
6161
global safetensor_allocation_store
6262

63-
# Check if we have a device allocation for this model
6463
debug_hash = create_safetensor_model_hash(self, "partial_load")
6564
allocations = safetensor_allocation_store.get(debug_hash)
66-
6765

6866
if not hasattr(self.model, '_distorch_high_precision_loras') or not allocations:
6967
result = original_partially_load(self, device_to, extra_memory, force_patch_weights)
70-
# Clean up
7168
if hasattr(self, '_distorch_block_assignments'):
7269
del self._distorch_block_assignments
73-
74-
7570
return result
7671

77-
logger.info(f"[MultiGPU_DisTorch2] DisTorchV2 Loader activated")
78-
7972
mem_counter = 0
80-
patch_counter = 0
8173

8274
logger.info(f"[MultiGPU_DisTorch2] Using static allocation for model {debug_hash[:8]}")
83-
# Parse allocation string and apply static assignment
8475
device_assignments = analyze_safetensor_loading(self, allocations)
85-
76+
model_original_dtype = comfy.utils.weight_dtype(self.model.state_dict())
8677
high_precision_loras = self.model._distorch_high_precision_loras
8778
loading = self._load_list()
8879
loading.sort(reverse=True)
@@ -109,7 +100,6 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
109100
has_patches = weight_key in self.patches or bias_key in self.patches
110101

111102
if not high_precision_loras and block_target_device == "cpu" and has_patches and model_original_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
112-
logger.info(f"[MultiGPU_DisTorch2] FP8 casting conditions met for {module_name}")
113103
for param_name, param in module_object.named_parameters():
114104
if param.dtype.is_floating_point:
115105
cast_data = comfy.float.stochastic_rounding(param.data, torch.float8_e4m3fn)

0 commit comments

Comments
 (0)