Skip to content

Commit 2aaa033

Browse files
committed
feat: add starter workflow for MultiGPU setup with detailed instructions, fix checkpoint loading.
1 parent 0b438f7 commit 2aaa033

3 files changed

Lines changed: 779 additions & 16 deletions

File tree

checkpoint_multigpu.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def patch_load_state_dict_guess_config():
3232

3333
def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False,
3434
embedding_directory=None, output_model=True, model_options={},
35-
te_model_options={}, metadata=None):
35+
te_model_options={}, metadata=None, disable_dynamic=False):
3636
"""Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
3737
from . import set_current_device, set_current_text_encoder_device, get_current_device, get_current_text_encoder_device
3838

@@ -42,7 +42,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
4242
distorch_config = checkpoint_distorch_config.get(config_hash)
4343

4444
if not device_config and not distorch_config:
45-
return original_load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options, metadata)
45+
return original_load_state_dict_guess_config(
46+
sd,
47+
output_vae=output_vae,
48+
output_clip=output_clip,
49+
output_clipvision=output_clipvision,
50+
embedding_directory=embedding_directory,
51+
output_model=output_model,
52+
model_options=model_options,
53+
te_model_options=te_model_options,
54+
metadata=metadata,
55+
disable_dynamic=disable_dynamic,
56+
)
4657

4758
logger.debug("[MultiGPU Checkpoint] ENTERING Patched Checkpoint Loader")
4859
logger.debug(f"[MultiGPU Checkpoint] Received Device Config: {device_config}")
@@ -73,7 +84,12 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
7384
logger.warning("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only.")
7485
# Simplified fallback for non-checkpoints
7586
set_current_device(device_config.get('unet_device', original_main_device))
76-
diffusion_model = comfy.sd.load_diffusion_model_state_dict(sd, model_options={})
87+
diffusion_model = comfy.sd.load_diffusion_model_state_dict(
88+
sd,
89+
model_options={},
90+
metadata=metadata,
91+
disable_dynamic=disable_dynamic,
92+
)
7793
if diffusion_model is None:
7894
return None
7995
return (diffusion_model, None, VAE(sd={}), None)
@@ -90,11 +106,11 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
90106
if unet_dtype is None:
91107
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
92108

93-
unet_compute_device = device_config.get('unet_device', original_main_device)
109+
unet_compute_device = torch.device(device_config.get('unet_device', original_main_device))
94110
if model_config.scaled_fp8 is not None:
95-
manual_cast_dtype = mm.unet_manual_cast(None, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
111+
manual_cast_dtype = mm.unet_manual_cast(None, unet_compute_device, model_config.supported_inference_dtypes)
96112
else:
97-
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
113+
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, unet_compute_device, model_config.supported_inference_dtypes)
98114
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
99115
logger.info(f"UNet DType: {unet_dtype}, Manual Cast: {manual_cast_dtype}")
100116

@@ -103,19 +119,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
103119
clipvision = comfy.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
104120

105121
if output_model:
106-
unet_compute_device = device_config.get('unet_device', original_main_device)
122+
unet_compute_device = torch.device(device_config.get('unet_device', original_main_device))
107123
set_current_device(unet_compute_device)
108124
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
109125

110126
multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load")
111127

112128
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
113-
model.load_model_weights(sd, diffusion_model_prefix)
129+
model_patcher_class = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
130+
model_patcher = model_patcher_class(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
131+
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
114132
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
115133

116134
logger.mgpu_mm_log("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup")
117135
soft_empty_cache_multigpu()
118-
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
119136
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-model")
120137

121138
if distorch_config and 'unet_allocation' in distorch_config:
@@ -159,7 +176,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
159176
out_sd[k] = quant_sd[k]
160177
sd = out_sd
161178

162-
clip_target_device = device_config.get('clip_device', original_clip_device)
179+
clip_target_device = torch.device(device_config.get('clip_device', original_clip_device))
163180
set_current_text_encoder_device(clip_target_device)
164181

165182
clip_target = model_config.clip_target(state_dict=sd)
@@ -170,7 +187,15 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
170187
multigpu_memory_log(f"clip:{config_hash[:8]}", "pre-load")
171188
soft_empty_cache_multigpu()
172189
clip_params = comfy.utils.calculate_parameters(clip_sd)
173-
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=clip_params, model_options=te_model_options)
190+
clip = CLIP(
191+
clip_target,
192+
embedding_directory=embedding_directory,
193+
tokenizer_data=clip_sd,
194+
parameters=clip_params,
195+
state_dict=clip_sd,
196+
model_options=te_model_options,
197+
disable_dynamic=disable_dynamic,
198+
)
174199

175200
if distorch_config and 'clip_allocation' in distorch_config:
176201
clip_alloc = distorch_config['clip_allocation']
@@ -181,11 +206,6 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
181206
logger.info(f"[CHECKPOINT_META] CLIP inner_model id=0x{id(inner_clip):x}")
182207
clip.patcher.model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
183208

184-
m, u = clip.load_sd(clip_sd, full_model=True) # This respects the patched text_encoder_device
185-
if len(m) > 0:
186-
logger.warning(f"CLIP missing keys: {m}")
187-
if len(u) > 0:
188-
logger.debug(f"CLIP unexpected keys: {u}")
189209
logger.info("CLIP Loaded.")
190210
multigpu_memory_log(f"clip:{config_hash[:8]}", "post-load")
191211
else:
239 KB
Loading

0 commit comments

Comments
 (0)