2323)
2424
2525WEB_DIRECTORY = "./web"
26- MGPU_MM_LOG = True
26+ MGPU_MM_LOG = False
2727DEBUG_LOG = False
2828
2929logger = logging .getLogger ("MultiGPU" )
@@ -148,6 +148,7 @@ def check_module_exists(module_path):
148148
149149current_device = mm .get_torch_device ()
150150current_text_encoder_device = mm .text_encoder_device ()
151+ current_unet_offload_device = mm .unet_offload_device ()
151152
152153def set_current_device (device ):
153154 """Set the current device context for MultiGPU operations."""
@@ -161,6 +162,12 @@ def set_current_text_encoder_device(device):
161162 current_text_encoder_device = device
162163 logger .debug (f"[MultiGPU Initialization] current_text_encoder_device set to: { device } " )
163164
165+ def set_current_unet_offload_device (device ):
166+ """Set the current UNet offload device context."""
167+ global current_unet_offload_device
168+ current_unet_offload_device = device
169+ logger .debug (f"[MultiGPU Initialization] current_unet_offload_device set to: { device } " )
170+
164171def get_torch_device_patched ():
165172 """Return MultiGPU-aware device selection for patched mm.get_torch_device."""
166173 device = None
@@ -183,11 +190,25 @@ def text_encoder_device_patched():
183190 logger .info (f"[MultiGPU Core Patching] text_encoder_device_patched returning device: { device } (current_text_encoder_device={ current_text_encoder_device } )" )
184191 return device
185192
186- logger .info (f"[MultiGPU Core Patching] Patching mm.get_torch_device and mm.text_encoder_device" )
193+ def unet_offload_device_patched ():
194+ """Return MultiGPU-aware UNet offload device for patched mm.unet_offload_device."""
195+ device = None
196+ if (not is_accelerator_available () or mm .cpu_state == mm .CPUState .CPU or "cpu" in str (current_unet_offload_device ).lower ()):
197+ device = torch .device ("cpu" )
198+ else :
199+ devs = set (get_device_list ())
200+ device = torch .device (current_unet_offload_device ) if str (current_unet_offload_device ) in devs else torch .device ("cpu" )
201+ logger .debug (f"[MultiGPU Core Patching] unet_offload_device_patched returning device: { device } (current_unet_offload_device={ current_unet_offload_device } )" )
202+ return device
203+
204+ logger .info (f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device" )
187205logger .info (f"[MultiGPU DEBUG] Initial current_device: { current_device } " )
188206logger .info (f"[MultiGPU DEBUG] Initial current_text_encoder_device: { current_text_encoder_device } " )
207+ logger .info (f"[MultiGPU DEBUG] Initial current_unet_offload_device: { current_unet_offload_device } " )
208+
189209mm .get_torch_device = get_torch_device_patched
190210mm .text_encoder_device = text_encoder_device_patched
211+ mm .unet_offload_device = unet_offload_device_patched
191212
192213from .nodes import (
193214 UnetLoaderGGUF ,
@@ -235,6 +256,7 @@ def text_encoder_device_patched():
235256
236257from .wrappers import (
237258 override_class ,
259+ override_class_offload ,
238260 override_class_clip ,
239261 override_class_clip_no_device ,
240262 override_class_with_distorch_gguf ,
@@ -319,8 +341,8 @@ def register_and_count(module_names, node_map):
319341register_and_count (["ComfyUI-LTXVideo" , "comfyui-ltxvideo" ], ltx_nodes )
320342
321343florence_nodes = {
322- "Florence2ModelLoaderMultiGPU" : override_class (Florence2ModelLoader ),
323- "DownloadAndLoadFlorence2ModelMultiGPU" : override_class (DownloadAndLoadFlorence2Model )
344+ "Florence2ModelLoaderMultiGPU" : override_class_offload (Florence2ModelLoader ),
345+ "DownloadAndLoadFlorence2ModelMultiGPU" : override_class_offload (DownloadAndLoadFlorence2Model )
324346}
325347register_and_count (["ComfyUI-Florence2" , "comfyui-florence2" ], florence_nodes )
326348
0 commit comments