Skip to content

Commit f07c2d2

Browse files
committed
feat: Add advanced checkpoint loaders for MultiGPU and DisTorch2
1 parent 4d0d4a6 commit f07c2d2

2 files changed

Lines changed: 318 additions & 0 deletions

File tree

__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,18 @@ def check_module_exists(module_path):
191191
override_class_with_distorch_safetensor_v2
192192
)
193193

194+
# Import advanced checkpoint loaders
195+
from .checkpoint_multigpu import (
196+
CheckpointLoaderAdvancedMultiGPU,
197+
CheckpointLoaderAdvancedDisTorch2MultiGPU
198+
)
199+
194200
# Initialize NODE_CLASS_MAPPINGS
195201
NODE_CLASS_MAPPINGS = {
196202
"DeviceSelectorMultiGPU": DeviceSelectorMultiGPU,
197203
"HunyuanVideoEmbeddingsAdapter": HunyuanVideoEmbeddingsAdapter,
204+
"CheckpointLoaderAdvancedMultiGPU": CheckpointLoaderAdvancedMultiGPU,
205+
"CheckpointLoaderAdvancedDisTorch2MultiGPU": CheckpointLoaderAdvancedDisTorch2MultiGPU,
198206
}
199207

200208
# Standard MultiGPU nodes

checkpoint_multigpu.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
"""
2+
Advanced Checkpoint Loaders for MultiGPU
3+
Provides device-specific and DisTorch2 sharding for checkpoint components
4+
"""
5+
6+
import torch
7+
import logging
8+
import hashlib
9+
import copy
10+
import comfy.sd
11+
import comfy.utils
12+
import comfy.model_management as mm
13+
from .device_utils import get_device_list
14+
from .distorch_2 import safetensor_allocation_store, create_safetensor_model_hash
15+
16+
logger = logging.getLogger("MultiGPU")
17+
18+
# Store checkpoint loading configurations
19+
checkpoint_device_config = {}
20+
checkpoint_distorch_config = {}
21+
22+
# Store the original function
23+
original_load_state_dict_guess_config = None
24+
25+
def create_checkpoint_config_hash(checkpoint_name, config_str):
26+
"""Create a unique hash for checkpoint configuration"""
27+
identifier = f"{checkpoint_name}_{config_str}"
28+
return hashlib.sha256(identifier.encode()).hexdigest()
29+
30+
def patch_load_state_dict_guess_config():
31+
"""Monkey patch the load_state_dict_guess_config function to support per-component device selection"""
32+
global original_load_state_dict_guess_config
33+
34+
if original_load_state_dict_guess_config is not None:
35+
return # Already patched
36+
37+
original_load_state_dict_guess_config = comfy.sd.load_state_dict_guess_config
38+
39+
def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False,
40+
embedding_directory=None, output_model=True, model_options={},
41+
te_model_options={}, metadata=None):
42+
43+
# Import here to avoid circular imports
44+
from . import set_current_device, set_current_text_encoder_device, current_device, current_text_encoder_device
45+
46+
# Check if we have a device configuration for this checkpoint
47+
# We use the state dict size as a simple identifier
48+
sd_size = sum(t.numel() for t in sd.values() if hasattr(t, 'numel'))
49+
config_hash = str(sd_size)
50+
51+
device_config = checkpoint_device_config.get(config_hash)
52+
distorch_config = checkpoint_distorch_config.get(config_hash)
53+
54+
if device_config or distorch_config:
55+
logger.info(f"[MultiGPU] Using custom device configuration for checkpoint")
56+
57+
# Save original devices
58+
original_unet_device = current_device
59+
original_clip_device = current_text_encoder_device
60+
61+
# Handle UNet device/DisTorch config
62+
if device_config and 'unet_device' in device_config:
63+
set_current_device(device_config['unet_device'])
64+
logger.info(f"[MultiGPU] Setting UNet device to: {device_config['unet_device']}")
65+
66+
# Apply DisTorch2 config for UNet if present
67+
if distorch_config and 'unet_allocation' in distorch_config:
68+
# We'll store this for when the model patcher is created
69+
logger.info(f"[MultiGPU] DisTorch2 UNet allocation will be applied: {distorch_config['unet_allocation']}")
70+
71+
# Call original function to load the checkpoint
72+
result = original_load_state_dict_guess_config(
73+
sd, output_vae=output_vae, output_clip=output_clip, output_clipvision=output_clipvision,
74+
embedding_directory=embedding_directory, output_model=output_model,
75+
model_options=model_options, te_model_options=te_model_options, metadata=metadata
76+
)
77+
78+
model_patcher, clip, vae, clipvision = result
79+
80+
# Apply DisTorch2 configurations after loading
81+
if distorch_config:
82+
if model_patcher and 'unet_allocation' in distorch_config:
83+
model_hash = create_safetensor_model_hash(model_patcher, "checkpoint_loader")
84+
safetensor_allocation_store[model_hash] = distorch_config['unet_allocation']
85+
if 'unet_settings' in distorch_config:
86+
from .distorch_2 import safetensor_settings_store
87+
safetensor_settings_store[model_hash] = distorch_config['unet_settings']
88+
logger.info(f"[MultiGPU] Applied DisTorch2 config to UNet: {model_hash[:8]}")
89+
90+
if clip and 'clip_allocation' in distorch_config:
91+
# For CLIP, we need to get the model from the CLIP object
92+
if hasattr(clip, 'patcher'):
93+
clip_hash = create_safetensor_model_hash(clip.patcher, "checkpoint_loader_clip")
94+
safetensor_allocation_store[clip_hash] = distorch_config['clip_allocation']
95+
if 'clip_settings' in distorch_config:
96+
from .distorch_2 import safetensor_settings_store
97+
safetensor_settings_store[clip_hash] = distorch_config['clip_settings']
98+
logger.info(f"[MultiGPU] Applied DisTorch2 config to CLIP: {clip_hash[:8]}")
99+
100+
# Handle CLIP device
101+
if device_config and 'clip_device' in device_config and clip:
102+
set_current_text_encoder_device(device_config['clip_device'])
103+
logger.info(f"[MultiGPU] Setting CLIP device to: {device_config['clip_device']}")
104+
# Force CLIP to load on the specified device
105+
if hasattr(clip, 'patcher'):
106+
clip.patcher.load(force_patch_weights=True)
107+
108+
# Handle VAE device
109+
if device_config and 'vae_device' in device_config and vae:
110+
vae_device = torch.device(device_config['vae_device'])
111+
logger.info(f"[MultiGPU] Setting VAE device to: {device_config['vae_device']}")
112+
# Move VAE to specified device
113+
if hasattr(vae, 'first_stage_model'):
114+
vae.first_stage_model = vae.first_stage_model.to(vae_device)
115+
116+
# Clean up stored configs
117+
if config_hash in checkpoint_device_config:
118+
del checkpoint_device_config[config_hash]
119+
if config_hash in checkpoint_distorch_config:
120+
del checkpoint_distorch_config[config_hash]
121+
122+
return result
123+
else:
124+
# No custom config, use original behavior
125+
return original_load_state_dict_guess_config(
126+
sd, output_vae=output_vae, output_clip=output_clip, output_clipvision=output_clipvision,
127+
embedding_directory=embedding_directory, output_model=output_model,
128+
model_options=model_options, te_model_options=te_model_options, metadata=metadata
129+
)
130+
131+
# Apply the patch
132+
comfy.sd.load_state_dict_guess_config = patched_load_state_dict_guess_config
133+
logger.info("[MultiGPU] Successfully patched load_state_dict_guess_config")
134+
135+
136+
class CheckpointLoaderAdvancedMultiGPU:
137+
"""
138+
Checkpoint loader that allows loading UNet, CLIP, and VAE to different devices
139+
"""
140+
@classmethod
141+
def INPUT_TYPES(s):
142+
import folder_paths
143+
devices = get_device_list()
144+
default_device = devices[1] if len(devices) > 1 else devices[0]
145+
146+
return {
147+
"required": {
148+
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
149+
"unet_device": (devices, {"default": default_device}),
150+
"clip_device": (devices, {"default": default_device}),
151+
"vae_device": (devices, {"default": default_device}),
152+
}
153+
}
154+
155+
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
156+
FUNCTION = "load_checkpoint"
157+
CATEGORY = "multigpu"
158+
TITLE = "Checkpoint Loader Advanced (MultiGPU)"
159+
160+
def load_checkpoint(self, ckpt_name, unet_device, clip_device, vae_device):
161+
# Apply the patch if not already applied
162+
patch_load_state_dict_guess_config()
163+
164+
# Store device configuration
165+
import folder_paths
166+
import comfy.utils
167+
168+
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
169+
sd = comfy.utils.load_torch_file(ckpt_path)
170+
171+
# Use state dict size as identifier
172+
sd_size = sum(t.numel() for t in sd.values() if hasattr(t, 'numel'))
173+
config_hash = str(sd_size)
174+
175+
# Store the device configuration
176+
checkpoint_device_config[config_hash] = {
177+
'unet_device': unet_device,
178+
'clip_device': clip_device,
179+
'vae_device': vae_device
180+
}
181+
182+
logger.info(f"[MultiGPU] CheckpointLoaderAdvanced configured - UNet: {unet_device}, CLIP: {clip_device}, VAE: {vae_device}")
183+
184+
# Load the checkpoint - our patched function will handle device placement
185+
from nodes import CheckpointLoaderSimple
186+
loader = CheckpointLoaderSimple()
187+
return loader.load_checkpoint(ckpt_name)
188+
189+
190+
class CheckpointLoaderAdvancedDisTorch2MultiGPU:
191+
"""
192+
Checkpoint loader with full DisTorch2 sharding for UNet and CLIP, device selection for VAE
193+
"""
194+
@classmethod
195+
def INPUT_TYPES(s):
196+
import folder_paths
197+
devices = get_device_list()
198+
compute_device = devices[1] if len(devices) > 1 else devices[0]
199+
200+
return {
201+
"required": {
202+
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
203+
# UNet DisTorch2 settings
204+
"unet_compute_device": (devices, {"default": compute_device}),
205+
"unet_virtual_vram_gb": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 128.0, "step": 0.1}),
206+
"unet_donor_device": (devices, {"default": "cpu"}),
207+
# CLIP DisTorch2 settings
208+
"clip_compute_device": (devices, {"default": compute_device}),
209+
"clip_virtual_vram_gb": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 128.0, "step": 0.1}),
210+
"clip_donor_device": (devices, {"default": "cpu"}),
211+
# VAE simple device
212+
"vae_device": (devices, {"default": compute_device}),
213+
},
214+
"optional": {
215+
"unet_expert_mode_allocations": ("STRING", {"multiline": False, "default": ""}),
216+
"clip_expert_mode_allocations": ("STRING", {"multiline": False, "default": ""}),
217+
"high_precision_loras": ("BOOLEAN", {"default": True}),
218+
}
219+
}
220+
221+
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
222+
FUNCTION = "load_checkpoint"
223+
CATEGORY = "multigpu/distorch_2"
224+
TITLE = "Checkpoint Loader Advanced (DisTorch2)"
225+
226+
def load_checkpoint(self, ckpt_name,
227+
unet_compute_device, unet_virtual_vram_gb, unet_donor_device,
228+
clip_compute_device, clip_virtual_vram_gb, clip_donor_device,
229+
vae_device,
230+
unet_expert_mode_allocations="", clip_expert_mode_allocations="",
231+
high_precision_loras=True):
232+
233+
# Apply the patch if not already applied
234+
patch_load_state_dict_guess_config()
235+
236+
# Register DisTorch2 model patcher
237+
from .distorch_2 import register_patched_safetensor_modelpatcher
238+
register_patched_safetensor_modelpatcher()
239+
240+
# Store device configuration
241+
import folder_paths
242+
import comfy.utils
243+
244+
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
245+
sd = comfy.utils.load_torch_file(ckpt_path)
246+
247+
# Use state dict size as identifier
248+
sd_size = sum(t.numel() for t in sd.values() if hasattr(t, 'numel'))
249+
config_hash = str(sd_size)
250+
251+
# Store device configuration
252+
checkpoint_device_config[config_hash] = {
253+
'unet_device': unet_compute_device,
254+
'clip_device': clip_compute_device,
255+
'vae_device': vae_device
256+
}
257+
258+
# Build DisTorch2 allocation strings
259+
unet_vram_string = ""
260+
if unet_virtual_vram_gb > 0:
261+
unet_vram_string = f"{unet_compute_device};{unet_virtual_vram_gb};{unet_donor_device}"
262+
elif unet_expert_mode_allocations:
263+
unet_vram_string = unet_compute_device
264+
265+
unet_allocation = f"{unet_expert_mode_allocations}#{unet_vram_string}" if unet_expert_mode_allocations or unet_vram_string else ""
266+
267+
clip_vram_string = ""
268+
if clip_virtual_vram_gb > 0:
269+
clip_vram_string = f"{clip_compute_device};{clip_virtual_vram_gb};{clip_donor_device}"
270+
elif clip_expert_mode_allocations:
271+
clip_vram_string = clip_compute_device
272+
273+
clip_allocation = f"{clip_expert_mode_allocations}#{clip_vram_string}" if clip_expert_mode_allocations or clip_vram_string else ""
274+
275+
# Create settings hashes for DisTorch2
276+
unet_settings_str = f"{unet_compute_device}{unet_virtual_vram_gb}{unet_donor_device}{unet_expert_mode_allocations}{high_precision_loras}"
277+
unet_settings_hash = hashlib.sha256(unet_settings_str.encode()).hexdigest()
278+
279+
clip_settings_str = f"{clip_compute_device}{clip_virtual_vram_gb}{clip_donor_device}{clip_expert_mode_allocations}{high_precision_loras}"
280+
clip_settings_hash = hashlib.sha256(clip_settings_str.encode()).hexdigest()
281+
282+
# Store DisTorch2 configuration
283+
checkpoint_distorch_config[config_hash] = {
284+
'unet_allocation': unet_allocation,
285+
'unet_settings': unet_settings_hash,
286+
'clip_allocation': clip_allocation,
287+
'clip_settings': clip_settings_hash,
288+
'high_precision_loras': high_precision_loras
289+
}
290+
291+
logger.info(f"[MultiGPU] CheckpointLoaderDisTorch2 configured:")
292+
logger.info(f" UNet: compute={unet_compute_device}, vram={unet_virtual_vram_gb}GB, donor={unet_donor_device}")
293+
logger.info(f" CLIP: compute={clip_compute_device}, vram={clip_virtual_vram_gb}GB, donor={clip_donor_device}")
294+
logger.info(f" VAE: device={vae_device}")
295+
296+
# Load the checkpoint - our patched function will handle device placement and DisTorch2
297+
from nodes import CheckpointLoaderSimple
298+
loader = CheckpointLoaderSimple()
299+
300+
# Set high precision loras flag
301+
result = loader.load_checkpoint(ckpt_name)
302+
303+
# Store high_precision_loras in the models
304+
model_patcher, clip, vae = result
305+
if model_patcher and hasattr(model_patcher, 'model'):
306+
model_patcher.model._distorch_high_precision_loras = high_precision_loras
307+
if clip and hasattr(clip, 'patcher') and hasattr(clip.patcher, 'model'):
308+
clip.patcher.model._distorch_high_precision_loras = high_precision_loras
309+
310+
return result

0 commit comments

Comments
 (0)