Skip to content

Commit afafc80

Browse files
committed
Add no-device variants for multi-GPU CLIP loaders
1 parent 5bb7add commit afafc80

3 files changed

Lines changed: 193 additions & 13 deletions

File tree

__init__.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,31 @@ def override(self, *args, device=None, **kwargs):
9292

9393
return NodeOverride
9494

95+
def override_class_clip_no_device(cls):
96+
class NodeOverride(cls):
97+
@classmethod
98+
def INPUT_TYPES(s):
99+
inputs = copy.deepcopy(cls.INPUT_TYPES())
100+
devices = get_device_list()
101+
default_device = devices[1] if len(devices) > 1 else devices[0]
102+
inputs["optional"] = inputs.get("optional", {})
103+
inputs["optional"]["device"] = (devices, {"default": default_device})
104+
return inputs
105+
106+
CATEGORY = "multigpu"
107+
FUNCTION = "override"
108+
109+
def override(self, *args, device=None, **kwargs):
110+
if device is not None:
111+
set_current_text_encoder_device(device)
112+
fn = getattr(super(), cls.FUNCTION)
113+
out = fn(*args, **kwargs)
114+
115+
return out
116+
117+
return NodeOverride
118+
119+
95120
def get_torch_device_patched():
96121
device = None
97122
if (not is_accelerator_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_device).lower()):
@@ -183,6 +208,7 @@ def check_module_exists(module_path):
183208
override_class_with_distorch_gguf,
184209
override_class_with_distorch_gguf_v2,
185210
override_class_with_distorch_clip,
211+
override_class_with_distorch_clip_no_device,
186212
override_class_with_distorch
187213
)
188214

@@ -194,7 +220,8 @@ def check_module_exists(module_path):
194220
analyze_safetensor_loading,
195221
calculate_safetensor_vvram_allocation,
196222
override_class_with_distorch_safetensor_v2,
197-
override_class_with_distorch_safetensor_v2_clip
223+
override_class_with_distorch_safetensor_v2_clip,
224+
override_class_with_distorch_safetensor_v2_clip_no_device
198225
)
199226

200227
# Import advanced checkpoint loaders
@@ -217,10 +244,10 @@ def check_module_exists(module_path):
217244
NODE_CLASS_MAPPINGS["CLIPLoaderMultiGPU"] = override_class_clip(GLOBAL_NODE_CLASS_MAPPINGS["CLIPLoader"])
218245
NODE_CLASS_MAPPINGS["DualCLIPLoaderMultiGPU"] = override_class_clip(GLOBAL_NODE_CLASS_MAPPINGS["DualCLIPLoader"])
219246
if "TripleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
220-
NODE_CLASS_MAPPINGS["TripleCLIPLoaderMultiGPU"] = override_class_clip(GLOBAL_NODE_CLASS_MAPPINGS["TripleCLIPLoader"])
247+
NODE_CLASS_MAPPINGS["TripleCLIPLoaderMultiGPU"] = override_class_clip_no_device(GLOBAL_NODE_CLASS_MAPPINGS["TripleCLIPLoader"])
221248
if "QuadrupleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
222-
NODE_CLASS_MAPPINGS["QuadrupleCLIPLoaderMultiGPU"] = override_class_clip(GLOBAL_NODE_CLASS_MAPPINGS["QuadrupleCLIPLoader"])
223-
NODE_CLASS_MAPPINGS["CLIPVisionLoaderMultiGPU"] = override_class_clip(GLOBAL_NODE_CLASS_MAPPINGS["CLIPVisionLoader"])
249+
NODE_CLASS_MAPPINGS["QuadrupleCLIPLoaderMultiGPU"] = override_class_clip_no_device(GLOBAL_NODE_CLASS_MAPPINGS["QuadrupleCLIPLoader"])
250+
NODE_CLASS_MAPPINGS["CLIPVisionLoaderMultiGPU"] = override_class_clip_no_device(GLOBAL_NODE_CLASS_MAPPINGS["CLIPVisionLoader"])
224251
NODE_CLASS_MAPPINGS["CheckpointLoaderSimpleMultiGPU"] = override_class(GLOBAL_NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"])
225252
NODE_CLASS_MAPPINGS["ControlNetLoaderMultiGPU"] = override_class(GLOBAL_NODE_CLASS_MAPPINGS["ControlNetLoader"])
226253
if "DiffusersLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
@@ -234,10 +261,10 @@ def check_module_exists(module_path):
234261
NODE_CLASS_MAPPINGS["CLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["CLIPLoader"])
235262
NODE_CLASS_MAPPINGS["DualCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["DualCLIPLoader"])
236263
if "TripleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
237-
NODE_CLASS_MAPPINGS["TripleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["TripleCLIPLoader"])
264+
NODE_CLASS_MAPPINGS["TripleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip_no_device(GLOBAL_NODE_CLASS_MAPPINGS["TripleCLIPLoader"])
238265
if "QuadrupleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
239-
NODE_CLASS_MAPPINGS["QuadrupleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["QuadrupleCLIPLoader"])
240-
NODE_CLASS_MAPPINGS["CLIPVisionLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["CLIPVisionLoader"])
266+
NODE_CLASS_MAPPINGS["QuadrupleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip_no_device(GLOBAL_NODE_CLASS_MAPPINGS["QuadrupleCLIPLoader"])
267+
NODE_CLASS_MAPPINGS["CLIPVisionLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip_no_device(GLOBAL_NODE_CLASS_MAPPINGS["CLIPVisionLoader"])
241268
NODE_CLASS_MAPPINGS["CheckpointLoaderSimpleDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"])
242269
NODE_CLASS_MAPPINGS["ControlNetLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["ControlNetLoader"])
243270
if "DiffusersLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
@@ -305,20 +332,20 @@ def register_and_count(module_names, node_map):
305332
"UnetLoaderGGUFAdvancedDisTorchMultiGPU": override_class_with_distorch_gguf(UnetLoaderGGUFAdvanced),
306333
"CLIPLoaderGGUFDisTorchMultiGPU": override_class_with_distorch_clip(CLIPLoaderGGUF),
307334
"DualCLIPLoaderGGUFDisTorchMultiGPU": override_class_with_distorch_clip(DualCLIPLoaderGGUF),
308-
"TripleCLIPLoaderGGUFDisTorchMultiGPU": override_class_with_distorch_clip(TripleCLIPLoaderGGUF),
309-
"QuadrupleCLIPLoaderGGUFDisTorchMultiGPU": override_class_with_distorch_clip(QuadrupleCLIPLoaderGGUF),
335+
"TripleCLIPLoaderGGUFDisTorchMultiGPU": override_class_with_distorch_clip_no_device(TripleCLIPLoaderGGUF),
336+
"QuadrupleCLIPLoaderGGUFDisTorchMultiGPU": override_class_with_distorch_clip_no_device(QuadrupleCLIPLoaderGGUF),
310337
"UnetLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(UnetLoaderGGUF),
311338
"UnetLoaderGGUFAdvancedDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(UnetLoaderGGUFAdvanced),
312339
"CLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(CLIPLoaderGGUF),
313340
"DualCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(DualCLIPLoaderGGUF),
314-
"TripleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(TripleCLIPLoaderGGUF),
315-
"QuadrupleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(QuadrupleCLIPLoaderGGUF),
341+
"TripleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip_no_device(TripleCLIPLoaderGGUF),
342+
"QuadrupleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip_no_device(QuadrupleCLIPLoaderGGUF),
316343
"UnetLoaderGGUFMultiGPU": override_class(UnetLoaderGGUF),
317344
"UnetLoaderGGUFAdvancedMultiGPU": override_class(UnetLoaderGGUFAdvanced),
318345
"CLIPLoaderGGUFMultiGPU": override_class_clip(CLIPLoaderGGUF),
319346
"DualCLIPLoaderGGUFMultiGPU": override_class_clip(DualCLIPLoaderGGUF),
320-
"TripleCLIPLoaderGGUFMultiGPU": override_class_clip(TripleCLIPLoaderGGUF),
321-
"QuadrupleCLIPLoaderGGUFMultiGPU": override_class_clip(QuadrupleCLIPLoaderGGUF)
347+
"TripleCLIPLoaderGGUFMultiGPU": override_class_clip_no_device(TripleCLIPLoaderGGUF),
348+
"QuadrupleCLIPLoaderGGUFMultiGPU": override_class_clip_no_device(QuadrupleCLIPLoaderGGUF)
322349
}
323350
register_and_count(["ComfyUI-GGUF", "comfyui-gguf"], gguf_nodes)
324351

distorch.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,64 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v
462462
return out
463463

464464
return NodeOverrideDisTorch
465+
def override_class_with_distorch_clip_no_device(cls):
466+
"""DisTorch wrapper for CLIP models with GGUF support"""
467+
from . import current_text_encoder_device
468+
469+
class NodeOverrideDisTorchClipNoDevice(cls):
470+
@classmethod
471+
def INPUT_TYPES(s):
472+
inputs = copy.deepcopy(cls.INPUT_TYPES())
473+
devices = get_device_list()
474+
default_device = devices[1] if len(devices) > 1 else devices[0]
475+
inputs["optional"] = inputs.get("optional", {})
476+
inputs["optional"]["device"] = (devices, {"default": default_device})
477+
inputs["optional"]["virtual_vram_gb"] = ("FLOAT", {"default": 4.0, "min": 0.0, "max": 24.0, "step": 0.1})
478+
inputs["optional"]["use_other_vram"] = ("BOOLEAN", {"default": False})
479+
inputs["optional"]["expert_mode_allocations"] = ("STRING", {
480+
"multiline": False,
481+
"default": "",
482+
"tooltip": "Expert use only: Manual VRAM allocation string. Incorrect values can cause crashes. Do not modify unless you fully understand DisTorch memory management."
483+
})
484+
return inputs
485+
486+
CATEGORY = "multigpu"
487+
FUNCTION = "override"
488+
489+
def override(self, *args, device=None, expert_mode_allocations=None, use_other_vram=None, virtual_vram_gb=0.0, **kwargs):
490+
from . import set_current_text_encoder_device
491+
if device is not None:
492+
set_current_text_encoder_device(device)
493+
494+
register_patched_ggufmodelpatcher()
495+
fn = getattr(super(), cls.FUNCTION)
496+
out = fn(*args, **kwargs)
497+
498+
vram_string = ""
499+
if virtual_vram_gb > 0:
500+
if use_other_vram:
501+
available_devices = [d for d in get_device_list() if d != "cpu"]
502+
other_devices = [d for d in available_devices if d != device]
503+
other_devices.sort(key=lambda x: int(x.split(':')[1] if ':' in x else x[-1]), reverse=False)
504+
device_string = ','.join(other_devices + ['cpu'])
505+
vram_string = f"{device};{virtual_vram_gb};{device_string}"
506+
else:
507+
vram_string = f"{device};{virtual_vram_gb};cpu"
508+
509+
full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else ""
510+
511+
logging.info(f"[MultiGPU_DisTorch] Full allocation string: {full_allocation}")
512+
513+
if hasattr(out[0], 'model'):
514+
model_hash = create_model_hash(out[0], "override")
515+
model_allocation_store[model_hash] = full_allocation
516+
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
517+
model_hash = create_model_hash(out[0].patcher, "override")
518+
model_allocation_store[model_hash] = full_allocation
519+
520+
return out
465521

522+
return NodeOverrideDisTorchClipNoDevice
466523

467524
# Alias for backward compatibility
468525
override_class_with_distorch = override_class_with_distorch_gguf

distorch_2.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,99 @@ def override(self, *args, device=None, virtual_vram_gb=4.0, # Changed from comp
772772
return out
773773

774774
return NodeOverrideDisTorchSafetensorV2Clip
775+
776+
def override_class_with_distorch_safetensor_v2_clip_no_device(cls):
777+
"""DisTorch 2.0 wrapper for safetensor CLIP models"""
778+
from . import current_device
779+
780+
class NodeOverrideDisTorchSafetensorV2ClipNoDevice(cls):
781+
@classmethod
782+
def INPUT_TYPES(s):
783+
inputs = copy.deepcopy(cls.INPUT_TYPES())
784+
devices = get_device_list()
785+
default_device = devices[1] if len(devices) > 1 else devices[0]
786+
787+
inputs["optional"] = inputs.get("optional", {})
788+
inputs["optional"]["device"] = (devices, {"default": default_device}) # Changed from compute_device
789+
inputs["optional"]["virtual_vram_gb"] = ("FLOAT", {"default": 4.0, "min": 0.0, "max": 128.0, "step": 0.1})
790+
inputs["optional"]["donor_device"] = (devices, {"default": "cpu"})
791+
inputs["optional"]["expert_mode_allocations"] = ("STRING", {"multiline": False, "default": ""})
792+
inputs["optional"]["high_precision_loras"] = ("BOOLEAN", {"default": True})
793+
return inputs
794+
795+
CATEGORY = "multigpu/distorch_2"
796+
FUNCTION = "override"
797+
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch2)"
798+
799+
@classmethod
800+
def IS_CHANGED(s, *args, device=None, virtual_vram_gb=4.0, # Changed from compute_device
801+
donor_device="cpu", expert_mode_allocations="", high_precision_loras=True, **kwargs):
802+
# Create a hash of our specific settings
803+
settings_str = f"{device}{virtual_vram_gb}{donor_device}{expert_mode_allocations}{high_precision_loras}" # Changed from compute_device
804+
return hashlib.sha256(settings_str.encode()).hexdigest()
805+
806+
def override(self, *args, device=None, virtual_vram_gb=4.0, # Changed from compute_device
807+
donor_device="cpu", expert_mode_allocations="", high_precision_loras=True, **kwargs):
808+
809+
from . import set_current_text_encoder_device # Use text encoder device setter
810+
if device is not None:
811+
set_current_text_encoder_device(device)
812+
813+
# Register our patched ModelPatcher
814+
register_patched_safetensor_modelpatcher()
815+
816+
# Call original function
817+
fn = getattr(super(), cls.FUNCTION)
818+
819+
# --- Check if we need to unload the model due to settings change ---
820+
# This logic is a bit redundant with IS_CHANGED, but provides clear logging
821+
settings_str = f"{device}{virtual_vram_gb}{donor_device}{expert_mode_allocations}" # Changed from compute_device
822+
settings_hash = hashlib.sha256(settings_str.encode()).hexdigest()
823+
824+
# Temporarily load to get hash without applying our patch
825+
temp_out = fn(*args, **kwargs)
826+
model_to_check = None
827+
if hasattr(temp_out[0], 'model'):
828+
model_to_check = temp_out[0]
829+
elif hasattr(temp_out[0], 'patcher') and hasattr(temp_out[0].patcher, 'model'):
830+
model_to_check = temp_out[0].patcher
831+
832+
if model_to_check:
833+
model_hash = create_safetensor_model_hash(model_to_check, "override_check")
834+
last_settings_hash = safetensor_settings_store.get(model_hash)
835+
836+
if last_settings_hash != settings_hash:
837+
logger.info(f"[MultiGPU_DisTorch2] Settings changed for model {model_hash[:8]}. Previous settings hash: {last_settings_hash}, New settings hash: {settings_hash}. Forcing reload.")
838+
else:
839+
logger.info(f"[MultiGPU_DisTorch2] Settings unchanged for model {model_hash[:8]}. Using cached model.")
840+
841+
out = fn(*args, **kwargs)
842+
843+
# Store high_precision_loras in the model for later retrieval
844+
if hasattr(out[0], 'model'):
845+
out[0].model._distorch_high_precision_loras = high_precision_loras
846+
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
847+
out[0].patcher.model._distorch_high_precision_loras = high_precision_loras
848+
849+
vram_string = ""
850+
if virtual_vram_gb > 0:
851+
vram_string = f"{device};{virtual_vram_gb};{donor_device}" # Changed from compute_device
852+
elif expert_mode_allocations: # Only include device if there's an expert string
853+
vram_string = device # Changed from compute_device
854+
855+
full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else ""
856+
857+
logger.info(f"[MultiGPU_DisTorch2] Full allocation string: {full_allocation}")
858+
859+
if hasattr(out[0], 'model'):
860+
model_hash = create_safetensor_model_hash(out[0], "override")
861+
safetensor_allocation_store[model_hash] = full_allocation
862+
safetensor_settings_store[model_hash] = settings_hash
863+
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
864+
model_hash = create_safetensor_model_hash(out[0].patcher, "override")
865+
safetensor_allocation_store[model_hash] = full_allocation
866+
safetensor_settings_store[model_hash] = settings_hash
867+
868+
return out
869+
870+
return NodeOverrideDisTorchSafetensorV2ClipNoDevice

0 commit comments

Comments
 (0)