1717
1818logger = logging .getLogger ("MultiGPU" )
1919
20- # --- Global Stores for Configuration ---
2120checkpoint_device_config = {}
2221checkpoint_distorch_config = {}
2322
24- # --- Original Function Store ---
2523original_load_state_dict_guess_config = None
2624
2725def patch_load_state_dict_guess_config ():
@@ -45,34 +43,29 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
4543
4644 from . import set_current_device , set_current_text_encoder_device , current_device , current_text_encoder_device
4745
48- # --- Check for custom configuration ---
4946 sd_size = sum (p .numel () for p in sd .values () if hasattr (p , 'numel' ))
5047 config_hash = str (sd_size )
5148 device_config = checkpoint_device_config .get (config_hash )
5249 distorch_config = checkpoint_distorch_config .get (config_hash )
5350
5451 if not device_config and not distorch_config :
55- # No config, fall back to original untouched function
5652 return original_load_state_dict_guess_config (sd , output_vae , output_clip , output_clipvision , embedding_directory , output_model , model_options , te_model_options , metadata )
5753
5854 logger .info ("--- [MultiGPU] ENTERING Patched Checkpoint Loader ---" )
5955 logger .info (f"Received Device Config: { device_config } " )
6056 logger .info (f"Received DisTorch2 Config: { distorch_config } " )
6157
62- # --- Start of Rewritten Logic ---
6358 clip = None
6459 clipvision = None
6560 vae = None
6661 model = None
6762 model_patcher = None
6863
69- # Store original device contexts to restore later
7064 original_main_device = current_device
7165 original_clip_device = current_text_encoder_device
7266 logger .info (f"Saved original device contexts: UNet/VAE='{ original_main_device } ', CLIP='{ original_clip_device } '" )
7367
7468 try :
75- # --- Model Configuration Detection (Replicated from original) ---
7669 diffusion_model_prefix = comfy .model_detection .unet_prefix_from_state_dict (sd )
7770 parameters = comfy .utils .calculate_parameters (sd , diffusion_model_prefix )
7871 weight_dtype = comfy .utils .weight_dtype (sd , diffusion_model_prefix )
@@ -94,31 +87,26 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
9487 weight_dtype = None
9588
9689 model_config .custom_operations = model_options .get ("custom_operations" , None )
97- unet_dtype = model_options .get ("dtype" , model_options .get ("weight_dtype" , None )))
90+ unet_dtype = model_options .get ("dtype" , model_options .get ("weight_dtype" , None ))
9891 if unet_dtype is None :
9992 unet_dtype = mm .unet_dtype (model_params = parameters , supported_dtypes = unet_weight_dtype , weight_dtype = weight_dtype )
10093
101- manual_cast_dtype = mm .unet_manual_cast (unet_dtype , torch .device (device_config .get ('unet_device' )), model_config .supported_inference_dtypes )
94+ unet_compute_device = device_config .get ('unet_device' , original_main_device )
95+ manual_cast_dtype = mm .unet_manual_cast (unet_dtype , torch .device (unet_compute_device ), model_config .supported_inference_dtypes )
10296 model_config .set_inference_dtype (unet_dtype , manual_cast_dtype )
10397 logger .info (f"UNet DType: { unet_dtype } , Manual Cast: { manual_cast_dtype } " )
10498
105- # --- CLIP Vision Loading ---
99+
106100 if model_config .clip_vision_prefix is not None and output_clipvision :
107- logger .info ("--- Loading CLIP Vision ---" )
108101 clipvision = comfy .clip_vision .load_clipvision_from_sd (sd , model_config .clip_vision_prefix , True )
109- logger .info ("CLIP Vision Loaded." )
110102
111- # --- UNet Loading Block ---
112103 if output_model :
113- logger .info ("--- Loading UNet ---" )
114104 unet_compute_device = device_config .get ('unet_device' , original_main_device )
115- set_current_device (unet_compute_device )
116- logger .info (f"Set UNet context to: { unet_compute_device } " )
117-
105+ set_current_device (unet_compute_device )
118106 inital_load_device = mm .unet_inital_load_device (parameters , unet_dtype )
119- logger .info (f"UNet initial load device: { inital_load_device } " )
120-
107+
121108 model = model_config .get_model (sd , diffusion_model_prefix , device = inital_load_device )
109+
122110 model_patcher = comfy .model_patcher .ModelPatcher (model , load_device = unet_compute_device , offload_device = mm .unet_offload_device ())
123111
124112 if distorch_config and 'unet_allocation' in distorch_config :
@@ -131,28 +119,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
131119 logger .info (f"Stored DisTorch2 config for UNet (hash { model_hash [:8 ]} ): { distorch_config ['unet_allocation' ]} " )
132120
133121 model .load_model_weights (sd , diffusion_model_prefix )
134- logger .info ("UNet Loaded." )
135122
136- # --- VAE Loading Block ---
137123 if output_vae :
138- logger .info ("--- Loading VAE ---" )
139124 vae_target_device = torch .device (device_config .get ('vae_device' , original_main_device ))
140125 set_current_device (vae_target_device ) # Use main device context for VAE
141- logger .info (f"Set VAE context to: { vae_target_device } " )
142126
143127 vae_sd = comfy .utils .state_dict_prefix_replace (sd , {k : "" for k in model_config .vae_key_prefix }, filter_keys = True )
144128 vae_sd = model_config .process_vae_state_dict (vae_sd )
145-
146- # The VAE class itself respects the mm.get_torch_device() patch
147129 vae = VAE (sd = vae_sd , metadata = metadata )
148- logger .info (f"VAE Loaded. Final device should be: { vae_target_device } " )
149130
150- # --- CLIP Loading Block ---
151131 if output_clip :
152- logger .info ("--- Loading CLIP ---" )
153132 clip_target_device = device_config .get ('clip_device' , original_clip_device )
154133 set_current_text_encoder_device (clip_target_device )
155- logger .info (f"Set CLIP context to: { clip_target_device } " )
156134
157135 clip_target = model_config .clip_target (state_dict = sd )
158136 if clip_target is not None :
0 commit comments