2121def patch_load_state_dict_guess_config ():
2222 """Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading."""
2323 global original_load_state_dict_guess_config
24-
24+
2525 if original_load_state_dict_guess_config is not None :
2626 logger .debug ("[MultiGPU Checkpoint] load_state_dict_guess_config is already patched." )
2727 return
28-
28+
2929 logger .info ("[MultiGPU Core Patching] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading." )
3030 original_load_state_dict_guess_config = comfy .sd .load_state_dict_guess_config
3131 comfy .sd .load_state_dict_guess_config = patched_load_state_dict_guess_config
@@ -35,7 +35,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
3535 te_model_options = {}, metadata = None ):
3636 """Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
3737 from . import set_current_device , set_current_text_encoder_device , get_current_device , get_current_text_encoder_device
38-
38+
3939 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
4040 config_hash = str (sd_size )
4141 device_config = checkpoint_device_config .get (config_hash )
@@ -53,7 +53,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5353 vae = None
5454 model = None
5555 model_patcher = None
56-
56+
5757 # Capture the current devices at runtime so we can restore them after loading
5858 original_main_device = get_current_device ()
5959 original_clip_device = get_current_text_encoder_device ()
@@ -68,7 +68,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6868 sd , metadata = comfy .utils .convert_old_quants (sd , diffusion_model_prefix , metadata = metadata )
6969
7070 model_config = comfy .model_detection .model_config_from_unet (sd , diffusion_model_prefix , metadata = metadata )
71-
71+
7272 if model_config is None :
7373 logger .warning ("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only." )
7474 # Simplified fallback for non-checkpoints
@@ -83,13 +83,13 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
8383 unet_weight_dtype = list (model_config .supported_inference_dtypes )
8484 if model_config .scaled_fp8 is not None :
8585 weight_dtype = None
86-
86+
8787 if custom_operations is not None :
8888 model_config .custom_operations = custom_operations
8989 unet_dtype = model_options .get ("dtype" , model_options .get ("weight_dtype" , None ))
9090 if unet_dtype is None :
9191 unet_dtype = mm .unet_dtype (model_params = parameters , supported_dtypes = unet_weight_dtype , weight_dtype = weight_dtype )
92-
92+
9393 unet_compute_device = device_config .get ('unet_device' , original_main_device )
9494 if model_config .scaled_fp8 is not None :
9595 manual_cast_dtype = mm .unet_manual_cast (None , torch .device (unet_compute_device ), model_config .supported_inference_dtypes )
@@ -104,7 +104,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
104104
105105 if output_model :
106106 unet_compute_device = device_config .get ('unet_device' , original_main_device )
107- set_current_device (unet_compute_device )
107+ set_current_device (unet_compute_device )
108108 inital_load_device = mm .unet_inital_load_device (parameters , unet_dtype )
109109
110110 multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "pre-load" )
@@ -131,7 +131,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
131131 vae_target_device = torch .device (device_config .get ('vae_device' , original_main_device ))
132132 set_current_device (vae_target_device ) # Use main device context for VAE
133133 multigpu_memory_log (f"vae:{ config_hash [:8 ]} " , "pre-load" )
134-
134+
135135 vae_sd = comfy .utils .state_dict_prefix_replace (sd , {k : "" for k in model_config .vae_key_prefix }, filter_keys = True )
136136 vae_sd = model_config .process_vae_state_dict (vae_sd )
137137 vae = VAE (sd = vae_sd , metadata = metadata )
@@ -151,7 +151,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
151151 for pref in scaled_fp8_list :
152152 skip = skip or k .startswith (pref )
153153 if not skip :
154- out_sd [k ] = sd [k ]
154+ out_sd [k ] = sd [k ]
155155
156156 for pref in scaled_fp8_list :
157157 quant_sd , qmetadata = comfy .utils .convert_old_quants (sd , pref , metadata = {})
@@ -161,7 +161,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
161161
162162 clip_target_device = device_config .get ('clip_device' , original_clip_device )
163163 set_current_text_encoder_device (clip_target_device )
164-
164+
165165 clip_target = model_config .clip_target (state_dict = sd )
166166 if clip_target is not None :
167167 clip_sd = model_config .process_clip_state_dict (sd )
@@ -182,15 +182,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
182182 clip .patcher .model ._distorch_high_precision_loras = distorch_config .get ('high_precision_loras' , True )
183183
184184 m , u = clip .load_sd (clip_sd , full_model = True ) # This respects the patched text_encoder_device
185- if len (m ) > 0 : logger .warning (f"CLIP missing keys: { m } " )
186- if len (u ) > 0 : logger .debug (f"CLIP unexpected keys: { u } " )
185+ if len (m ) > 0 :
186+ logger .warning (f"CLIP missing keys: { m } " )
187+ if len (u ) > 0 :
188+ logger .debug (f"CLIP unexpected keys: { u } " )
187189 logger .info ("CLIP Loaded." )
188190 multigpu_memory_log (f"clip:{ config_hash [:8 ]} " , "post-load" )
189191 else :
190192 logger .warning ("No CLIP/text encoder weights in checkpoint." )
191193 else :
192194 logger .warning ("CLIP target not found in model config." )
193-
195+
194196 finally :
195197 set_current_device (original_main_device )
196198 set_current_text_encoder_device (original_clip_device )
@@ -206,7 +208,7 @@ def INPUT_TYPES(s):
206208 import folder_paths
207209 devices = get_device_list ()
208210 default_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
209-
211+
210212 return {
211213 "required" : {
212214 "ckpt_name" : (folder_paths .get_filename_list ("checkpoints" ), ),
@@ -215,27 +217,27 @@ def INPUT_TYPES(s):
215217 "vae_device" : (devices , {"default" : default_device }),
216218 }
217219 }
218-
220+
219221 RETURN_TYPES = ("MODEL" , "CLIP" , "VAE" )
220222 FUNCTION = "load_checkpoint"
221223 CATEGORY = "multigpu"
222224 TITLE = "Checkpoint Loader Advanced (MultiGPU)"
223-
225+
224226 def load_checkpoint (self , ckpt_name , unet_device , clip_device , vae_device ):
225227 patch_load_state_dict_guess_config ()
226-
228+
227229 import folder_paths
228230 import comfy .utils
229-
231+
230232 ckpt_path = folder_paths .get_full_path ("checkpoints" , ckpt_name )
231233 sd = comfy .utils .load_torch_file (ckpt_path )
232234 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
233235 config_hash = str (sd_size )
234-
236+
235237 checkpoint_device_config [config_hash ] = {
236238 'unet_device' : unet_device , 'clip_device' : clip_device , 'vae_device' : vae_device
237239 }
238-
240+
239241 # Load using standard loader, our patch will intercept
240242 from nodes import CheckpointLoaderSimple
241243 return CheckpointLoaderSimple ().load_checkpoint (ckpt_name )
@@ -247,7 +249,7 @@ def INPUT_TYPES(s):
247249 import folder_paths
248250 devices = get_device_list ()
249251 compute_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
250-
252+
251253 return {
252254 "required" : {
253255 "ckpt_name" : (folder_paths .get_filename_list ("checkpoints" ), ),
@@ -265,18 +267,18 @@ def INPUT_TYPES(s):
265267 "eject_models" : ("BOOLEAN" , {"default" : True }),
266268 }
267269 }
268-
270+
269271 RETURN_TYPES = ("MODEL" , "CLIP" , "VAE" )
270272 FUNCTION = "load_checkpoint"
271273 CATEGORY = "multigpu/distorch_2"
272274 TITLE = "Checkpoint Loader Advanced (DisTorch2)"
273-
275+
274276 def load_checkpoint (self , ckpt_name , unet_compute_device , unet_virtual_vram_gb , unet_donor_device ,
275277 clip_compute_device , clip_virtual_vram_gb , clip_donor_device , vae_device ,
276278 unet_expert_mode_allocations = "" , clip_expert_mode_allocations = "" , high_precision_loras = True , eject_models = True ):
277-
279+
278280 if eject_models :
279- logger .mgpu_mm_log (f "[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction" )
281+ logger .mgpu_mm_log ("[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction" )
280282 ejection_count = 0
281283 for i , lm in enumerate (mm .current_loaded_models ):
282284 model_name = type (getattr (lm .model , 'model' , lm .model )).__name__ if lm .model else 'Unknown'
@@ -289,17 +291,17 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
289291 logger .mgpu_mm_log (f"[EJECT_MARKED] Model { i } : { model_name } (direct patcher) → marked for eviction" )
290292 ejection_count += 1
291293 logger .mgpu_mm_log (f"[EJECT_MODELS_SETUP_COMPLETE] Marked { ejection_count } models for Comfy Core eviction during load_models_gpu" )
292-
293- patch_load_state_dict_guess_config ()
294-
294+
295+ patch_load_state_dict_guess_config ()
296+
295297 import folder_paths
296298 import comfy .utils
297-
299+
298300 ckpt_path = folder_paths .get_full_path ("checkpoints" , ckpt_name )
299301 sd = comfy .utils .load_torch_file (ckpt_path )
300302 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
301303 config_hash = str (sd_size )
302-
304+
303305 checkpoint_device_config [config_hash ] = {
304306 'unet_device' : unet_compute_device ,
305307 'clip_device' : clip_compute_device ,
@@ -312,7 +314,7 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
312314 elif unet_expert_mode_allocations :
313315 unet_vram_str = unet_compute_device
314316 unet_alloc = f"{ unet_expert_mode_allocations } #{ unet_vram_str } " if unet_expert_mode_allocations or unet_vram_str else ""
315-
317+
316318 clip_vram_str = ""
317319 if clip_virtual_vram_gb > 0 :
318320 clip_vram_str = f"{ clip_compute_device } ;{ clip_virtual_vram_gb } ;{ clip_donor_device } "
@@ -327,6 +329,6 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
327329 'unet_settings' : hashlib .sha256 (f"{ unet_alloc } { high_precision_loras } " .encode ()).hexdigest (),
328330 'clip_settings' : hashlib .sha256 (f"{ clip_alloc } { high_precision_loras } " .encode ()).hexdigest (),
329331 }
330-
332+
331333 from nodes import CheckpointLoaderSimple
332334 return CheckpointLoaderSimple ().load_checkpoint (ckpt_name )
0 commit comments