2121def patch_load_state_dict_guess_config ():
2222 """Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading."""
2323 global original_load_state_dict_guess_config
24-
24+
2525 if original_load_state_dict_guess_config is not None :
2626 logger .debug ("[MultiGPU Checkpoint] load_state_dict_guess_config is already patched." )
2727 return
28-
28+
2929 logger .info ("[MultiGPU Core Patching] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading." )
3030 original_load_state_dict_guess_config = comfy .sd .load_state_dict_guess_config
3131 comfy .sd .load_state_dict_guess_config = patched_load_state_dict_guess_config
3232
3333def patched_load_state_dict_guess_config (sd , output_vae = True , output_clip = True , output_clipvision = False ,
3434 embedding_directory = None , output_model = True , model_options = {},
35- te_model_options = {}, metadata = None ):
35+ te_model_options = {}, metadata = None , disable_dynamic = False ):
3636 """Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
3737 from . import set_current_device , set_current_text_encoder_device , get_current_device , get_current_text_encoder_device
38-
38+
3939 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
4040 config_hash = str (sd_size )
4141 device_config = checkpoint_device_config .get (config_hash )
4242 distorch_config = checkpoint_distorch_config .get (config_hash )
4343
4444 if not device_config and not distorch_config :
45- return original_load_state_dict_guess_config (sd , output_vae , output_clip , output_clipvision , embedding_directory , output_model , model_options , te_model_options , metadata )
45+ return original_load_state_dict_guess_config (
46+ sd ,
47+ output_vae = output_vae ,
48+ output_clip = output_clip ,
49+ output_clipvision = output_clipvision ,
50+ embedding_directory = embedding_directory ,
51+ output_model = output_model ,
52+ model_options = model_options ,
53+ te_model_options = te_model_options ,
54+ metadata = metadata ,
55+ disable_dynamic = disable_dynamic ,
56+ )
4657
4758 logger .debug ("[MultiGPU Checkpoint] ENTERING Patched Checkpoint Loader" )
4859 logger .debug (f"[MultiGPU Checkpoint] Received Device Config: { device_config } " )
@@ -53,7 +64,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5364 vae = None
5465 model = None
5566 model_patcher = None
56-
67+
5768 # Capture the current devices at runtime so we can restore them after loading
5869 original_main_device = get_current_device ()
5970 original_clip_device = get_current_text_encoder_device ()
@@ -68,12 +79,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6879 sd , metadata = comfy .utils .convert_old_quants (sd , diffusion_model_prefix , metadata = metadata )
6980
7081 model_config = comfy .model_detection .model_config_from_unet (sd , diffusion_model_prefix , metadata = metadata )
71-
82+
7283 if model_config is None :
7384 logger .warning ("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only." )
7485 # Simplified fallback for non-checkpoints
7586 set_current_device (device_config .get ('unet_device' , original_main_device ))
76- diffusion_model = comfy .sd .load_diffusion_model_state_dict (sd , model_options = {})
87+ diffusion_model = comfy .sd .load_diffusion_model_state_dict (
88+ sd ,
89+ model_options = {},
90+ metadata = metadata ,
91+ disable_dynamic = disable_dynamic ,
92+ )
7793 if diffusion_model is None :
7894 return None
7995 return (diffusion_model , None , VAE (sd = {}), None )
@@ -83,18 +99,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
8399 unet_weight_dtype = list (model_config .supported_inference_dtypes )
84100 if model_config .scaled_fp8 is not None :
85101 weight_dtype = None
86-
102+
87103 if custom_operations is not None :
88104 model_config .custom_operations = custom_operations
89105 unet_dtype = model_options .get ("dtype" , model_options .get ("weight_dtype" , None ))
90106 if unet_dtype is None :
91107 unet_dtype = mm .unet_dtype (model_params = parameters , supported_dtypes = unet_weight_dtype , weight_dtype = weight_dtype )
92-
93- unet_compute_device = device_config .get ('unet_device' , original_main_device )
108+
109+ unet_compute_device = torch . device ( device_config .get ('unet_device' , original_main_device ) )
94110 if model_config .scaled_fp8 is not None :
95- manual_cast_dtype = mm .unet_manual_cast (None , torch . device ( unet_compute_device ) , model_config .supported_inference_dtypes )
111+ manual_cast_dtype = mm .unet_manual_cast (None , unet_compute_device , model_config .supported_inference_dtypes )
96112 else :
97- manual_cast_dtype = mm .unet_manual_cast (unet_dtype , torch . device ( unet_compute_device ) , model_config .supported_inference_dtypes )
113+ manual_cast_dtype = mm .unet_manual_cast (unet_dtype , unet_compute_device , model_config .supported_inference_dtypes )
98114 model_config .set_inference_dtype (unet_dtype , manual_cast_dtype )
99115 logger .info (f"UNet DType: { unet_dtype } , Manual Cast: { manual_cast_dtype } " )
100116
@@ -103,19 +119,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
103119 clipvision = comfy .clip_vision .load_clipvision_from_sd (sd , model_config .clip_vision_prefix , True )
104120
105121 if output_model :
106- unet_compute_device = device_config .get ('unet_device' , original_main_device )
107- set_current_device (unet_compute_device )
122+ unet_compute_device = torch . device ( device_config .get ('unet_device' , original_main_device ) )
123+ set_current_device (unet_compute_device )
108124 inital_load_device = mm .unet_inital_load_device (parameters , unet_dtype )
109125
110126 multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "pre-load" )
111127
112128 model = model_config .get_model (sd , diffusion_model_prefix , device = inital_load_device )
113- model .load_model_weights (sd , diffusion_model_prefix )
129+ model_patcher_class = comfy .model_patcher .ModelPatcher if disable_dynamic else comfy .model_patcher .CoreModelPatcher
130+ model_patcher = model_patcher_class (model , load_device = unet_compute_device , offload_device = mm .unet_offload_device ())
131+ model .load_model_weights (sd , diffusion_model_prefix , assign = model_patcher .is_dynamic ())
114132 multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "post-weights" )
115133
116134 logger .mgpu_mm_log ("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup" )
117135 soft_empty_cache_multigpu ()
118- model_patcher = comfy .model_patcher .ModelPatcher (model , load_device = unet_compute_device , offload_device = mm .unet_offload_device ())
119136 multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "post-model" )
120137
121138 if distorch_config and 'unet_allocation' in distorch_config :
@@ -131,7 +148,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
131148 vae_target_device = torch .device (device_config .get ('vae_device' , original_main_device ))
132149 set_current_device (vae_target_device ) # Use main device context for VAE
133150 multigpu_memory_log (f"vae:{ config_hash [:8 ]} " , "pre-load" )
134-
151+
135152 vae_sd = comfy .utils .state_dict_prefix_replace (sd , {k : "" for k in model_config .vae_key_prefix }, filter_keys = True )
136153 vae_sd = model_config .process_vae_state_dict (vae_sd )
137154 vae = VAE (sd = vae_sd , metadata = metadata )
@@ -151,17 +168,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
151168 for pref in scaled_fp8_list :
152169 skip = skip or k .startswith (pref )
153170 if not skip :
154- out_sd [k ] = sd [k ]
171+ out_sd [k ] = sd [k ]
155172
156173 for pref in scaled_fp8_list :
157174 quant_sd , qmetadata = comfy .utils .convert_old_quants (sd , pref , metadata = {})
158175 for k in quant_sd :
159176 out_sd [k ] = quant_sd [k ]
160177 sd = out_sd
161178
162- clip_target_device = device_config .get ('clip_device' , original_clip_device )
179+ clip_target_device = torch . device ( device_config .get ('clip_device' , original_clip_device ) )
163180 set_current_text_encoder_device (clip_target_device )
164-
181+
165182 clip_target = model_config .clip_target (state_dict = sd )
166183 if clip_target is not None :
167184 clip_sd = model_config .process_clip_state_dict (sd )
@@ -170,7 +187,15 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
170187 multigpu_memory_log (f"clip:{ config_hash [:8 ]} " , "pre-load" )
171188 soft_empty_cache_multigpu ()
172189 clip_params = comfy .utils .calculate_parameters (clip_sd )
173- clip = CLIP (clip_target , embedding_directory = embedding_directory , tokenizer_data = clip_sd , parameters = clip_params , model_options = te_model_options )
190+ clip = CLIP (
191+ clip_target ,
192+ embedding_directory = embedding_directory ,
193+ tokenizer_data = clip_sd ,
194+ parameters = clip_params ,
195+ state_dict = clip_sd ,
196+ model_options = te_model_options ,
197+ disable_dynamic = disable_dynamic ,
198+ )
174199
175200 if distorch_config and 'clip_allocation' in distorch_config :
176201 clip_alloc = distorch_config ['clip_allocation' ]
@@ -181,16 +206,13 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
181206 logger .info (f"[CHECKPOINT_META] CLIP inner_model id=0x{ id (inner_clip ):x} " )
182207 clip .patcher .model ._distorch_high_precision_loras = distorch_config .get ('high_precision_loras' , True )
183208
184- m , u = clip .load_sd (clip_sd , full_model = True ) # This respects the patched text_encoder_device
185- if len (m ) > 0 : logger .warning (f"CLIP missing keys: { m } " )
186- if len (u ) > 0 : logger .debug (f"CLIP unexpected keys: { u } " )
187209 logger .info ("CLIP Loaded." )
188210 multigpu_memory_log (f"clip:{ config_hash [:8 ]} " , "post-load" )
189211 else :
190212 logger .warning ("No CLIP/text encoder weights in checkpoint." )
191213 else :
192214 logger .warning ("CLIP target not found in model config." )
193-
215+
194216 finally :
195217 set_current_device (original_main_device )
196218 set_current_text_encoder_device (original_clip_device )
@@ -206,7 +228,7 @@ def INPUT_TYPES(s):
206228 import folder_paths
207229 devices = get_device_list ()
208230 default_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
209-
231+
210232 return {
211233 "required" : {
212234 "ckpt_name" : (folder_paths .get_filename_list ("checkpoints" ), ),
@@ -215,27 +237,27 @@ def INPUT_TYPES(s):
215237 "vae_device" : (devices , {"default" : default_device }),
216238 }
217239 }
218-
240+
219241 RETURN_TYPES = ("MODEL" , "CLIP" , "VAE" )
220242 FUNCTION = "load_checkpoint"
221243 CATEGORY = "multigpu"
222244 TITLE = "Checkpoint Loader Advanced (MultiGPU)"
223-
245+
224246 def load_checkpoint (self , ckpt_name , unet_device , clip_device , vae_device ):
225247 patch_load_state_dict_guess_config ()
226-
248+
227249 import folder_paths
228250 import comfy .utils
229-
251+
230252 ckpt_path = folder_paths .get_full_path ("checkpoints" , ckpt_name )
231253 sd = comfy .utils .load_torch_file (ckpt_path )
232254 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
233255 config_hash = str (sd_size )
234-
256+
235257 checkpoint_device_config [config_hash ] = {
236258 'unet_device' : unet_device , 'clip_device' : clip_device , 'vae_device' : vae_device
237259 }
238-
260+
239261 # Load using standard loader, our patch will intercept
240262 from nodes import CheckpointLoaderSimple
241263 return CheckpointLoaderSimple ().load_checkpoint (ckpt_name )
@@ -247,7 +269,7 @@ def INPUT_TYPES(s):
247269 import folder_paths
248270 devices = get_device_list ()
249271 compute_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
250-
272+
251273 return {
252274 "required" : {
253275 "ckpt_name" : (folder_paths .get_filename_list ("checkpoints" ), ),
@@ -265,18 +287,18 @@ def INPUT_TYPES(s):
265287 "eject_models" : ("BOOLEAN" , {"default" : True }),
266288 }
267289 }
268-
290+
269291 RETURN_TYPES = ("MODEL" , "CLIP" , "VAE" )
270292 FUNCTION = "load_checkpoint"
271293 CATEGORY = "multigpu/distorch_2"
272294 TITLE = "Checkpoint Loader Advanced (DisTorch2)"
273-
295+
274296 def load_checkpoint (self , ckpt_name , unet_compute_device , unet_virtual_vram_gb , unet_donor_device ,
275297 clip_compute_device , clip_virtual_vram_gb , clip_donor_device , vae_device ,
276298 unet_expert_mode_allocations = "" , clip_expert_mode_allocations = "" , high_precision_loras = True , eject_models = True ):
277-
299+
278300 if eject_models :
279- logger .mgpu_mm_log (f "[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction" )
301+ logger .mgpu_mm_log ("[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction" )
280302 ejection_count = 0
281303 for i , lm in enumerate (mm .current_loaded_models ):
282304 model_name = type (getattr (lm .model , 'model' , lm .model )).__name__ if lm .model else 'Unknown'
@@ -289,17 +311,17 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
289311 logger .mgpu_mm_log (f"[EJECT_MARKED] Model { i } : { model_name } (direct patcher) → marked for eviction" )
290312 ejection_count += 1
291313 logger .mgpu_mm_log (f"[EJECT_MODELS_SETUP_COMPLETE] Marked { ejection_count } models for Comfy Core eviction during load_models_gpu" )
292-
293- patch_load_state_dict_guess_config ()
294-
314+
315+ patch_load_state_dict_guess_config ()
316+
295317 import folder_paths
296318 import comfy .utils
297-
319+
298320 ckpt_path = folder_paths .get_full_path ("checkpoints" , ckpt_name )
299321 sd = comfy .utils .load_torch_file (ckpt_path )
300322 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
301323 config_hash = str (sd_size )
302-
324+
303325 checkpoint_device_config [config_hash ] = {
304326 'unet_device' : unet_compute_device ,
305327 'clip_device' : clip_compute_device ,
@@ -312,7 +334,7 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
312334 elif unet_expert_mode_allocations :
313335 unet_vram_str = unet_compute_device
314336 unet_alloc = f"{ unet_expert_mode_allocations } #{ unet_vram_str } " if unet_expert_mode_allocations or unet_vram_str else ""
315-
337+
316338 clip_vram_str = ""
317339 if clip_virtual_vram_gb > 0 :
318340 clip_vram_str = f"{ clip_compute_device } ;{ clip_virtual_vram_gb } ;{ clip_donor_device } "
@@ -327,6 +349,6 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
327349 'unet_settings' : hashlib .sha256 (f"{ unet_alloc } { high_precision_loras } " .encode ()).hexdigest (),
328350 'clip_settings' : hashlib .sha256 (f"{ clip_alloc } { high_precision_loras } " .encode ()).hexdigest (),
329351 }
330-
352+
331353 from nodes import CheckpointLoaderSimple
332354 return CheckpointLoaderSimple ().load_checkpoint (ckpt_name )
0 commit comments