@@ -67,12 +67,22 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
6767 allocations = safetensor_allocation_store .get (debug_hash )
6868
6969 if not hasattr (self .model , '_distorch_high_precision_loras' ) or not allocations :
70+
7071 result = original_partially_load (self , device_to , extra_memory , force_patch_weights )
7172 if hasattr (self , '_distorch_block_assignments' ):
7273 del self ._distorch_block_assignments
7374 return result
74-
75- # soft_empty_cache_multigpu(logger)
75+
76+ if not hasattr (self .model , 'current_weight_patches_uuid' ):
77+ self .model .current_weight_patches_uuid = None
78+
79+ unpatch_weights = self .model .current_weight_patches_uuid is not None and (self .model .current_weight_patches_uuid != self .patches_uuid or force_patch_weights )
80+
81+ if unpatch_weights :
82+ logger .info (f"[MultiGPU_DisTorch2] Patches changed or forced. Unpatching model." )
83+ self .unpatch_model (self .offload_device , unpatch_weights = True )
84+
85+ self .patch_model (load_weights = False )
7686
7787 mem_counter = 0
7888
@@ -83,6 +93,22 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
8393 loading = self ._load_list ()
8494 loading .sort (reverse = True )
8595 for module_size , module_name , module_object , params in loading :
96+ if not unpatch_weights and hasattr (module_object , "comfy_patched_weights" ) and module_object .comfy_patched_weights == True :
97+ block_target_device = device_assignments ['block_assignments' ].get (module_name , device_to )
98+ current_module_device = None
99+ try :
100+ if any (p .numel () > 0 for p in module_object .parameters (recurse = False )):
101+ current_module_device = next (module_object .parameters (recurse = False )).device
102+ except StopIteration :
103+ pass
104+
105+ if current_module_device is not None and str (current_module_device ) != str (block_target_device ):
106+ logger .debug (f"[MultiGPU_DisTorch2] Moving already patched { module_name } to { block_target_device } " )
107+ module_object .to (block_target_device )
108+
109+ mem_counter += module_size
110+ continue
111+
86112 # Step 1: Write block/tensor to compute device first
87113 module_object .to (device_to )
88114
@@ -123,6 +149,8 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
123149 module_object .comfy_patched_weights = True
124150 mem_counter += module_size
125151
152+ self .model .current_weight_patches_uuid = self .patches_uuid
153+
126154 logger .info (f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: { mem_counter / (1024 * 1024 ):.2f} MB" )
127155
128156 return 0
0 commit comments