Skip to content

Commit a1b7b1f

Browse files
authored
Merge pull request #114 from pollockjj/d2_clip
Major Refactor
2 parents 57c7d3d + 2a6a8f4 commit a1b7b1f

49 files changed

Lines changed: 2661 additions & 2567 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
# Python and IDE
22
__pycache__/
3-
.vscode/settings.json
3+
.clinerules
4+
.vscode
5+
memory-bank/

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,15 @@ Currently supported nodes (automatically detected if available):
112112

113113
All MultiGPU nodes available for your install can be found in the "multigpu" category in the node menu.
114114

115+
## Node Documentation
116+
117+
Detailed technical documentation is available for all **automatically-detected core MultiGPU and DisTorch2 nodes**, covering 36+ documented nodes with comprehensive parameter details, output specifications, and DisTorch2 allocation guidance where applicable.
118+
119+
- **To access documentation**: Click on any core MultiGPU or DisTorch2 node in ComfyUI and select "Help" (question mark inside a circle) from the resultant menu
120+
- **Coverage**: All standard ComfyUI loader nodes (UNet, VAE, Checkpoints, CLIP, ControlNet, Diffusers) plus popular GGUF loader variants
121+
- **Contents**: Input parameters with data types and descriptions, output specifications, usage examples, and DisTorch2 distributed loading explanations with allocation modes and strategies
122+
- **Note**: Documentation covers core ComfyUI-MultiGPU functionality only. Third-party custom node integrations (WanVideoWrapper, Florence2, etc.) have their own separate documentation.
123+
115124
## Example workflows
116125

117126
All workflows have been tested on a 2x 3090 + 1060ti linux setup, a 4070 win 11 setup, and a 3090/1070ti linux setup.

__init__.py

Lines changed: 59 additions & 145 deletions
Large diffs are not rendered by default.

checkpoint_multigpu.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
"""
2-
Advanced Checkpoint Loaders for MultiGPU
3-
Provides device-specific and DisTorch2 sharding for checkpoint components
4-
"""
5-
61
import torch
72
import logging
83
import hashlib
@@ -13,6 +8,7 @@
138
import comfy.clip_vision
149
from comfy.sd import VAE, CLIP
1510
from .device_utils import get_device_list, soft_empty_cache_multigpu
11+
from .model_management_mgpu import multigpu_memory_log
1612
from .distorch_2 import safetensor_allocation_store, safetensor_settings_store, create_safetensor_model_hash, register_patched_safetensor_modelpatcher
1713

1814
logger = logging.getLogger("MultiGPU")
@@ -23,24 +19,21 @@
2319
original_load_state_dict_guess_config = None
2420

2521
def patch_load_state_dict_guess_config():
26-
"""
27-
Monkey patch the load_state_dict_guess_config function to replace its logic
28-
with a MultiGPU-aware implementation.
29-
"""
22+
"""Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading."""
3023
global original_load_state_dict_guess_config
3124

3225
if original_load_state_dict_guess_config is not None:
33-
logger.info("[MultiGPU] load_state_dict_guess_config is already patched.")
26+
logger.debug("[MultiGPU Checkpoint] load_state_dict_guess_config is already patched.")
3427
return
3528

36-
logger.info("[MultiGPU] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading.")
29+
logger.info("[MultiGPU Core Patching] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading.")
3730
original_load_state_dict_guess_config = comfy.sd.load_state_dict_guess_config
3831
comfy.sd.load_state_dict_guess_config = patched_load_state_dict_guess_config
3932

4033
def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False,
4134
embedding_directory=None, output_model=True, model_options={},
4235
te_model_options={}, metadata=None):
43-
36+
"""Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
4437
from . import set_current_device, set_current_text_encoder_device, current_device, current_text_encoder_device
4538

4639
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
@@ -51,9 +44,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5144
if not device_config and not distorch_config:
5245
return original_load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options, metadata)
5346

54-
logger.info("--- [MultiGPU] ENTERING Patched Checkpoint Loader ---")
55-
logger.info(f"Received Device Config: {device_config}")
56-
logger.info(f"Received DisTorch2 Config: {distorch_config}")
47+
logger.debug("[MultiGPU Checkpoint] ENTERING Patched Checkpoint Loader")
48+
logger.debug(f"[MultiGPU Checkpoint] Received Device Config: {device_config}")
49+
logger.debug(f"[MultiGPU Checkpoint] Received DisTorch2 Config: {distorch_config}")
5750

5851
clip = None
5952
clipvision = None
@@ -63,7 +56,6 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6356

6457
original_main_device = current_device
6558
original_clip_device = current_text_encoder_device
66-
logger.info(f"Saved original device contexts: UNet/VAE='{original_main_device}', CLIP='{original_clip_device}'")
6759

6860
try:
6961
diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(sd)
@@ -80,7 +72,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
8072
return None
8173
return (diffusion_model, None, VAE(sd={}), None)
8274

83-
logger.info(f"[MultiGPU] Detected Model Config: {type(model_config).__name__}, Parameters: {parameters/10**9:.2f}B")
75+
logger.debug(f"[MultiGPU] Detected Model Config: {type(model_config).__name__}, Parameters: {parameters/10**9:.2f}B")
8476

8577
unet_weight_dtype = list(model_config.supported_inference_dtypes)
8678
if model_config.scaled_fp8 is not None:
@@ -105,10 +97,14 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
10597
set_current_device(unet_compute_device)
10698
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
10799

100+
multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load")
101+
108102
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
109103

110-
soft_empty_cache_multigpu(logger)
104+
logger.mgpu_mm_log("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup")
105+
soft_empty_cache_multigpu()
111106
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
107+
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-model")
112108

113109
if distorch_config and 'unet_allocation' in distorch_config:
114110
register_patched_safetensor_modelpatcher()
@@ -117,17 +113,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
117113
safetensor_settings_store[model_hash] = distorch_config.get('unet_settings','')
118114
model.is_distorch = True
119115
model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
120-
logger.info(f"Stored DisTorch2 config for UNet (hash {model_hash[:8]}): {distorch_config['unet_allocation']}")
116+
logger.mgpu_mm_log(f"Stored DisTorch2 config for UNet (hash {model_hash[:8]}): {distorch_config['unet_allocation']}")
121117

122118
model.load_model_weights(sd, diffusion_model_prefix)
119+
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
123120

124121
if output_vae:
125122
vae_target_device = torch.device(device_config.get('vae_device', original_main_device))
126123
set_current_device(vae_target_device) # Use main device context for VAE
124+
multigpu_memory_log(f"vae:{config_hash[:8]}", "pre-load")
127125

128126
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
129127
vae_sd = model_config.process_vae_state_dict(vae_sd)
130128
vae = VAE(sd=vae_sd, metadata=metadata)
129+
multigpu_memory_log(f"vae:{config_hash[:8]}", "post-load")
131130

132131
if output_clip:
133132
clip_target_device = device_config.get('clip_device', original_clip_device)
@@ -137,7 +136,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
137136
if clip_target is not None:
138137
clip_sd = model_config.process_clip_state_dict(sd)
139138
if len(clip_sd) > 0:
140-
soft_empty_cache_multigpu(logger)
139+
logger.debug("[MultiGPU Checkpoint] Invoking soft_empty_cache_multigpu before CLIP construction")
140+
multigpu_memory_log(f"clip:{config_hash[:8]}", "pre-load")
141+
soft_empty_cache_multigpu()
141142
clip_params = comfy.utils.calculate_parameters(clip_sd)
142143
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=clip_params, model_options=te_model_options)
143144

@@ -155,22 +156,19 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
155156
if len(m) > 0: logger.warning(f"CLIP missing keys: {m}")
156157
if len(u) > 0: logger.debug(f"CLIP unexpected keys: {u}")
157158
logger.info("CLIP Loaded.")
159+
multigpu_memory_log(f"clip:{config_hash[:8]}", "post-load")
158160
else:
159161
logger.warning("No CLIP/text encoder weights in checkpoint.")
160162
else:
161163
logger.warning("CLIP target not found in model config.")
162164

163165
finally:
164-
# --- Restore original device contexts and clean up ---
165166
set_current_device(original_main_device)
166167
set_current_text_encoder_device(original_clip_device)
167168
if config_hash in checkpoint_device_config:
168169
del checkpoint_device_config[config_hash]
169170
if config_hash in checkpoint_distorch_config:
170171
del checkpoint_distorch_config[config_hash]
171-
logger.info(f"Restored original device contexts. UNet/VAE='{original_main_device}', CLIP='{original_clip_device}'")
172-
logger.info("--- [MultiGPU] EXITING Patched Checkpoint Loader ---")
173-
174172
return (model_patcher, clip, vae, clipvision)
175173

176174
class CheckpointLoaderAdvancedMultiGPU:

0 commit comments

Comments
 (0)