Skip to content

Commit 35e81e9

Browse files
authored
Merge pull request #131 from pollockjj/dev_restore
Restore global device state in MultiGPU operations
2 parents 0a0716b + 6e5225e commit 35e81e9

4 files changed

Lines changed: 80 additions & 25 deletions

File tree

__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ def set_current_unet_offload_device(device):
168168
current_unet_offload_device = device
169169
logger.debug(f"[MultiGPU Initialization] current_unet_offload_device set to: {device}")
170170

171+
172+
def get_current_device():
173+
"""Get the current device context for MultiGPU operations at runtime."""
174+
return current_device
175+
176+
177+
def get_current_text_encoder_device():
178+
"""Get the current text encoder device context for CLIP models at runtime."""
179+
return current_text_encoder_device
180+
181+
182+
def get_current_unet_offload_device():
183+
"""Get the current UNet offload device context at runtime."""
184+
return current_unet_offload_device
185+
171186
def get_torch_device_patched():
172187
"""Return MultiGPU-aware device selection for patched mm.get_torch_device."""
173188
device = None

checkpoint_multigpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
3434
embedding_directory=None, output_model=True, model_options={},
3535
te_model_options={}, metadata=None):
3636
"""Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
37-
from . import set_current_device, set_current_text_encoder_device, current_device, current_text_encoder_device
37+
from . import set_current_device, set_current_text_encoder_device, get_current_device, get_current_text_encoder_device
3838

3939
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
4040
config_hash = str(sd_size)
@@ -54,8 +54,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5454
model = None
5555
model_patcher = None
5656

57-
original_main_device = current_device
58-
original_clip_device = current_text_encoder_device
57+
# Capture the current devices at runtime so we can restore them after loading
58+
original_main_device = get_current_device()
59+
original_clip_device = get_current_text_encoder_device()
5960

6061
try:
6162
diffusion_model_prefix = comfy.model_detection.unet_prefix_from_state_dict(sd)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-multigpu"
33
description = "Provides a suite of custom nodes to manage multiple GPUs for ComfyUI, including advanced model offloading for both GGUF and Safetensor formats with DisTorch, and bespoke MultiGPU support for WanVideoWrapper and other custom nodes."
4-
version = "2.5.8"
4+
version = "2.5.9"
55
license = {file = "LICENSE"}
66

77
[project.urls]

wrappers.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu",
5959

6060
device_value = kwargs.get(device_param_name)
6161

62+
# Capture the current device at runtime so we can restore it later
63+
from . import get_current_device, get_current_text_encoder_device
64+
original_device = get_current_device() if device_param_name == "compute_device" else get_current_text_encoder_device()
65+
6266
import comfy.model_management as mm
6367

6468
if eject_models:
@@ -118,7 +122,11 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu",
118122

119123
logger.info(f"[MultiGPU DisTorch V2] Full allocation string: {full_allocation}")
120124

121-
return out
125+
try:
126+
return out
127+
finally:
128+
# Restore the device that was in use when the override started
129+
device_setter_func(original_device)
122130

123131
return NodeOverrideDisTorchSafetensorV2
124132

@@ -164,7 +172,7 @@ def override_class_with_distorch_safetensor_v2_clip_no_device(cls):
164172

165173
def override_class_with_distorch_gguf(cls):
166174
"""DisTorch V1 Legacy wrapper - maintains V1 UI but calls V2 backend"""
167-
from . import set_current_device
175+
from . import set_current_device, get_current_device
168176
from .distorch_2 import register_patched_safetensor_modelpatcher
169177

170178
class NodeOverrideDisTorchGGUFLegacy(cls):
@@ -185,6 +193,8 @@ def INPUT_TYPES(s):
185193
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (Legacy)"
186194

187195
def override(self, *args, device=None, expert_mode_allocations="", use_other_vram=False, virtual_vram_gb=0.0, **kwargs):
196+
# Capture and restore the current device to avoid leaking global state
197+
original_device = get_current_device()
188198
if device is not None:
189199
set_current_device(device)
190200

@@ -220,15 +230,17 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
220230
if model_to_check and full_allocation:
221231
inner_model = model_to_check.model
222232
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
223-
224-
return out
233+
try:
234+
return out
235+
finally:
236+
set_current_device(original_device)
225237

226238
return NodeOverrideDisTorchGGUFLegacy
227239

228240

229241
def override_class_with_distorch_gguf_v2(cls):
230242
"""DisTorch V2 wrapper for GGUF models"""
231-
from . import set_current_device
243+
from . import set_current_device, get_current_device
232244
from .distorch_2 import register_patched_safetensor_modelpatcher
233245

234246
class NodeOverrideDisTorchGGUFv2(cls):
@@ -250,6 +262,7 @@ def INPUT_TYPES(s):
250262
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch2)"
251263

252264
def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device="cpu", expert_mode_allocations="", **kwargs):
265+
original_device = get_current_device()
253266
if compute_device is not None:
254267
set_current_device(compute_device)
255268

@@ -282,15 +295,17 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device
282295
if model_to_check and full_allocation:
283296
inner_model = model_to_check.model
284297
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
285-
286-
return out
298+
try:
299+
return out
300+
finally:
301+
set_current_device(original_device)
287302

288303
return NodeOverrideDisTorchGGUFv2
289304

290305

291306
def override_class_with_distorch_clip(cls):
292307
"""DisTorch V1 wrapper for CLIP models - calls V2 backend"""
293-
from . import set_current_text_encoder_device
308+
from . import set_current_text_encoder_device, get_current_text_encoder_device
294309
from .distorch_2 import register_patched_safetensor_modelpatcher
295310

296311
class NodeOverrideDisTorchClip(cls):
@@ -311,6 +326,7 @@ def INPUT_TYPES(s):
311326
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch)"
312327

313328
def override(self, *args, device=None, expert_mode_allocations="", use_other_vram=False, virtual_vram_gb=0.0, **kwargs):
329+
original_text_device = get_current_text_encoder_device()
314330
if device is not None:
315331
set_current_text_encoder_device(device)
316332

@@ -346,15 +362,17 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
346362
if model_to_check and full_allocation:
347363
inner_model = model_to_check.model
348364
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
349-
350-
return out
365+
try:
366+
return out
367+
finally:
368+
set_current_text_encoder_device(original_text_device)
351369

352370
return NodeOverrideDisTorchClip
353371

354372

355373
def override_class_with_distorch_clip_no_device(cls):
356374
"""DisTorch V1 wrapper for Triple/Quad CLIP models - calls V2 backend"""
357-
from . import set_current_text_encoder_device
375+
from . import set_current_text_encoder_device, get_current_text_encoder_device
358376
from .distorch_2 import register_patched_safetensor_modelpatcher
359377

360378
class NodeOverrideDisTorchClipNoDevice(cls):
@@ -375,6 +393,7 @@ def INPUT_TYPES(s):
375393
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch)"
376394

377395
def override(self, *args, device=None, expert_mode_allocations="", use_other_vram=False, virtual_vram_gb=0.0, **kwargs):
396+
original_text_device = get_current_text_encoder_device()
378397
if device is not None:
379398
set_current_text_encoder_device(device)
380399

@@ -410,8 +429,10 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
410429
if model_to_check and full_allocation:
411430
inner_model = model_to_check.model
412431
inner_model._distorch_v2_meta = {"full_allocation": full_allocation}
413-
414-
return out
432+
try:
433+
return out
434+
finally:
435+
set_current_text_encoder_device(original_text_device)
415436

416437
return NodeOverrideDisTorchClipNoDevice
417438

@@ -426,7 +447,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra
426447

427448
def override_class(cls):
428449
"""Standard MultiGPU device override for UNet/VAE models"""
429-
from . import set_current_device
450+
from . import set_current_device, get_current_device
430451

431452
class NodeOverride(cls):
432453
@classmethod
@@ -442,17 +463,21 @@ def INPUT_TYPES(s):
442463
FUNCTION = "override"
443464

444465
def override(self, *args, device=None, **kwargs):
466+
original_device = get_current_device()
445467
if device is not None:
446468
set_current_device(device)
447469
fn = getattr(super(), cls.FUNCTION)
448470
out = fn(*args, **kwargs)
449-
return out
471+
try:
472+
return out
473+
finally:
474+
set_current_device(original_device)
450475

451476
return NodeOverride
452477

453478
def override_class_offload(cls):
454479
"""Standard MultiGPU device override for UNet/VAE models"""
455-
from . import set_current_device, set_current_unet_offload_device
480+
from . import set_current_device, set_current_unet_offload_device, get_current_device, get_current_unet_offload_device
456481

457482
class NodeOverride(cls):
458483
@classmethod
@@ -469,21 +494,27 @@ def INPUT_TYPES(s):
469494
FUNCTION = "override"
470495

471496
def override(self, *args, device=None, offload_device=None, **kwargs):
497+
original_device = get_current_device()
498+
original_offload_device = get_current_unet_offload_device()
472499
if device is not None:
473500
set_current_device(device)
474501
if offload_device is not None:
475502
set_current_unet_offload_device(offload_device)
476503
fn = getattr(super(), cls.FUNCTION)
477504
out = fn(*args, **kwargs)
478-
return out
505+
try:
506+
return out
507+
finally:
508+
set_current_device(original_device)
509+
set_current_unet_offload_device(original_offload_device)
479510

480511
return NodeOverride
481512

482513

483514

484515
def override_class_clip(cls):
485516
"""Standard MultiGPU device override for CLIP models (with device kwarg workaround)"""
486-
from . import set_current_text_encoder_device
517+
from . import set_current_text_encoder_device, get_current_text_encoder_device
487518

488519
class NodeOverride(cls):
489520
@classmethod
@@ -499,19 +530,23 @@ def INPUT_TYPES(s):
499530
FUNCTION = "override"
500531

501532
def override(self, *args, device=None, **kwargs):
533+
original_text_device = get_current_text_encoder_device()
502534
if device is not None:
503535
set_current_text_encoder_device(device)
504536
kwargs['device'] = 'default'
505537
fn = getattr(super(), cls.FUNCTION)
506538
out = fn(*args, **kwargs)
507-
return out
539+
try:
540+
return out
541+
finally:
542+
set_current_text_encoder_device(original_text_device)
508543

509544
return NodeOverride
510545

511546

512547
def override_class_clip_no_device(cls):
513548
"""Standard MultiGPU device override for Triple/Quad CLIP models (no device kwarg workaround)"""
514-
from . import set_current_text_encoder_device
549+
from . import set_current_text_encoder_device, get_current_text_encoder_device
515550

516551
class NodeOverride(cls):
517552
@classmethod
@@ -527,10 +562,14 @@ def INPUT_TYPES(s):
527562
FUNCTION = "override"
528563

529564
def override(self, *args, device=None, **kwargs):
565+
original_text_device = get_current_text_encoder_device()
530566
if device is not None:
531567
set_current_text_encoder_device(device)
532568
fn = getattr(super(), cls.FUNCTION)
533569
out = fn(*args, **kwargs)
534-
return out
570+
try:
571+
return out
572+
finally:
573+
set_current_text_encoder_device(original_text_device)
535574

536575
return NodeOverride

0 commit comments

Comments
 (0)