Skip to content

Commit fabbc9e

Browse files
committed
Merge branch 'lora_reapply_fix'
2 parents e9fb4a8 + 5b62671 commit fabbc9e

1 file changed

Lines changed: 30 additions & 2 deletions

File tree

distorch_2.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,22 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
6767
allocations = safetensor_allocation_store.get(debug_hash)
6868

6969
if not hasattr(self.model, '_distorch_high_precision_loras') or not allocations:
70+
7071
result = original_partially_load(self, device_to, extra_memory, force_patch_weights)
7172
if hasattr(self, '_distorch_block_assignments'):
7273
del self._distorch_block_assignments
7374
return result
74-
75-
# soft_empty_cache_multigpu(logger)
75+
76+
if not hasattr(self.model, 'current_weight_patches_uuid'):
77+
self.model.current_weight_patches_uuid = None
78+
79+
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
80+
81+
if unpatch_weights:
82+
logger.info(f"[MultiGPU_DisTorch2] Patches changed or forced. Unpatching model.")
83+
self.unpatch_model(self.offload_device, unpatch_weights=True)
84+
85+
self.patch_model(load_weights=False)
7686

7787
mem_counter = 0
7888

@@ -83,6 +93,22 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
8393
loading = self._load_list()
8494
loading.sort(reverse=True)
8595
for module_size, module_name, module_object, params in loading:
96+
if not unpatch_weights and hasattr(module_object, "comfy_patched_weights") and module_object.comfy_patched_weights == True:
97+
block_target_device = device_assignments['block_assignments'].get(module_name, device_to)
98+
current_module_device = None
99+
try:
100+
if any(p.numel() > 0 for p in module_object.parameters(recurse=False)):
101+
current_module_device = next(module_object.parameters(recurse=False)).device
102+
except StopIteration:
103+
pass
104+
105+
if current_module_device is not None and str(current_module_device) != str(block_target_device):
106+
logger.debug(f"[MultiGPU_DisTorch2] Moving already patched {module_name} to {block_target_device}")
107+
module_object.to(block_target_device)
108+
109+
mem_counter += module_size
110+
continue
111+
86112
# Step 1: Write block/tensor to compute device first
87113
module_object.to(device_to)
88114

@@ -123,6 +149,8 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
123149
module_object.comfy_patched_weights = True
124150
mem_counter += module_size
125151

152+
self.model.current_weight_patches_uuid = self.patches_uuid
153+
126154
logger.info(f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: {mem_counter / (1024 * 1024):.2f}MB")
127155

128156
return 0

0 commit comments

Comments
 (0)