Skip to content

Commit 0b1511e

Browse files
committed
refactor: Simplify checkpoint loading and fix text encoder device
This commit introduces two main improvements: refactoring the checkpoint loading mechanism and fixing the initial device placement for the text encoder (CLIP). 1. **Fix Text Encoder Device Handling:** - A new patch is applied to `mm.text_encoder_initial_device` to gain control over the device used when the text encoder is first loaded. - The `CLIPLoader` override now forces `device='default'` to ensure ComfyUI's patching mechanism is triggered correctly, preventing the text encoder from being incorrectly placed on the wrong GPU. 2. **Refactor Checkpoint Loaders:** - Removed the global stores (`checkpoint_dtype_store`, `checkpoint_half_store`, `checkpoint_config_store`). - The `CheckpointLoaderSimpleMultiGPU` and `AdvCheckpointLoaderMultiGPU` nodes now use arguments and ComfyUI's internal defaults directly. This simplifies the logic, reduces global state, and makes the code easier to follow. Additionally, log message prefixes have been updated to be more descriptive, aiding in debugging.
1 parent 9e14e46 commit 0b1511e

2 files changed

Lines changed: 27 additions & 45 deletions

File tree

__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
handler.setFormatter(formatter)
2424
logger.addHandler(handler)
2525
logger.setLevel(log_level)
26-
logger.info(f"[MultiGPU] Logger initialized with level: {logging.getLevelName(log_level)}")
26+
logger.info(f"[MultiGPU Initialization] Logger initialized with level: {logging.getLevelName(log_level)}")
2727

2828

2929
# Global device state management
@@ -33,12 +33,13 @@
3333
def set_current_device(device):
3434
global current_device
3535
current_device = device
36-
logger.info(f"[MultiGPU] current_device set to: {device}")
36+
logger.info(f"[MultiGPU Initialization] current_device set to: {device}")
3737

3838
def set_current_text_encoder_device(device):
3939
global current_text_encoder_device
4040
current_text_encoder_device = device
41-
logger.info(f"[MultiGPU] current_text_encoder_device set to: {device}")
41+
current_text_encoder_initial_device = device
42+
logger.info(f"[MultiGPU Initialization] current_text_encoder_device and current_text_encoder_initial_device set to: {device}")
4243

4344
def override_class(cls):
4445
class NodeOverride(cls):
@@ -55,15 +56,12 @@ def INPUT_TYPES(s):
5556
FUNCTION = "override"
5657

5758
def override(self, *args, device=None, **kwargs):
58-
logger.debug(f"[MultiGPU] override_class called for {cls.__name__} with device={device}")
59-
59+
6060
if device is not None:
6161
set_current_device(device)
62-
6362
fn = getattr(super(), cls.FUNCTION)
6463
out = fn(*args, **kwargs)
65-
logger.debug(f"[MultiGPU] override_class for {cls.__name__} completed successfully")
66-
64+
6765
return out
6866

6967
return NodeOverride
@@ -85,7 +83,7 @@ def INPUT_TYPES(s):
8583
def override(self, *args, device=None, **kwargs):
8684
if device is not None:
8785
set_current_text_encoder_device(device)
88-
86+
kwargs['device'] = 'default'
8987
fn = getattr(super(), cls.FUNCTION)
9088
out = fn(*args, **kwargs)
9189

@@ -100,7 +98,7 @@ def get_torch_device_patched():
10098
else:
10199
devs = set(get_device_list())
102100
device = torch.device(current_device) if str(current_device) in devs else torch.device("cpu")
103-
logger.debug(f"[MultiGPU] get_torch_device_patched returning device: {device} (current_device={current_device})")
101+
logger.debug(f"[MultiGPU Core Patching] get_torch_device_patched returning device: {device} (current_device={current_device})")
104102
return device
105103

106104
def text_encoder_device_patched():
@@ -110,16 +108,22 @@ def text_encoder_device_patched():
110108
else:
111109
devs = set(get_device_list())
112110
device = torch.device(current_text_encoder_device) if str(current_text_encoder_device) in devs else torch.device("cpu")
113-
logger.debug(f"[MultiGPU] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
111+
logger.debug(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
114112
return device
115113

116-
# Apply patches
117-
logger.info(f"[MultiGPU] Patching mm.get_torch_device and mm.text_encoder_device")
118-
logger.debug(f"[MultiGPU] Initial current_device: {current_device}")
119-
logger.debug(f"[MultiGPU] Initial current_text_encoder_device: {current_text_encoder_device}")
114+
def text_encoder_initial_device_patched(*args, **kwargs):
115+
logger.debug(f"[MultiGPU Core Patching] text_encoder_initial_device_patched called with args={args}, kwargs={kwargs}")
116+
# look at this later - I am not convinced that this isn't the better choice:
117+
# return text_encoder_device_patched()
118+
return mm.text_encoder_device()
119+
120+
121+
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, and mm.text_encoder_initial_device")
122+
logger.debug(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
123+
logger.debug(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
120124
mm.get_torch_device = get_torch_device_patched
121125
mm.text_encoder_device = text_encoder_device_patched
122-
logger.debug(f"[MultiGPU] Patches applied successfully")
126+
mm.text_encoder_initial_device = text_encoder_initial_device_patched
123127

124128
def check_module_exists(module_path):
125129
full_path = os.path.join(folder_paths.get_folder_paths("custom_nodes")[0], module_path)

checkpoint_multigpu.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
logger = logging.getLogger("MultiGPU")
1919

20-
# --- Global Stores for Configuration ---
2120
checkpoint_device_config = {}
2221
checkpoint_distorch_config = {}
2322

24-
# --- Original Function Store ---
2523
original_load_state_dict_guess_config = None
2624

2725
def patch_load_state_dict_guess_config():
@@ -45,34 +43,29 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
4543

4644
from . import set_current_device, set_current_text_encoder_device, current_device, current_text_encoder_device
4745

48-
# --- Check for custom configuration ---
4946
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
5047
config_hash = str(sd_size)
5148
device_config = checkpoint_device_config.get(config_hash)
5249
distorch_config = checkpoint_distorch_config.get(config_hash)
5350

5451
if not device_config and not distorch_config:
55-
# No config, fall back to original untouched function
5652
return original_load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options, metadata)
5753

5854
logger.info("--- [MultiGPU] ENTERING Patched Checkpoint Loader ---")
5955
logger.info(f"Received Device Config: {device_config}")
6056
logger.info(f"Received DisTorch2 Config: {distorch_config}")
6157

62-
# --- Start of Rewritten Logic ---
6358
clip = None
6459
clipvision = None
6560
vae = None
6661
model = None
6762
model_patcher = None
6863

69-
# Store original device contexts to restore later
7064
original_main_device = current_device
7165
original_clip_device = current_text_encoder_device
7266
logger.info(f"Saved original device contexts: UNet/VAE='{original_main_device}', CLIP='{original_clip_device}'")
7367

7468
try:
75-
# --- Model Configuration Detection (Replicated from original) ---
7669
diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(sd)
7770
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
7871
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
@@ -94,31 +87,26 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
9487
weight_dtype = None
9588

9689
model_config.custom_operations = model_options.get("custom_operations", None)
97-
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)))
90+
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
9891
if unet_dtype is None:
9992
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
10093

101-
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(device_config.get('unet_device')), model_config.supported_inference_dtypes)
94+
unet_compute_device = device_config.get('unet_device', original_main_device)
95+
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
10296
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
10397
logger.info(f"UNet DType: {unet_dtype}, Manual Cast: {manual_cast_dtype}")
10498

105-
# --- CLIP Vision Loading ---
99+
106100
if model_config.clip_vision_prefix is not None and output_clipvision:
107-
logger.info("--- Loading CLIP Vision ---")
108101
clipvision = comfy.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
109-
logger.info("CLIP Vision Loaded.")
110102

111-
# --- UNet Loading Block ---
112103
if output_model:
113-
logger.info("--- Loading UNet ---")
114104
unet_compute_device = device_config.get('unet_device', original_main_device)
115-
set_current_device(unet_compute_device)
116-
logger.info(f"Set UNet context to: {unet_compute_device}")
117-
105+
set_current_device(unet_compute_device)
118106
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
119-
logger.info(f"UNet initial load device: {inital_load_device}")
120-
107+
121108
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
109+
122110
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
123111

124112
if distorch_config and 'unet_allocation' in distorch_config:
@@ -131,28 +119,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
131119
logger.info(f"Stored DisTorch2 config for UNet (hash {model_hash[:8]}): {distorch_config['unet_allocation']}")
132120

133121
model.load_model_weights(sd, diffusion_model_prefix)
134-
logger.info("UNet Loaded.")
135122

136-
# --- VAE Loading Block ---
137123
if output_vae:
138-
logger.info("--- Loading VAE ---")
139124
vae_target_device = torch.device(device_config.get('vae_device', original_main_device))
140125
set_current_device(vae_target_device) # Use main device context for VAE
141-
logger.info(f"Set VAE context to: {vae_target_device}")
142126

143127
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
144128
vae_sd = model_config.process_vae_state_dict(vae_sd)
145-
146-
# The VAE class itself respects the mm.get_torch_device() patch
147129
vae = VAE(sd=vae_sd, metadata=metadata)
148-
logger.info(f"VAE Loaded. Final device should be: {vae_target_device}")
149130

150-
# --- CLIP Loading Block ---
151131
if output_clip:
152-
logger.info("--- Loading CLIP ---")
153132
clip_target_device = device_config.get('clip_device', original_clip_device)
154133
set_current_text_encoder_device(clip_target_device)
155-
logger.info(f"Set CLIP context to: {clip_target_device}")
156134

157135
clip_target = model_config.clip_target(state_dict=sd)
158136
if clip_target is not None:

0 commit comments

Comments
 (0)