@@ -63,86 +63,115 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
6363 # Check if we have a device allocation for this model
6464 debug_hash = create_safetensor_model_hash (self , "partial_load" )
6565 allocations = safetensor_allocation_store .get (debug_hash )
66+
6667
6768 if not hasattr (self .model , '_distorch_high_precision_loras' ) or not allocations :
68- result = original_partially_load (self , device_to , extra_memory , force_patch_weights )
69-
69+ result = original_partially_load (self , device_to , extra_memory , force_patch_weights )
7070 # Clean up
7171 if hasattr (self , '_distorch_block_assignments' ):
7272 del self ._distorch_block_assignments
73-
73+
74+
7475 return result
75-
76- logger .info (f"[MultiGPU_DisTorch2] high_precision_loras flag not retrieved from model. DisTorchV2 Loader not used. Reverting to normal loading behavior " )
76+
77+ logger .info (f"[MultiGPU_DisTorch2] DisTorchV2 Loader activated " )
7778
7879 mem_counter = 0
7980 patch_counter = 0
8081
81- loading = self ._load_list ()
82-
83- load_completely = []
84- loading .sort (reverse = True )
85- for x in loading :
86- n = x [1 ]
87- m = x [2 ]
88- params = x [3 ]
89- module_mem = x [0 ]
90-
91- weight_key = "{}.weight" .format (n )
92- bias_key = "{}.bias" .format (n )
93-
94- cast_weight = self .force_cast_weights
95-
96- if hasattr (m , "comfy_cast_weights" ):
97- #logging.info(f"Unpatching weight {weight_key} for Distorch2")
98- wipe_lowvram_weight (m )
99-
100- #logging.info(f"Adding {n} to 'load_completely' list")
101- mem_counter += module_mem
102- load_completely .append ((module_mem , n , m , params ))
103-
104- if cast_weight and hasattr (m , "comfy_cast_weights" ):
105- #logging.info(f"Setting cast weights for {weight_key}")
106- m .prev_comfy_cast_weights = m .comfy_cast_weights
107- m .comfy_cast_weights = True
108-
109- if weight_key in self .weight_wrapper_patches :
110- logging .info (f"Patching weight wrapper { m } for Distorch2" )
111- m .weight_function .extend (self .weight_wrapper_patches [weight_key ])
112-
113- if bias_key in self .weight_wrapper_patches :
114- logging .info (f"Patching bias wrapper { bias_key } for Distorch2" )
115- m .bias_function .extend (self .weight_wrapper_patches [bias_key ])
116-
117- mem_counter += move_weight_functions (m , device_to )
118-
119-
12082 logger .info (f"[MultiGPU_DisTorch2] Using static allocation for model { debug_hash [:8 ]} " )
12183 # Parse allocation string and apply static assignment
12284 device_assignments = analyze_safetensor_loading (self , allocations )
123-
124-
12585
126- # Apply our static assignments instead of ComfyUI's dynamic ones
127- for block_name , target_device in device_assignments ['block_assignments' ].items ():
128- # Find the module by name
129- parts = block_name .split ('.' )
130- module = self .model
131- for part in parts :
132- if hasattr (module , part ):
133- module = getattr (module , part )
134- else :
135- break
86+ model_type = type (self .model ).__name__
87+
88+ if model_type == "SDXL" or model_type == "LTXV" :
89+ on_compute_patching = False
90+ else :
91+ on_compute_patching = True
92+
93+ if on_compute_patching :
94+ high_precision_loras = self .model ._distorch_high_precision_loras
95+ loading = self ._load_list ()
96+ loading .sort (reverse = True )
97+ for module_size , module_name , module_object , params in loading :
98+ # Step 1: Write block/tensor to compute device first
99+ module_object .to (device_to )
100+
101+ # Step 2: Apply LoRa patches while on compute device
102+ weight_key = "{}.weight" .format (module_name )
103+ bias_key = "{}.bias" .format (module_name )
104+
105+ if weight_key in self .patches :
106+ self .patch_weight_to_device (weight_key , device_to = device_to )
107+ if weight_key in self .weight_wrapper_patches :
108+ module_object .weight_function .extend (self .weight_wrapper_patches [weight_key ])
109+
110+ if bias_key in self .patches :
111+ self .patch_weight_to_device (bias_key , device_to = device_to )
112+ if bias_key in self .weight_wrapper_patches :
113+ module_object .bias_function .extend (self .weight_wrapper_patches [bias_key ])
114+
115+ # Step 3: FP8 casting for CPU storage (if enabled)
116+ block_target_device = device_assignments ['block_assignments' ].get (module_name , device_to )
117+ has_patches = weight_key in self .patches or bias_key in self .patches
118+
119+ logger .info (f"[MultiGPU_DisTorch2] Patch-on-Compute: Processing { module_name } -> block_target_device={ block_target_device } " )
120+
121+ if not high_precision_loras and block_target_device == "cpu" and has_patches and model_original_dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
122+ logger .info (f"[MultiGPU_DisTorch2] FP8 casting conditions met for { module_name } " )
123+ for param_name , param in module_object .named_parameters ():
124+ if param .dtype .is_floating_point :
125+ cast_data = comfy .float .stochastic_rounding (param .data , torch .float8_e4m3fn )
126+ new_param = torch .nn .Parameter (cast_data .to (torch .float8_e4m3fn ))
127+ new_param .requires_grad = param .requires_grad
128+ setattr (module_object , param_name , new_param )
129+ logger .debug (f"[MultiGPU_DisTorch2] Cast { module_name } .{ param_name } to FP8 for CPU storage" )
130+
131+ # Step 4: Move to ultimate destination based on DisTorch assignment
132+ if block_target_device != device_to :
133+ logger .debug (f"[MultiGPU_DisTorch2] Moving { module_name } from { device_to } to { block_target_device } " )
134+ module_object .to (block_target_device )
135+
136+ # Mark as patched and update memory counter
137+ module_object .comfy_patched_weights = True
138+ mem_counter += module_size
139+
140+ logger .info (f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: { mem_counter / (1024 * 1024 ):.2f} MB" )
141+
142+ else :
143+ # Apply our static assignments instead of ComfyUI's dynamic ones
144+ for block_name , target_device in device_assignments ['block_assignments' ].items ():
145+ # Find the module by name
146+ parts = block_name .split ('.' )
147+ module = self .model
148+ for part in parts :
149+ if hasattr (module , part ):
150+ module = getattr (module , part )
151+ else :
152+ break
153+
154+ if hasattr (module , 'weight' ) or hasattr (module , 'comfy_cast_weights' ):
155+ # Move to our assigned device
156+ logger .info (f"[MultiGPU_DisTorch2] Patch-on-Device: Moving { block_name } to { target_device } " )
157+ module .to (target_device )
158+ # Mark for ComfyUI's cast system if not already marked
159+ if hasattr (module , 'comfy_cast_weights' ):
160+ module .comfy_cast_weights = True
136161
137- if hasattr (module , 'weight' ) or hasattr (module , 'comfy_cast_weights' ):
138- # Move to our assigned device
139- logger .debug (f"[MultiGPU_DisTorch2] Moving { block_name } to { target_device } " )
140- module .to (target_device )
141- # Mark for ComfyUI's cast system if not already marked
142- if hasattr (module , 'comfy_cast_weights' ):
143- module .comfy_cast_weights = True
144-
145- # Return 0 to indicate no additional memory used on compute device
162+ weight_key = "{}.weight" .format (block_name )
163+ bias_key = "{}.bias" .format (block_name )
164+
165+ if weight_key in self .patches :
166+ self .patch_weight_to_device (weight_key , device_to = target_device )
167+ if weight_key in self .weight_wrapper_patches :
168+ module_object .weight_function .extend (self .weight_wrapper_patches [weight_key ])
169+
170+ if bias_key in self .patches :
171+ self .patch_weight_to_device (bias_key , device_to = target_device )
172+ if bias_key in self .weight_wrapper_patches :
173+ module_object .bias_function .extend (self .weight_wrapper_patches [bias_key ])
174+
146175 return 0
147176
148177
@@ -248,7 +277,12 @@ def analyze_safetensor_loading(model_patcher, allocations_str):
248277 device_assignments = {device : [] for device in DEVICE_RATIOS_DISTORCH .keys ()}
249278 block_assignments = {}
250279
251- compute_device = str (current_device )
280+ # Determine the primary compute device (first non-cpu device)
281+ compute_device = "cuda:0" # Fallback
282+ for dev in sorted_devices :
283+ if dev != "cpu" :
284+ compute_device = dev
285+ break
252286
253287 # Calculate total memory to be offloaded to donor devices
254288 total_offload_gb = sum (DEVICE_RATIOS_DISTORCH .get (d , 0 ) for d in sorted_devices if d != compute_device )
0 commit comments