1- """
2- Advanced Checkpoint Loaders for MultiGPU
3- Provides device-specific and DisTorch2 sharding for checkpoint components
4- """
5-
61import torch
72import logging
83import hashlib
138import comfy .clip_vision
149from comfy .sd import VAE , CLIP
1510from .device_utils import get_device_list , soft_empty_cache_multigpu
11+ from .model_management_mgpu import multigpu_memory_log
1612from .distorch_2 import safetensor_allocation_store , safetensor_settings_store , create_safetensor_model_hash , register_patched_safetensor_modelpatcher
1713
1814logger = logging .getLogger ("MultiGPU" )
2319original_load_state_dict_guess_config = None
2420
2521def patch_load_state_dict_guess_config ():
26- """
27- Monkey patch the load_state_dict_guess_config function to replace its logic
28- with a MultiGPU-aware implementation.
29- """
22+ """Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading."""
3023 global original_load_state_dict_guess_config
3124
3225 if original_load_state_dict_guess_config is not None :
33- logger .info ("[MultiGPU] load_state_dict_guess_config is already patched." )
26+ logger .debug ("[MultiGPU Checkpoint ] load_state_dict_guess_config is already patched." )
3427 return
3528
36- logger .info ("[MultiGPU] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading." )
29+ logger .info ("[MultiGPU Core Patching ] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading." )
3730 original_load_state_dict_guess_config = comfy .sd .load_state_dict_guess_config
3831 comfy .sd .load_state_dict_guess_config = patched_load_state_dict_guess_config
3932
4033def patched_load_state_dict_guess_config (sd , output_vae = True , output_clip = True , output_clipvision = False ,
4134 embedding_directory = None , output_model = True , model_options = {},
4235 te_model_options = {}, metadata = None ):
43-
36+ """Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
4437 from . import set_current_device , set_current_text_encoder_device , current_device , current_text_encoder_device
4538
4639 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
@@ -51,9 +44,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5144 if not device_config and not distorch_config :
5245 return original_load_state_dict_guess_config (sd , output_vae , output_clip , output_clipvision , embedding_directory , output_model , model_options , te_model_options , metadata )
5346
54- logger .info ( "--- [MultiGPU] ENTERING Patched Checkpoint Loader --- " )
55- logger .info (f"Received Device Config: { device_config } " )
56- logger .info (f"Received DisTorch2 Config: { distorch_config } " )
47+ logger .debug ( " [MultiGPU Checkpoint ] ENTERING Patched Checkpoint Loader" )
48+ logger .debug (f"[MultiGPU Checkpoint] Received Device Config: { device_config } " )
49+ logger .debug (f"[MultiGPU Checkpoint] Received DisTorch2 Config: { distorch_config } " )
5750
5851 clip = None
5952 clipvision = None
@@ -63,7 +56,6 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6356
6457 original_main_device = current_device
6558 original_clip_device = current_text_encoder_device
66- logger .info (f"Saved original device contexts: UNet/VAE='{ original_main_device } ', CLIP='{ original_clip_device } '" )
6759
6860 try :
6961 diffusion_model_prefix = comfy .model_detection .unet_prefix_from_state_dict (sd )
@@ -80,7 +72,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
8072 return None
8173 return (diffusion_model , None , VAE (sd = {}), None )
8274
83- logger .info (f"[MultiGPU] Detected Model Config: { type (model_config ).__name__ } , Parameters: { parameters / 10 ** 9 :.2f} B" )
75+ logger .debug (f"[MultiGPU] Detected Model Config: { type (model_config ).__name__ } , Parameters: { parameters / 10 ** 9 :.2f} B" )
8476
8577 unet_weight_dtype = list (model_config .supported_inference_dtypes )
8678 if model_config .scaled_fp8 is not None :
@@ -105,10 +97,14 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
10597 set_current_device (unet_compute_device )
10698 inital_load_device = mm .unet_inital_load_device (parameters , unet_dtype )
10799
100+ multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "pre-load" )
101+
108102 model = model_config .get_model (sd , diffusion_model_prefix , device = inital_load_device )
109103
110- soft_empty_cache_multigpu (logger )
104+ logger .mgpu_mm_log ("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup" )
105+ soft_empty_cache_multigpu ()
111106 model_patcher = comfy .model_patcher .ModelPatcher (model , load_device = unet_compute_device , offload_device = mm .unet_offload_device ())
107+ multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "post-model" )
112108
113109 if distorch_config and 'unet_allocation' in distorch_config :
114110 register_patched_safetensor_modelpatcher ()
@@ -117,17 +113,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
117113 safetensor_settings_store [model_hash ] = distorch_config .get ('unet_settings' ,'' )
118114 model .is_distorch = True
119115 model ._distorch_high_precision_loras = distorch_config .get ('high_precision_loras' , True )
120- logger .info (f"Stored DisTorch2 config for UNet (hash { model_hash [:8 ]} ): { distorch_config ['unet_allocation' ]} " )
116+ logger .mgpu_mm_log (f"Stored DisTorch2 config for UNet (hash { model_hash [:8 ]} ): { distorch_config ['unet_allocation' ]} " )
121117
122118 model .load_model_weights (sd , diffusion_model_prefix )
119+ multigpu_memory_log (f"unet:{ config_hash [:8 ]} " , "post-weights" )
123120
124121 if output_vae :
125122 vae_target_device = torch .device (device_config .get ('vae_device' , original_main_device ))
126123 set_current_device (vae_target_device ) # Use main device context for VAE
124+ multigpu_memory_log (f"vae:{ config_hash [:8 ]} " , "pre-load" )
127125
128126 vae_sd = comfy .utils .state_dict_prefix_replace (sd , {k : "" for k in model_config .vae_key_prefix }, filter_keys = True )
129127 vae_sd = model_config .process_vae_state_dict (vae_sd )
130128 vae = VAE (sd = vae_sd , metadata = metadata )
129+ multigpu_memory_log (f"vae:{ config_hash [:8 ]} " , "post-load" )
131130
132131 if output_clip :
133132 clip_target_device = device_config .get ('clip_device' , original_clip_device )
@@ -137,7 +136,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
137136 if clip_target is not None :
138137 clip_sd = model_config .process_clip_state_dict (sd )
139138 if len (clip_sd ) > 0 :
140- soft_empty_cache_multigpu (logger )
139+ logger .debug ("[MultiGPU Checkpoint] Invoking soft_empty_cache_multigpu before CLIP construction" )
140+ multigpu_memory_log (f"clip:{ config_hash [:8 ]} " , "pre-load" )
141+ soft_empty_cache_multigpu ()
141142 clip_params = comfy .utils .calculate_parameters (clip_sd )
142143 clip = CLIP (clip_target , embedding_directory = embedding_directory , tokenizer_data = clip_sd , parameters = clip_params , model_options = te_model_options )
143144
@@ -155,22 +156,19 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
155156 if len (m ) > 0 : logger .warning (f"CLIP missing keys: { m } " )
156157 if len (u ) > 0 : logger .debug (f"CLIP unexpected keys: { u } " )
157158 logger .info ("CLIP Loaded." )
159+ multigpu_memory_log (f"clip:{ config_hash [:8 ]} " , "post-load" )
158160 else :
159161 logger .warning ("No CLIP/text encoder weights in checkpoint." )
160162 else :
161163 logger .warning ("CLIP target not found in model config." )
162164
163165 finally :
164- # --- Restore original device contexts and clean up ---
165166 set_current_device (original_main_device )
166167 set_current_text_encoder_device (original_clip_device )
167168 if config_hash in checkpoint_device_config :
168169 del checkpoint_device_config [config_hash ]
169170 if config_hash in checkpoint_distorch_config :
170171 del checkpoint_distorch_config [config_hash ]
171- logger .info (f"Restored original device contexts. UNet/VAE='{ original_main_device } ', CLIP='{ original_clip_device } '" )
172- logger .info ("--- [MultiGPU] EXITING Patched Checkpoint Loader ---" )
173-
174172 return (model_patcher , clip , vae , clipvision )
175173
176174class CheckpointLoaderAdvancedMultiGPU :
0 commit comments