Skip to content

Commit 7a3ca1c

Browse files
committed
fix: restore global device state in wrapper overrides and checkpoint loader
Add runtime getters in __init__.py: get_current_device() get_current_text_encoder_device() get_current_unet_offload_device() Update wrappers.py to follow the wanvideo.py pattern: each override that sets a device now captures the original device at runtime via the appropriate getter, performs the existing override logic unchanged, returns inside a try ... finally and restores the original device in finally. Applies to DisTorch V2 factory override, GGUF legacy/V2 overrides, CLIP overrides, standard UNet/VAE wrappers and offload variants. Fix checkpoint_multigpu.py to use getters (instead of reading globals directly) when capturing original devices before modifying them, and restore in the existing finally block. Rationale: prevents MultiGPU nodes from leaving the ComfyUI global device context "stuck" to a non-default device. Changes are minimal and localized; no public API or functional behavior is altered except guaranteed restoration of global device state after node execution.
1 parent 0a0716b commit 7a3ca1c

3 files changed

Lines changed: 79 additions & 24 deletions

File tree

__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ def set_current_unet_offload_device(device):
168168
current_unet_offload_device = device
169169
logger.debug(f"[MultiGPU Initialization] current_unet_offload_device set to: {device}")
170170

171+
172+
def get_current_device():
173+
"""Get the current device context for MultiGPU operations at runtime."""
174+
return current_device
175+
176+
177+
def get_current_text_encoder_device():
178+
"""Get the current text encoder device context for CLIP models at runtime."""
179+
return current_text_encoder_device
180+
181+
182+
def get_current_unet_offload_device():
183+
"""Get the current UNet offload device context at runtime."""
184+
return current_unet_offload_device
185+
171186
def get_torch_device_patched():
172187
"""Return MultiGPU-aware device selection for patched mm.get_torch_device."""
173188
device = None

checkpoint_multigpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
3434
embedding_directory=None, output_model=True, model_options={},
3535
te_model_options={}, metadata=None):
3636
"""Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
37-
from . import set_current_device, set_current_text_encoder_device, current_device, current_text_encoder_device
37+
from . import set_current_device, set_current_text_encoder_device, get_current_device, get_current_text_encoder_device
3838

3939
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
4040
config_hash = str(sd_size)
@@ -54,8 +54,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5454
model = None
5555
model_patcher = None
5656

57-
original_main_device = current_device
58-
original_clip_device = current_text_encoder_device
57+
# Capture the current devices at runtime so we can restore them after loading
58+
original_main_device = get_current_device()
59+
original_clip_device = get_current_text_encoder_device()
5960

6061
try:
6162
diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(sd)

wrappers.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu",
5959

6060
device_value = kwargs.get(device_param_name)
6161

62+
# Capture the current device at runtime so we can restore it later
63+
from . import get_current_device, get_current_text_encoder_device
64+
original_device = get_current_device() if device_param_name == "compute_device" else get_current_text_encoder_device()
65+
6266
import comfy.model_management as mm
6367

6468
if eject_models:
@@ -118,7 +122,11 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu",
118122

119123
logger.info(f"[MultiGPU DisTorch V2] Full allocation string: {full_allocation}")
120124

121-
return out
125+
try:
126+
return out
127+
finally:
128+
# Restore the device that was in use when the override started
129+
device_setter_func(original_device)
122130

123131
return NodeOverrideDisTorchSafetensorV2
124132

@@ -164,7 +172,7 @@ def override_class_with_distorch_safetensor_v2_clip_no_device(cls):
164172

165173
def override_class_with_distorch_gguf(cls):
166174
"""DisTorch V1 Legacy wrapper - maintains V1 UI but calls V2 backend"""
167-
from . import set_current_device
175+
from . import set_current_device, get_current_device
168176
from .distorch_2 import register_patched_safetensor_modelpatcher
169177

170178
class NodeOverrideDisTorchGGUFLegacy(cls):
@@ -185,6 +193,8 @@ def INPUT_TYPES(s):
185193
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (Legacy)"
186194

187195
def override(self, *args, device=None, expert_mode_allocations="", use_other_vram=False, virtual_vram_gb=0.0, **kwargs):
196+
# Capture and restore the current device to avoid leaking global state
197+
original_device = get_current_device()
188198
if device is not None:
189199
set_current_device(device)
190200

@@ -220,15 +230,17 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
220230
if model_to_check and full_allocation:
221231
inner_model = model_to_check.model
222232
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
223-
224-
return out
233+
try:
234+
return out
235+
finally:
236+
set_current_device(original_device)
225237

226238
return NodeOverrideDisTorchGGUFLegacy
227239

228240

229241
def override_class_with_distorch_gguf_v2(cls):
230242
"""DisTorch V2 wrapper for GGUF models"""
231-
from . import set_current_device
243+
from . import set_current_device, get_current_device
232244
from .distorch_2 import register_patched_safetensor_modelpatcher
233245

234246
class NodeOverrideDisTorchGGUFv2(cls):
@@ -250,6 +262,7 @@ def INPUT_TYPES(s):
250262
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch2)"
251263

252264
def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device="cpu", expert_mode_allocations="", **kwargs):
265+
original_device = get_current_device()
253266
if compute_device is not None:
254267
set_current_device(compute_device)
255268

@@ -282,15 +295,17 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device
282295
if model_to_check and full_allocation:
283296
inner_model = model_to_check.model
284297
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
285-
286-
return out
298+
try:
299+
return out
300+
finally:
301+
set_current_device(original_device)
287302

288303
return NodeOverrideDisTorchGGUFv2
289304

290305

291306
def override_class_with_distorch_clip(cls):
292307
"""DisTorch V1 wrapper for CLIP models - calls V2 backend"""
293-
from . import set_current_text_encoder_device
308+
from . import set_current_text_encoder_device, get_current_text_encoder_device
294309
from .distorch_2 import register_patched_safetensor_modelpatcher
295310

296311
class NodeOverrideDisTorchClip(cls):
@@ -311,6 +326,7 @@ def INPUT_TYPES(s):
311326
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch)"
312327

313328
def override(self, *args, device=None, expert_mode_allocations="", use_other_vram=False, virtual_vram_gb=0.0, **kwargs):
329+
original_text_device = get_current_text_encoder_device()
314330
if device is not None:
315331
set_current_text_encoder_device(device)
316332

@@ -346,15 +362,17 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
346362
if model_to_check and full_allocation:
347363
inner_model = model_to_check.model
348364
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
349-
350-
return out
365+
try:
366+
return out
367+
finally:
368+
set_current_text_encoder_device(original_text_device)
351369

352370
return NodeOverrideDisTorchClip
353371

354372

355373
def override_class_with_distorch_clip_no_device(cls):
356374
"""DisTorch V1 wrapper for Triple/Quad CLIP models - calls V2 backend"""
357-
from . import set_current_text_encoder_device
375+
from . import set_current_text_encoder_device, get_current_text_encoder_device
358376
from .distorch_2 import register_patched_safetensor_modelpatcher
359377

360378
class NodeOverrideDisTorchClipNoDevice(cls):
@@ -375,6 +393,7 @@ def INPUT_TYPES(s):
375393
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch)"
376394

377395
def override(self, *args, device=None, expert_mode_allocations="", use_other_vram=False, virtual_vram_gb=0.0, **kwargs):
396+
original_text_device = get_current_text_encoder_device()
378397
if device is not None:
379398
set_current_text_encoder_device(device)
380399

@@ -410,8 +429,10 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
410429
if model_to_check and full_allocation:
411430
inner_model = model_to_check.model
412431
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
413-
414-
return out
432+
try:
433+
return out
434+
finally:
435+
set_current_text_encoder_device(original_text_device)
415436

416437
return NodeOverrideDisTorchClipNoDevice
417438

@@ -426,7 +447,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
426447

427448
def override_class(cls):
428449
"""Standard MultiGPU device override for UNet/VAE models"""
429-
from . import set_current_device
450+
from . import set_current_device, get_current_device
430451

431452
class NodeOverride(cls):
432453
@classmethod
@@ -442,17 +463,21 @@ def INPUT_TYPES(s):
442463
FUNCTION = "override"
443464

444465
def override(self, *args, device=None, **kwargs):
466+
original_device = get_current_device()
445467
if device is not None:
446468
set_current_device(device)
447469
fn = getattr(super(), cls.FUNCTION)
448470
out = fn(*args, **kwargs)
449-
return out
471+
try:
472+
return out
473+
finally:
474+
set_current_device(original_device)
450475

451476
return NodeOverride
452477

453478
def override_class_offload(cls):
454479
"""Standard MultiGPU device override for UNet/VAE models"""
455-
from . import set_current_device, set_current_unet_offload_device
480+
from . import set_current_device, set_current_unet_offload_device, get_current_device, get_current_unet_offload_device
456481

457482
class NodeOverride(cls):
458483
@classmethod
@@ -469,21 +494,27 @@ def INPUT_TYPES(s):
469494
FUNCTION = "override"
470495

471496
def override(self, *args, device=None, offload_device=None, **kwargs):
497+
original_device = get_current_device()
498+
original_offload_device = get_current_unet_offload_device()
472499
if device is not None:
473500
set_current_device(device)
474501
if offload_device is not None:
475502
set_current_unet_offload_device(offload_device)
476503
fn = getattr(super(), cls.FUNCTION)
477504
out = fn(*args, **kwargs)
478-
return out
505+
try:
506+
return out
507+
finally:
508+
set_current_device(original_device)
509+
set_current_unet_offload_device(original_offload_device)
479510

480511
return NodeOverride
481512

482513

483514

484515
def override_class_clip(cls):
485516
"""Standard MultiGPU device override for CLIP models (with device kwarg workaround)"""
486-
from . import set_current_text_encoder_device
517+
from . import set_current_text_encoder_device, get_current_text_encoder_device
487518

488519
class NodeOverride(cls):
489520
@classmethod
@@ -499,19 +530,23 @@ def INPUT_TYPES(s):
499530
FUNCTION = "override"
500531

501532
def override(self, *args, device=None, **kwargs):
533+
original_text_device = get_current_text_encoder_device()
502534
if device is not None:
503535
set_current_text_encoder_device(device)
504536
kwargs['device'] = 'default'
505537
fn = getattr(super(), cls.FUNCTION)
506538
out = fn(*args, **kwargs)
507-
return out
539+
try:
540+
return out
541+
finally:
542+
set_current_text_encoder_device(original_text_device)
508543

509544
return NodeOverride
510545

511546

512547
def override_class_clip_no_device(cls):
513548
"""Standard MultiGPU device override for Triple/Quad CLIP models (no device kwarg workaround)"""
514-
from . import set_current_text_encoder_device
549+
from . import set_current_text_encoder_device, get_current_text_encoder_device
515550

516551
class NodeOverride(cls):
517552
@classmethod
@@ -527,10 +562,14 @@ def INPUT_TYPES(s):
527562
FUNCTION = "override"
528563

529564
def override(self, *args, device=None, **kwargs):
565+
original_text_device = get_current_text_encoder_device()
530566
if device is not None:
531567
set_current_text_encoder_device(device)
532568
fn = getattr(super(), cls.FUNCTION)
533569
out = fn(*args, **kwargs)
534-
return out
570+
try:
571+
return out
572+
finally:
573+
set_current_text_encoder_device(original_text_device)
535574

536575
return NodeOverride

0 commit comments

Comments
 (0)