@@ -83,94 +83,52 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
8383 # Parse allocation string and apply static assignment
8484 device_assignments = analyze_safetensor_loading (self , allocations )
8585
86- model_type = type (self .model ).__name__
87-
88- if model_type == "SDXL" or model_type == "LTXV" :
89- on_compute_patching = False
90- else :
91- on_compute_patching = True
92-
93- if on_compute_patching :
94- high_precision_loras = self .model ._distorch_high_precision_loras
95- loading = self ._load_list ()
96- loading .sort (reverse = True )
97- for module_size , module_name , module_object , params in loading :
98- # Step 1: Write block/tensor to compute device first
99- module_object .to (device_to )
100-
101- # Step 2: Apply LoRa patches while on compute device
102- weight_key = "{}.weight" .format (module_name )
103- bias_key = "{}.bias" .format (module_name )
104-
105- if weight_key in self .patches :
106- self .patch_weight_to_device (weight_key , device_to = device_to )
107- if weight_key in self .weight_wrapper_patches :
108- module_object .weight_function .extend (self .weight_wrapper_patches [weight_key ])
109-
110- if bias_key in self .patches :
111- self .patch_weight_to_device (bias_key , device_to = device_to )
112- if bias_key in self .weight_wrapper_patches :
113- module_object .bias_function .extend (self .weight_wrapper_patches [bias_key ])
114-
115- # Step 3: FP8 casting for CPU storage (if enabled)
116- block_target_device = device_assignments ['block_assignments' ].get (module_name , device_to )
117- has_patches = weight_key in self .patches or bias_key in self .patches
118-
119- logger .info (f"[MultiGPU_DisTorch2] Patch-on-Compute: Processing { module_name } -> block_target_device={ block_target_device } " )
120-
121- if not high_precision_loras and block_target_device == "cpu" and has_patches and model_original_dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
122- logger .info (f"[MultiGPU_DisTorch2] FP8 casting conditions met for { module_name } " )
123- for param_name , param in module_object .named_parameters ():
124- if param .dtype .is_floating_point :
125- cast_data = comfy .float .stochastic_rounding (param .data , torch .float8_e4m3fn )
126- new_param = torch .nn .Parameter (cast_data .to (torch .float8_e4m3fn ))
127- new_param .requires_grad = param .requires_grad
128- setattr (module_object , param_name , new_param )
129- logger .debug (f"[MultiGPU_DisTorch2] Cast { module_name } .{ param_name } to FP8 for CPU storage" )
130-
131- # Step 4: Move to ultimate destination based on DisTorch assignment
132- if block_target_device != device_to :
133- logger .debug (f"[MultiGPU_DisTorch2] Moving { module_name } from { device_to } to { block_target_device } " )
134- module_object .to (block_target_device )
135-
136- # Mark as patched and update memory counter
137- module_object .comfy_patched_weights = True
138- mem_counter += module_size
139-
140- logger .info (f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: { mem_counter / (1024 * 1024 ):.2f} MB" )
141-
142- else :
143- # Apply our static assignments instead of ComfyUI's dynamic ones
144- for block_name , target_device in device_assignments ['block_assignments' ].items ():
145- # Find the module by name
146- parts = block_name .split ('.' )
147- module = self .model
148- for part in parts :
149- if hasattr (module , part ):
150- module = getattr (module , part )
151- else :
152- break
153-
154- if hasattr (module , 'weight' ) or hasattr (module , 'comfy_cast_weights' ):
155- # Move to our assigned device
156- logger .info (f"[MultiGPU_DisTorch2] Patch-on-Device: Moving { block_name } to { target_device } " )
157- module .to (target_device )
158- # Mark for ComfyUI's cast system if not already marked
159- if hasattr (module , 'comfy_cast_weights' ):
160- module .comfy_cast_weights = True
161-
162- weight_key = "{}.weight" .format (block_name )
163- bias_key = "{}.bias" .format (block_name )
164-
165- if weight_key in self .patches :
166- self .patch_weight_to_device (weight_key , device_to = target_device )
167- if weight_key in self .weight_wrapper_patches :
168- module_object .weight_function .extend (self .weight_wrapper_patches [weight_key ])
169-
170- if bias_key in self .patches :
171- self .patch_weight_to_device (bias_key , device_to = target_device )
172- if bias_key in self .weight_wrapper_patches :
173- module_object .bias_function .extend (self .weight_wrapper_patches [bias_key ])
86+ high_precision_loras = self .model ._distorch_high_precision_loras
87+ loading = self ._load_list ()
88+ loading .sort (reverse = True )
89+ for module_size , module_name , module_object , params in loading :
90+ # Step 1: Write block/tensor to compute device first
91+ module_object .to (device_to )
92+
93+ # Step 2: Apply LoRa patches while on compute device
94+ weight_key = "{}.weight" .format (module_name )
95+ bias_key = "{}.bias" .format (module_name )
96+
97+ if weight_key in self .patches :
98+ self .patch_weight_to_device (weight_key , device_to = device_to )
99+ if weight_key in self .weight_wrapper_patches :
100+ module_object .weight_function .extend (self .weight_wrapper_patches [weight_key ])
101+
102+ if bias_key in self .patches :
103+ self .patch_weight_to_device (bias_key , device_to = device_to )
104+ if bias_key in self .weight_wrapper_patches :
105+ module_object .bias_function .extend (self .weight_wrapper_patches [bias_key ])
106+
107+ # Step 3: FP8 casting for CPU storage (if enabled)
108+ block_target_device = device_assignments ['block_assignments' ].get (module_name , device_to )
109+ has_patches = weight_key in self .patches or bias_key in self .patches
110+
111+ if not high_precision_loras and block_target_device == "cpu" and has_patches and model_original_dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
112+ logger .info (f"[MultiGPU_DisTorch2] FP8 casting conditions met for { module_name } " )
113+ for param_name , param in module_object .named_parameters ():
114+ if param .dtype .is_floating_point :
115+ cast_data = comfy .float .stochastic_rounding (param .data , torch .float8_e4m3fn )
116+ new_param = torch .nn .Parameter (cast_data .to (torch .float8_e4m3fn ))
117+ new_param .requires_grad = param .requires_grad
118+ setattr (module_object , param_name , new_param )
119+ logger .debug (f"[MultiGPU_DisTorch2] Cast { module_name } .{ param_name } to FP8 for CPU storage" )
120+
121+ # Step 4: Move to ultimate destination based on DisTorch assignment
122+ if block_target_device != device_to :
123+ logger .debug (f"[MultiGPU_DisTorch2] Moving { module_name } from { device_to } to { block_target_device } " )
124+ module_object .to (block_target_device )
125+ module_object .comfy_cast_weights = True
126+
127+ # Mark as patched and update memory counter
128+ module_object .comfy_patched_weights = True
129+ mem_counter += module_size
130+
131+ logger .info (f"[MultiGPU_DisTorch2] DisTorch loading completed. Total memory: { mem_counter / (1024 * 1024 ):.2f} MB" )
174132
175133 return 0
176134
0 commit comments