@@ -59,6 +59,10 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu",
5959
6060 device_value = kwargs .get (device_param_name )
6161
62+ # Capture the current device at runtime so we can restore it later
63+ from . import get_current_device , get_current_text_encoder_device
64+ original_device = get_current_device () if device_param_name == "compute_device" else get_current_text_encoder_device ()
65+
6266 import comfy .model_management as mm
6367
6468 if eject_models :
@@ -118,7 +122,11 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu",
118122
119123 logger .info (f"[MultiGPU DisTorch V2] Full allocation string: { full_allocation } " )
120124
121- return out
125+ try :
126+ return out
127+ finally :
128+ # Restore the device that was in use when the override started
129+ device_setter_func (original_device )
122130
123131 return NodeOverrideDisTorchSafetensorV2
124132
@@ -164,7 +172,7 @@ def override_class_with_distorch_safetensor_v2_clip_no_device(cls):
164172
165173def override_class_with_distorch_gguf (cls ):
166174 """DisTorch V1 Legacy wrapper - maintains V1 UI but calls V2 backend"""
167- from . import set_current_device
175+ from . import set_current_device , get_current_device
168176 from .distorch_2 import register_patched_safetensor_modelpatcher
169177
170178 class NodeOverrideDisTorchGGUFLegacy (cls ):
@@ -185,6 +193,8 @@ def INPUT_TYPES(s):
185193 TITLE = f"{ cls .TITLE if hasattr (cls , 'TITLE' ) else cls .__name__ } (Legacy)"
186194
187195 def override (self , * args , device = None , expert_mode_allocations = "" , use_other_vram = False , virtual_vram_gb = 0.0 , ** kwargs ):
196+ # Capture and restore the current device to avoid leaking global state
197+ original_device = get_current_device ()
188198 if device is not None :
189199 set_current_device (device )
190200
@@ -220,15 +230,17 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
220230 if model_to_check and full_allocation :
221231 inner_model = model_to_check .model
222232 inner_model ._distorch_v2_meta = {"full_allocation" : full_allocation }
223-
224- return out
233+ try :
234+ return out
235+ finally :
236+ set_current_device (original_device )
225237
226238 return NodeOverrideDisTorchGGUFLegacy
227239
228240
229241def override_class_with_distorch_gguf_v2 (cls ):
230242 """DisTorch V2 wrapper for GGUF models"""
231- from . import set_current_device
243+ from . import set_current_device , get_current_device
232244 from .distorch_2 import register_patched_safetensor_modelpatcher
233245
234246 class NodeOverrideDisTorchGGUFv2 (cls ):
@@ -250,6 +262,7 @@ def INPUT_TYPES(s):
250262 TITLE = f"{ cls .TITLE if hasattr (cls , 'TITLE' ) else cls .__name__ } (DisTorch2)"
251263
252264 def override (self , * args , compute_device = None , virtual_vram_gb = 4.0 , donor_device = "cpu" , expert_mode_allocations = "" , ** kwargs ):
265+ original_device = get_current_device ()
253266 if compute_device is not None :
254267 set_current_device (compute_device )
255268
@@ -282,15 +295,17 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device
282295 if model_to_check and full_allocation :
283296 inner_model = model_to_check .model
284297 inner_model ._distorch_v2_meta = {"full_allocation" : full_allocation }
285-
286- return out
298+ try :
299+ return out
300+ finally :
301+ set_current_device (original_device )
287302
288303 return NodeOverrideDisTorchGGUFv2
289304
290305
291306def override_class_with_distorch_clip (cls ):
292307 """DisTorch V1 wrapper for CLIP models - calls V2 backend"""
293- from . import set_current_text_encoder_device
308+ from . import set_current_text_encoder_device , get_current_text_encoder_device
294309 from .distorch_2 import register_patched_safetensor_modelpatcher
295310
296311 class NodeOverrideDisTorchClip (cls ):
@@ -311,6 +326,7 @@ def INPUT_TYPES(s):
311326 TITLE = f"{ cls .TITLE if hasattr (cls , 'TITLE' ) else cls .__name__ } (DisTorch)"
312327
313328 def override (self , * args , device = None , expert_mode_allocations = "" , use_other_vram = False , virtual_vram_gb = 0.0 , ** kwargs ):
329+ original_text_device = get_current_text_encoder_device ()
314330 if device is not None :
315331 set_current_text_encoder_device (device )
316332
@@ -346,15 +362,17 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
346362 if model_to_check and full_allocation :
347363 inner_model = model_to_check .model
348364 inner_model ._distorch_v2_meta = {"full_allocation" : full_allocation }
349-
350- return out
365+ try :
366+ return out
367+ finally :
368+ set_current_text_encoder_device (original_text_device )
351369
352370 return NodeOverrideDisTorchClip
353371
354372
355373def override_class_with_distorch_clip_no_device (cls ):
356374 """DisTorch V1 wrapper for Triple/Quad CLIP models - calls V2 backend"""
357- from . import set_current_text_encoder_device
375+ from . import set_current_text_encoder_device , get_current_text_encoder_device
358376 from .distorch_2 import register_patched_safetensor_modelpatcher
359377
360378 class NodeOverrideDisTorchClipNoDevice (cls ):
@@ -375,6 +393,7 @@ def INPUT_TYPES(s):
375393 TITLE = f"{ cls .TITLE if hasattr (cls , 'TITLE' ) else cls .__name__ } (DisTorch)"
376394
377395 def override (self , * args , device = None , expert_mode_allocations = "" , use_other_vram = False , virtual_vram_gb = 0.0 , ** kwargs ):
396+ original_text_device = get_current_text_encoder_device ()
378397 if device is not None :
379398 set_current_text_encoder_device (device )
380399
@@ -410,8 +429,10 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
410429 if model_to_check and full_allocation :
411430 inner_model = model_to_check .model
412431 inner_model ._distorch_v2_meta = {"full_allocation" : full_allocation }
413-
414- return out
432+ try :
433+ return out
434+ finally :
435+ set_current_text_encoder_device (original_text_device )
415436
416437 return NodeOverrideDisTorchClipNoDevice
417438
@@ -426,7 +447,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
426447
427448def override_class (cls ):
428449 """Standard MultiGPU device override for UNet/VAE models"""
429- from . import set_current_device
450+ from . import set_current_device , get_current_device
430451
431452 class NodeOverride (cls ):
432453 @classmethod
@@ -442,17 +463,21 @@ def INPUT_TYPES(s):
442463 FUNCTION = "override"
443464
444465 def override (self , * args , device = None , ** kwargs ):
466+ original_device = get_current_device ()
445467 if device is not None :
446468 set_current_device (device )
447469 fn = getattr (super (), cls .FUNCTION )
448470 out = fn (* args , ** kwargs )
449- return out
471+ try :
472+ return out
473+ finally :
474+ set_current_device (original_device )
450475
451476 return NodeOverride
452477
453478def override_class_offload (cls ):
454479 """Standard MultiGPU device override for UNet/VAE models"""
455- from . import set_current_device , set_current_unet_offload_device
480+ from . import set_current_device , set_current_unet_offload_device , get_current_device , get_current_unet_offload_device
456481
457482 class NodeOverride (cls ):
458483 @classmethod
@@ -469,21 +494,27 @@ def INPUT_TYPES(s):
469494 FUNCTION = "override"
470495
471496 def override (self , * args , device = None , offload_device = None , ** kwargs ):
497+ original_device = get_current_device ()
498+ original_offload_device = get_current_unet_offload_device ()
472499 if device is not None :
473500 set_current_device (device )
474501 if offload_device is not None :
475502 set_current_unet_offload_device (offload_device )
476503 fn = getattr (super (), cls .FUNCTION )
477504 out = fn (* args , ** kwargs )
478- return out
505+ try :
506+ return out
507+ finally :
508+ set_current_device (original_device )
509+ set_current_unet_offload_device (original_offload_device )
479510
480511 return NodeOverride
481512
482513
483514
484515def override_class_clip (cls ):
485516 """Standard MultiGPU device override for CLIP models (with device kwarg workaround)"""
486- from . import set_current_text_encoder_device
517+ from . import set_current_text_encoder_device , get_current_text_encoder_device
487518
488519 class NodeOverride (cls ):
489520 @classmethod
@@ -499,19 +530,23 @@ def INPUT_TYPES(s):
499530 FUNCTION = "override"
500531
501532 def override (self , * args , device = None , ** kwargs ):
533+ original_text_device = get_current_text_encoder_device ()
502534 if device is not None :
503535 set_current_text_encoder_device (device )
504536 kwargs ['device' ] = 'default'
505537 fn = getattr (super (), cls .FUNCTION )
506538 out = fn (* args , ** kwargs )
507- return out
539+ try :
540+ return out
541+ finally :
542+ set_current_text_encoder_device (original_text_device )
508543
509544 return NodeOverride
510545
511546
512547def override_class_clip_no_device (cls ):
513548 """Standard MultiGPU device override for Triple/Quad CLIP models (no device kwarg workaround)"""
514- from . import set_current_text_encoder_device
549+ from . import set_current_text_encoder_device , get_current_text_encoder_device
515550
516551 class NodeOverride (cls ):
517552 @classmethod
@@ -527,10 +562,14 @@ def INPUT_TYPES(s):
527562 FUNCTION = "override"
528563
529564 def override (self , * args , device = None , ** kwargs ):
565+ original_text_device = get_current_text_encoder_device ()
530566 if device is not None :
531567 set_current_text_encoder_device (device )
532568 fn = getattr (super (), cls .FUNCTION )
533569 out = fn (* args , ** kwargs )
534- return out
570+ try :
571+ return out
572+ finally :
573+ set_current_text_encoder_device (original_text_device )
535574
536575 return NodeOverride
0 commit comments