@@ -92,6 +92,31 @@ def override(self, *args, device=None, **kwargs):
9292
9393 return NodeOverride
9494
95+ def override_class_clip_no_device (cls ):
96+ class NodeOverride (cls ):
97+ @classmethod
98+ def INPUT_TYPES (s ):
99+ inputs = copy .deepcopy (cls .INPUT_TYPES ())
100+ devices = get_device_list ()
101+ default_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
102+ inputs ["optional" ] = inputs .get ("optional" , {})
103+ inputs ["optional" ]["device" ] = (devices , {"default" : default_device })
104+ return inputs
105+
106+ CATEGORY = "multigpu"
107+ FUNCTION = "override"
108+
109+ def override (self , * args , device = None , ** kwargs ):
110+ if device is not None :
111+ set_current_text_encoder_device (device )
112+ fn = getattr (super (), cls .FUNCTION )
113+ out = fn (* args , ** kwargs )
114+
115+ return out
116+
117+ return NodeOverride
118+
119+
95120def get_torch_device_patched ():
96121 device = None
97122 if (not is_accelerator_available () or mm .cpu_state == mm .CPUState .CPU or "cpu" in str (current_device ).lower ()):
@@ -183,6 +208,7 @@ def check_module_exists(module_path):
183208 override_class_with_distorch_gguf ,
184209 override_class_with_distorch_gguf_v2 ,
185210 override_class_with_distorch_clip ,
211+ override_class_with_distorch_clip_no_device ,
186212 override_class_with_distorch
187213)
188214
@@ -194,7 +220,8 @@ def check_module_exists(module_path):
194220 analyze_safetensor_loading ,
195221 calculate_safetensor_vvram_allocation ,
196222 override_class_with_distorch_safetensor_v2 ,
197- override_class_with_distorch_safetensor_v2_clip
223+ override_class_with_distorch_safetensor_v2_clip ,
224+ override_class_with_distorch_safetensor_v2_clip_no_device
198225)
199226
200227# Import advanced checkpoint loaders
@@ -217,10 +244,10 @@ def check_module_exists(module_path):
217244NODE_CLASS_MAPPINGS ["CLIPLoaderMultiGPU" ] = override_class_clip (GLOBAL_NODE_CLASS_MAPPINGS ["CLIPLoader" ])
218245NODE_CLASS_MAPPINGS ["DualCLIPLoaderMultiGPU" ] = override_class_clip (GLOBAL_NODE_CLASS_MAPPINGS ["DualCLIPLoader" ])
219246if "TripleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS :
220- NODE_CLASS_MAPPINGS ["TripleCLIPLoaderMultiGPU" ] = override_class_clip (GLOBAL_NODE_CLASS_MAPPINGS ["TripleCLIPLoader" ])
247+ NODE_CLASS_MAPPINGS ["TripleCLIPLoaderMultiGPU" ] = override_class_clip_no_device (GLOBAL_NODE_CLASS_MAPPINGS ["TripleCLIPLoader" ])
221248if "QuadrupleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS :
222- NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoaderMultiGPU" ] = override_class_clip (GLOBAL_NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoader" ])
223- NODE_CLASS_MAPPINGS ["CLIPVisionLoaderMultiGPU" ] = override_class_clip (GLOBAL_NODE_CLASS_MAPPINGS ["CLIPVisionLoader" ])
249+ NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoaderMultiGPU" ] = override_class_clip_no_device (GLOBAL_NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoader" ])
250+ NODE_CLASS_MAPPINGS ["CLIPVisionLoaderMultiGPU" ] = override_class_clip_no_device (GLOBAL_NODE_CLASS_MAPPINGS ["CLIPVisionLoader" ])
224251NODE_CLASS_MAPPINGS ["CheckpointLoaderSimpleMultiGPU" ] = override_class (GLOBAL_NODE_CLASS_MAPPINGS ["CheckpointLoaderSimple" ])
225252NODE_CLASS_MAPPINGS ["ControlNetLoaderMultiGPU" ] = override_class (GLOBAL_NODE_CLASS_MAPPINGS ["ControlNetLoader" ])
226253if "DiffusersLoader" in GLOBAL_NODE_CLASS_MAPPINGS :
@@ -234,10 +261,10 @@ def check_module_exists(module_path):
234261NODE_CLASS_MAPPINGS ["CLIPLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip (GLOBAL_NODE_CLASS_MAPPINGS ["CLIPLoader" ])
235262NODE_CLASS_MAPPINGS ["DualCLIPLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip (GLOBAL_NODE_CLASS_MAPPINGS ["DualCLIPLoader" ])
236263if "TripleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS :
237- NODE_CLASS_MAPPINGS ["TripleCLIPLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip (GLOBAL_NODE_CLASS_MAPPINGS ["TripleCLIPLoader" ])
264+ NODE_CLASS_MAPPINGS ["TripleCLIPLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip_no_device (GLOBAL_NODE_CLASS_MAPPINGS ["TripleCLIPLoader" ])
238265if "QuadrupleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS :
239- NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip (GLOBAL_NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoader" ])
240- NODE_CLASS_MAPPINGS ["CLIPVisionLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip (GLOBAL_NODE_CLASS_MAPPINGS ["CLIPVisionLoader" ])
266+ NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip_no_device (GLOBAL_NODE_CLASS_MAPPINGS ["QuadrupleCLIPLoader" ])
267+ NODE_CLASS_MAPPINGS ["CLIPVisionLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2_clip_no_device (GLOBAL_NODE_CLASS_MAPPINGS ["CLIPVisionLoader" ])
241268NODE_CLASS_MAPPINGS ["CheckpointLoaderSimpleDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2 (GLOBAL_NODE_CLASS_MAPPINGS ["CheckpointLoaderSimple" ])
242269NODE_CLASS_MAPPINGS ["ControlNetLoaderDisTorch2MultiGPU" ] = override_class_with_distorch_safetensor_v2 (GLOBAL_NODE_CLASS_MAPPINGS ["ControlNetLoader" ])
243270if "DiffusersLoader" in GLOBAL_NODE_CLASS_MAPPINGS :
@@ -305,20 +332,20 @@ def register_and_count(module_names, node_map):
305332 "UnetLoaderGGUFAdvancedDisTorchMultiGPU" : override_class_with_distorch_gguf (UnetLoaderGGUFAdvanced ),
306333 "CLIPLoaderGGUFDisTorchMultiGPU" : override_class_with_distorch_clip (CLIPLoaderGGUF ),
307334 "DualCLIPLoaderGGUFDisTorchMultiGPU" : override_class_with_distorch_clip (DualCLIPLoaderGGUF ),
308- "TripleCLIPLoaderGGUFDisTorchMultiGPU" : override_class_with_distorch_clip (TripleCLIPLoaderGGUF ),
309- "QuadrupleCLIPLoaderGGUFDisTorchMultiGPU" : override_class_with_distorch_clip (QuadrupleCLIPLoaderGGUF ),
335+ "TripleCLIPLoaderGGUFDisTorchMultiGPU" : override_class_with_distorch_clip_no_device (TripleCLIPLoaderGGUF ),
336+ "QuadrupleCLIPLoaderGGUFDisTorchMultiGPU" : override_class_with_distorch_clip_no_device (QuadrupleCLIPLoaderGGUF ),
310337 "UnetLoaderGGUFDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2 (UnetLoaderGGUF ),
311338 "UnetLoaderGGUFAdvancedDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2 (UnetLoaderGGUFAdvanced ),
312339 "CLIPLoaderGGUFDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2_clip (CLIPLoaderGGUF ),
313340 "DualCLIPLoaderGGUFDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2_clip (DualCLIPLoaderGGUF ),
314- "TripleCLIPLoaderGGUFDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2_clip (TripleCLIPLoaderGGUF ),
315- "QuadrupleCLIPLoaderGGUFDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2_clip (QuadrupleCLIPLoaderGGUF ),
341+ "TripleCLIPLoaderGGUFDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2_clip_no_device (TripleCLIPLoaderGGUF ),
342+ "QuadrupleCLIPLoaderGGUFDisTorch2MultiGPU" : override_class_with_distorch_safetensor_v2_clip_no_device (QuadrupleCLIPLoaderGGUF ),
316343 "UnetLoaderGGUFMultiGPU" : override_class (UnetLoaderGGUF ),
317344 "UnetLoaderGGUFAdvancedMultiGPU" : override_class (UnetLoaderGGUFAdvanced ),
318345 "CLIPLoaderGGUFMultiGPU" : override_class_clip (CLIPLoaderGGUF ),
319346 "DualCLIPLoaderGGUFMultiGPU" : override_class_clip (DualCLIPLoaderGGUF ),
320- "TripleCLIPLoaderGGUFMultiGPU" : override_class_clip (TripleCLIPLoaderGGUF ),
321- "QuadrupleCLIPLoaderGGUFMultiGPU" : override_class_clip (QuadrupleCLIPLoaderGGUF )
347+ "TripleCLIPLoaderGGUFMultiGPU" : override_class_clip_no_device (TripleCLIPLoaderGGUF ),
348+ "QuadrupleCLIPLoaderGGUFMultiGPU" : override_class_clip_no_device (QuadrupleCLIPLoaderGGUF )
322349}
323350register_and_count (["ComfyUI-GGUF" , "comfyui-gguf" ], gguf_nodes )
324351
0 commit comments