Skip to content

Commit 1611742

Browse files
committed
Add UNet offload device support, and enhance Florence2 model loading with safetensors conversion option.
1 parent 0d6dfc7 commit 1611742

3 files changed

Lines changed: 68 additions & 11 deletions

File tree

__init__.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424

2525
WEB_DIRECTORY = "./web"
26-
MGPU_MM_LOG = True
26+
MGPU_MM_LOG = False
2727
DEBUG_LOG = False
2828

2929
logger = logging.getLogger("MultiGPU")
@@ -148,6 +148,7 @@ def check_module_exists(module_path):
148148

149149
current_device = mm.get_torch_device()
150150
current_text_encoder_device = mm.text_encoder_device()
151+
current_unet_offload_device = mm.unet_offload_device()
151152

152153
def set_current_device(device):
153154
"""Set the current device context for MultiGPU operations."""
@@ -161,6 +162,12 @@ def set_current_text_encoder_device(device):
161162
current_text_encoder_device = device
162163
logger.debug(f"[MultiGPU Initialization] current_text_encoder_device set to: {device}")
163164

165+
def set_current_unet_offload_device(device):
166+
"""Set the current UNet offload device context."""
167+
global current_unet_offload_device
168+
current_unet_offload_device = device
169+
logger.debug(f"[MultiGPU Initialization] current_unet_offload_device set to: {device}")
170+
164171
def get_torch_device_patched():
165172
"""Return MultiGPU-aware device selection for patched mm.get_torch_device."""
166173
device = None
@@ -183,11 +190,25 @@ def text_encoder_device_patched():
183190
logger.info(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
184191
return device
185192

186-
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device and mm.text_encoder_device")
193+
def unet_offload_device_patched():
194+
"""Return MultiGPU-aware UNet offload device for patched mm.unet_offload_device."""
195+
device = None
196+
if (not is_accelerator_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_unet_offload_device).lower()):
197+
device = torch.device("cpu")
198+
else:
199+
devs = set(get_device_list())
200+
device = torch.device(current_unet_offload_device) if str(current_unet_offload_device) in devs else torch.device("cpu")
201+
logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})")
202+
return device
203+
204+
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
187205
logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
188206
logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
207+
logger.info(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}")
208+
189209
mm.get_torch_device = get_torch_device_patched
190210
mm.text_encoder_device = text_encoder_device_patched
211+
mm.unet_offload_device = unet_offload_device_patched
191212

192213
from .nodes import (
193214
UnetLoaderGGUF,
@@ -235,6 +256,7 @@ def text_encoder_device_patched():
235256

236257
from .wrappers import (
237258
override_class,
259+
override_class_offload,
238260
override_class_clip,
239261
override_class_clip_no_device,
240262
override_class_with_distorch_gguf,
@@ -319,8 +341,8 @@ def register_and_count(module_names, node_map):
319341
register_and_count(["ComfyUI-LTXVideo", "comfyui-ltxvideo"], ltx_nodes)
320342

321343
florence_nodes = {
322-
"Florence2ModelLoaderMultiGPU": override_class(Florence2ModelLoader),
323-
"DownloadAndLoadFlorence2ModelMultiGPU": override_class(DownloadAndLoadFlorence2Model)
344+
"Florence2ModelLoaderMultiGPU": override_class_offload(Florence2ModelLoader),
345+
"DownloadAndLoadFlorence2ModelMultiGPU": override_class_offload(DownloadAndLoadFlorence2Model)
324346
}
325347
register_and_count(["ComfyUI-Florence2", "comfyui-florence2"], florence_nodes)
326348

nodes.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,12 @@ def _load_vae(self, weights, config=None):
176176
return original_loader._load_vae(weights, config=None)
177177

178178
class Florence2ModelLoader:
179-
@classmethod
180179
def INPUT_TYPES(s):
180+
all_llm_paths = folder_paths.get_folder_paths("LLM")
181+
s.model_paths = create_path_dict(all_llm_paths, lambda x: x.is_dir())
182+
181183
return {"required": {
182-
"model": ([item.name for item in Path(folder_paths.models_dir, "LLM").iterdir() if item.is_dir()], {"tooltip": "models are expected to be in Comfyui/models/LLM folder"}),
184+
"model": ([*s.model_paths], {"tooltip": "models are expected to be in Comfyui/models/LLM folder"}),
183185
"precision": (['fp16','bf16','fp32'],),
184186
"attention": (
185187
[ 'flash_attention_2', 'sdpa', 'eager'],
@@ -189,6 +191,7 @@ def INPUT_TYPES(s):
189191
},
190192
"optional": {
191193
"lora": ("PEFTLORA",),
194+
"convert_to_safetensors": ("BOOLEAN", {"default": False, "tooltip": "Some of the older model weights are not saved in .safetensors format, which seem to cause longer loading times, this option converts the .bin weights to .safetensors"}),
192195
}
193196
}
194197

@@ -197,10 +200,10 @@ def INPUT_TYPES(s):
197200
FUNCTION = "loadmodel"
198201
CATEGORY = "Florence2"
199202

200-
def loadmodel(self, model, precision, attention, lora=None):
203+
def loadmodel(self, model, precision, attention, lora=None, convert_to_safetensors=False):
201204
"""Load Florence2 vision model with specified precision and attention mode."""
202205
original_loader = NODE_CLASS_MAPPINGS["Florence2ModelLoader"]()
203-
return original_loader.loadmodel(model, precision, attention, lora)
206+
return original_loader.loadmodel(model, precision, attention, lora, convert_to_safetensors)
204207

205208
class DownloadAndLoadFlorence2Model:
206209
@classmethod
@@ -220,7 +223,8 @@ def INPUT_TYPES(s):
220223
'MiaoshouAI/Florence-2-base-PromptGen-v1.5',
221224
'MiaoshouAI/Florence-2-large-PromptGen-v1.5',
222225
'MiaoshouAI/Florence-2-base-PromptGen-v2.0',
223-
'MiaoshouAI/Florence-2-large-PromptGen-v2.0'
226+
'MiaoshouAI/Florence-2-large-PromptGen-v2.0',
227+
'PJMixers-Images/Florence-2-base-Castollux-v0.5'
224228
],
225229
{
226230
"default": 'microsoft/Florence-2-base'
@@ -237,6 +241,7 @@ def INPUT_TYPES(s):
237241
},
238242
"optional": {
239243
"lora": ("PEFTLORA",),
244+
"convert_to_safetensors": ("BOOLEAN", {"default": False, "tooltip": "Some of the older model weights are not saved in .safetensors format, which seem to cause longer loading times, this option converts the .bin weights to .safetensors"}),
240245
}
241246
}
242247

@@ -245,10 +250,10 @@ def INPUT_TYPES(s):
245250
FUNCTION = "loadmodel"
246251
CATEGORY = "Florence2"
247252

248-
def loadmodel(self, model, precision, attention, lora=None):
253+
def loadmodel(self, model, precision, attention, lora=None, convert_to_safetensors=False):
249254
"""Download and load Florence2 model from HuggingFace."""
250255
original_loader = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]()
251-
return original_loader.loadmodel(model, precision, attention, lora)
256+
return original_loader.loadmodel(model, precision, attention, lora, convert_to_safetensors)
252257

253258
class CheckpointLoaderNF4:
254259
@classmethod

wrappers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,36 @@ def override(self, *args, device=None, **kwargs):
484484

485485
return NodeOverride
486486

487+
def override_class_offload(cls):
488+
"""Standard MultiGPU device override for UNet/VAE models"""
489+
from . import set_current_device, set_current_unet_offload_device
490+
491+
class NodeOverride(cls):
492+
@classmethod
493+
def INPUT_TYPES(s):
494+
inputs = copy.deepcopy(cls.INPUT_TYPES())
495+
devices = get_device_list()
496+
default_device = devices[1] if len(devices) > 1 else devices[0]
497+
inputs["optional"] = inputs.get("optional", {})
498+
inputs["optional"]["device"] = (devices, {"default": default_device})
499+
inputs["optional"]["offload_device"] = (devices, {"default": "cpu"})
500+
return inputs
501+
502+
CATEGORY = "multigpu"
503+
FUNCTION = "override"
504+
505+
def override(self, *args, device=None, offload_device=None, **kwargs):
506+
if device is not None:
507+
set_current_device(device)
508+
if offload_device is not None:
509+
set_current_unet_offload_device(offload_device)
510+
fn = getattr(super(), cls.FUNCTION)
511+
out = fn(*args, **kwargs)
512+
return out
513+
514+
return NodeOverride
515+
516+
487517

488518
def override_class_clip(cls):
489519
"""Standard MultiGPU device override for CLIP models (with device kwarg workaround)"""

0 commit comments

Comments
 (0)