@@ -32,7 +32,7 @@ def patch_load_state_dict_guess_config():
3232
3333def patched_load_state_dict_guess_config (sd , output_vae = True , output_clip = True , output_clipvision = False ,
3434 embedding_directory = None , output_model = True , model_options = {},
35- te_model_options = {}, metadata = None ):
35+ te_model_options = {}, metadata = None , disable_dynamic = False ):
3636 """Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
3737 from . import set_current_device , set_current_text_encoder_device , get_current_device , get_current_text_encoder_device
3838
@@ -42,7 +42,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
4242 distorch_config = checkpoint_distorch_config .get (config_hash )
4343
4444 if not device_config and not distorch_config :
45- return original_load_state_dict_guess_config (sd , output_vae , output_clip , output_clipvision , embedding_directory , output_model , model_options , te_model_options , metadata )
45+ return original_load_state_dict_guess_config (
46+ sd ,
47+ output_vae = output_vae ,
48+ output_clip = output_clip ,
49+ output_clipvision = output_clipvision ,
50+ embedding_directory = embedding_directory ,
51+ output_model = output_model ,
52+ model_options = model_options ,
53+ te_model_options = te_model_options ,
54+ metadata = metadata ,
55+ disable_dynamic = disable_dynamic ,
56+ )
4657
4758 logger .debug ("[MultiGPU Checkpoint] ENTERING Patched Checkpoint Loader" )
4859 logger .debug (f"[MultiGPU Checkpoint] Received Device Config: { device_config } " )
@@ -73,7 +84,12 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
7384 logger .warning ("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only." )
7485 # Simplified fallback for non-checkpoints
7586 set_current_device (device_config .get ('unet_device' , original_main_device ))
76- diffusion_model = comfy .sd .load_diffusion_model_state_dict (sd , model_options = {})
87+ diffusion_model = comfy .sd .load_diffusion_model_state_dict (
88+ sd ,
89+ model_options = {},
90+ metadata = metadata ,
91+ disable_dynamic = disable_dynamic ,
92+ )
7793 if diffusion_model is None :
7894 return None
7995 return (diffusion_model , None , VAE (sd = {}), None )
@@ -90,11 +106,11 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
90106 if unet_dtype is None :
91107 unet_dtype = mm .unet_dtype (model_params = parameters , supported_dtypes = unet_weight_dtype , weight_dtype = weight_dtype )
92108
93- unet_compute_device = device_config .get ('unet_device' , original_main_device )
109+ unet_compute_device = torch . device ( device_config .get ('unet_device' , original_main_device ) )
94110 if model_config .scaled_fp8 is not None :
95- manual_cast_dtype = mm .unet_manual_cast (None , torch . device ( unet_compute_device ) , model_config .supported_inference_dtypes )
111+ manual_cast_dtype = mm .unet_manual_cast (None , unet_compute_device , model_config .supported_inference_dtypes )
96112 else :
97- manual_cast_dtype = mm .unet_manual_cast (unet_dtype , torch . device ( unet_compute_device ) , model_config .supported_inference_dtypes )
113+ manual_cast_dtype = mm .unet_manual_cast (unet_dtype , unet_compute_device , model_config .supported_inference_dtypes )
98114 model_config .set_inference_dtype (unet_dtype , manual_cast_dtype )
99115 logger .info (f"UNet DType: { unet_dtype } , Manual Cast: { manual_cast_dtype } " )
100116
@@ -103,19 +119,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
103119 clipvision = comfy .clip_vision .load_clipvision_from_sd (sd , model_config .clip_vision_prefix , True )
104120
105121 if output_model :
106- unet_compute_device = device_config .get ('unet_device' , original_main_device )
122+ unet_compute_device = torch . device ( device_config .get ('unet_device' , original_main_device ) )
107123 set_current_device (unet_compute_device )
108124 inital_load_device = mm .unet_inital_load_device (parameters , unet_dtype )
109125
110126 multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "pre-load" )
111127
112128 model = model_config .get_model (sd , diffusion_model_prefix , device = inital_load_device )
113- model .load_model_weights (sd , diffusion_model_prefix )
129+ model_patcher_class = comfy .model_patcher .ModelPatcher if disable_dynamic else comfy .model_patcher .CoreModelPatcher
130+ model_patcher = model_patcher_class (model , load_device = unet_compute_device , offload_device = mm .unet_offload_device ())
131+ model .load_model_weights (sd , diffusion_model_prefix , assign = model_patcher .is_dynamic ())
114132 multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "post-weights" )
115133
116134 logger .mgpu_mm_log ("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup" )
117135 soft_empty_cache_multigpu ()
118- model_patcher = comfy .model_patcher .ModelPatcher (model , load_device = unet_compute_device , offload_device = mm .unet_offload_device ())
119136 multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "post-model" )
120137
121138 if distorch_config and 'unet_allocation' in distorch_config :
@@ -159,7 +176,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
159176 out_sd [k ] = quant_sd [k ]
160177 sd = out_sd
161178
162- clip_target_device = device_config .get ('clip_device' , original_clip_device )
179+ clip_target_device = torch . device ( device_config .get ('clip_device' , original_clip_device ) )
163180 set_current_text_encoder_device (clip_target_device )
164181
165182 clip_target = model_config .clip_target (state_dict = sd )
@@ -170,7 +187,15 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
170187 multigpu_memory_log (f"clip:{ config_hash [:8 ]} " , "pre-load" )
171188 soft_empty_cache_multigpu ()
172189 clip_params = comfy .utils .calculate_parameters (clip_sd )
173- clip = CLIP (clip_target , embedding_directory = embedding_directory , tokenizer_data = clip_sd , parameters = clip_params , model_options = te_model_options )
190+ clip = CLIP (
191+ clip_target ,
192+ embedding_directory = embedding_directory ,
193+ tokenizer_data = clip_sd ,
194+ parameters = clip_params ,
195+ state_dict = clip_sd ,
196+ model_options = te_model_options ,
197+ disable_dynamic = disable_dynamic ,
198+ )
174199
175200 if distorch_config and 'clip_allocation' in distorch_config :
176201 clip_alloc = distorch_config ['clip_allocation' ]
@@ -181,11 +206,6 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
181206 logger .info (f"[CHECKPOINT_META] CLIP inner_model id=0x{ id (inner_clip ):x} " )
182207 clip .patcher .model ._distorch_high_precision_loras = distorch_config .get ('high_precision_loras' , True )
183208
184- m , u = clip .load_sd (clip_sd , full_model = True ) # This respects the patched text_encoder_device
185- if len (m ) > 0 :
186- logger .warning (f"CLIP missing keys: { m } " )
187- if len (u ) > 0 :
188- logger .debug (f"CLIP unexpected keys: { u } " )
189209 logger .info ("CLIP Loaded." )
190210 multigpu_memory_log (f"clip:{ config_hash [:8 ]} " , "post-load" )
191211 else :
0 commit comments