Skip to content

Commit 0d6eba7

Browse files
authored
Close #123, Merge pull request #126 from pollockjj/ew
2 parents d6137f3 + 5d61440 commit 0d6eba7

87 files changed

Lines changed: 8676 additions & 14180 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 185 additions & 52 deletions
Large diffs are not rendered by default.

__init__.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424

2525
WEB_DIRECTORY = "./web"
26-
MGPU_MM_LOG = True
26+
MGPU_MM_LOG = False
2727
DEBUG_LOG = False
2828

2929
logger = logging.getLogger("MultiGPU")
@@ -148,6 +148,7 @@ def check_module_exists(module_path):
148148

149149
current_device = mm.get_torch_device()
150150
current_text_encoder_device = mm.text_encoder_device()
151+
current_unet_offload_device = mm.unet_offload_device()
151152

152153
def set_current_device(device):
153154
"""Set the current device context for MultiGPU operations."""
@@ -161,6 +162,12 @@ def set_current_text_encoder_device(device):
161162
current_text_encoder_device = device
162163
logger.debug(f"[MultiGPU Initialization] current_text_encoder_device set to: {device}")
163164

165+
def set_current_unet_offload_device(device):
166+
"""Set the current UNet offload device context."""
167+
global current_unet_offload_device
168+
current_unet_offload_device = device
169+
logger.debug(f"[MultiGPU Initialization] current_unet_offload_device set to: {device}")
170+
164171
def get_torch_device_patched():
165172
"""Return MultiGPU-aware device selection for patched mm.get_torch_device."""
166173
device = None
@@ -183,11 +190,25 @@ def text_encoder_device_patched():
183190
logger.info(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
184191
return device
185192

186-
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device and mm.text_encoder_device")
193+
def unet_offload_device_patched():
194+
"""Return MultiGPU-aware UNet offload device for patched mm.unet_offload_device."""
195+
device = None
196+
if (not is_accelerator_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_unet_offload_device).lower()):
197+
device = torch.device("cpu")
198+
else:
199+
devs = set(get_device_list())
200+
device = torch.device(current_unet_offload_device) if str(current_unet_offload_device) in devs else torch.device("cpu")
201+
logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})")
202+
return device
203+
204+
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
187205
logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
188206
logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
207+
logger.info(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}")
208+
189209
mm.get_torch_device = get_torch_device_patched
190210
mm.text_encoder_device = text_encoder_device_patched
211+
mm.unet_offload_device = unet_offload_device_patched
191212

192213
from .nodes import (
193214
UnetLoaderGGUF,
@@ -235,6 +256,7 @@ def text_encoder_device_patched():
235256

236257
from .wrappers import (
237258
override_class,
259+
override_class_offload,
238260
override_class_clip,
239261
override_class_clip_no_device,
240262
override_class_with_distorch_gguf,
@@ -319,8 +341,8 @@ def register_and_count(module_names, node_map):
319341
register_and_count(["ComfyUI-LTXVideo", "comfyui-ltxvideo"], ltx_nodes)
320342

321343
florence_nodes = {
322-
"Florence2ModelLoaderMultiGPU": override_class(Florence2ModelLoader),
323-
"DownloadAndLoadFlorence2ModelMultiGPU": override_class(DownloadAndLoadFlorence2Model)
344+
"Florence2ModelLoaderMultiGPU": override_class_offload(Florence2ModelLoader),
345+
"DownloadAndLoadFlorence2ModelMultiGPU": override_class_offload(DownloadAndLoadFlorence2Model)
324346
}
325347
register_and_count(["ComfyUI-Florence2", "comfyui-florence2"], florence_nodes)
326348

256 KB
Loading

0 commit comments

Comments
 (0)