|
6 | 6 | from pathlib import Path |
7 | 7 | import logging |
8 | 8 | import folder_paths |
9 | | -import shutil |
10 | 9 | from collections import defaultdict |
11 | 10 | import hashlib |
12 | | -import tempfile |
13 | | -import subprocess |
14 | | -import gc |
15 | | -from safetensors.torch import save_file, load_file |
16 | 11 | import comfy.utils |
17 | 12 | from typing import Dict, List |
18 | 13 |
|
|
36 | 31 | current_text_encoder_device = mm.text_encoder_device() |
37 | 32 | model_allocation_store = {} |
38 | 33 |
|
| 34 | +def _has_xpu(): |
| 35 | + try: |
| 36 | + return hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available() |
| 37 | + except Exception: |
| 38 | + return False |
| 39 | + |
39 | 40 | def get_torch_device_patched(): |
40 | 41 | device = None |
41 | | - if (not torch.cuda.is_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_device).lower()): |
| 42 | + if (not (torch.cuda.is_available() or _has_xpu()) or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_device).lower()): |
42 | 43 | device = torch.device("cpu") |
43 | 44 | else: |
44 | | - device = torch.device(current_device) |
| 45 | + devs = set(get_device_list()) |
| 46 | + device = torch.device(current_device) if str(current_device) in devs else torch.device("cpu") |
45 | 47 | logging.info(f"[MultiGPU get_torch_device_patched] Returning device: {device} (current_device={current_device})") |
46 | 48 | return device |
47 | 49 |
|
48 | 50 | def text_encoder_device_patched(): |
49 | 51 | device = None |
50 | | - if (not torch.cuda.is_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_text_encoder_device).lower()): |
| 52 | + if (not (torch.cuda.is_available() or _has_xpu()) or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_text_encoder_device).lower()): |
51 | 53 | device = torch.device("cpu") |
52 | 54 | else: |
53 | | - device = torch.device(current_text_encoder_device) |
| 55 | + devs = set(get_device_list()) |
| 56 | + device = torch.device(current_text_encoder_device) if str(current_text_encoder_device) in devs else torch.device("cpu") |
54 | 57 | logging.info(f"[MultiGPU text_encoder_device_patched] Returning device: {device} (current_text_encoder_device={current_text_encoder_device})") |
55 | 58 | return device |
56 | 59 |
|
@@ -331,8 +334,18 @@ def calculate_vvram_allocation_string(model, virtual_vram_str): |
331 | 334 | return allocation_string |
332 | 335 |
|
333 | 336 | def get_device_list(): |
334 | | - import torch |
335 | | - return ["cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())] |
| 337 | + devs = ["cpu"] |
| 338 | + try: |
| 339 | + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available(): |
| 340 | + devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())] |
| 341 | + except Exception: |
| 342 | + pass |
| 343 | + try: |
| 344 | + if _has_xpu(): |
| 345 | + devs += [f"xpu:{i}" for i in range(torch.xpu.device_count())] |
| 346 | + except Exception: |
| 347 | + pass |
| 348 | + return devs |
336 | 349 |
|
337 | 350 | class DeviceSelectorMultiGPU: |
338 | 351 | @classmethod |
@@ -385,138 +398,6 @@ def adapt_embeddings(self, hyvid_embeds): |
385 | 398 | return ([[cond, pooled_dict]],) |
386 | 399 |
|
387 | 400 |
|
388 | | -class MergeFluxLoRAsQuantizeAndLoad: |
389 | | - @classmethod |
390 | | - def INPUT_TYPES(cls): |
391 | | - unet_name = folder_paths.get_filename_list("diffusion_models") |
392 | | - loras = ["None"] + folder_paths.get_filename_list("loras") |
393 | | - inputs = { |
394 | | - "required": { |
395 | | - "unet_name": (unet_name,), |
396 | | - "switch_1": (["Off", "On"],), |
397 | | - "lora_name_1": (loras,), |
398 | | - "lora_weight_1": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), |
399 | | - "switch_2": (["Off", "On"],), |
400 | | - "lora_name_2": (loras,), |
401 | | - "lora_weight_2": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), |
402 | | - "switch_3": (["Off", "On"],), |
403 | | - "lora_name_3": (loras,), |
404 | | - "lora_weight_3": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), |
405 | | - "switch_4": (["Off", "On"],), |
406 | | - "lora_name_4": (loras,), |
407 | | - "lora_weight_4": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), |
408 | | - "quantization": (["Q2_K", "Q3_K_S", "Q4_0", "Q4_1", "Q4_K_S", "Q5_0", "Q5_1", "Q5_K_S", "Q6_K", "Q8_0", "FP16"], {"default": "Q4_K_S"}), |
409 | | - "delete_final_gguf": ("BOOLEAN", {"default": False}), |
410 | | - "new_model_name": ("STRING", {"default": "merged_model"}), |
411 | | - } |
412 | | - } |
413 | | - return inputs |
414 | | - |
415 | | - RETURN_TYPES = ("MODEL",) |
416 | | - FUNCTION = "load_and_quantize" |
417 | | - CATEGORY = "loaders" |
418 | | - |
419 | | - def merge_flux_loras(self, model_sd: dict, lora_paths: list, weights: list, device="cuda") -> dict: |
420 | | - for lora_path, weight in zip(lora_paths, weights): |
421 | | - logging.info(f"[DEBUG] Merging LoRA file: {lora_path} with weight: {weight}") |
422 | | - lora_sd = load_file(lora_path, device=device) |
423 | | - for key in list(lora_sd.keys()): |
424 | | - if "lora_down" not in key: |
425 | | - continue |
426 | | - base_name = key[: key.rfind(".lora_down")] |
427 | | - up_key = key.replace("lora_down", "lora_up") |
428 | | - module_name = base_name.replace("_", ".") |
429 | | - alpha_key = f"{base_name}.alpha" |
430 | | - if module_name not in model_sd: |
431 | | - logging.info(f"[DEBUG] Module {module_name} not found in model_sd; skipping key {key}") |
432 | | - continue |
433 | | - down_weight = lora_sd[key].float() |
434 | | - up_weight = lora_sd[up_key].float() |
435 | | - alpha = float(lora_sd.get(alpha_key, up_weight.shape[0])) |
436 | | - scale = weight * alpha / up_weight.shape[0] |
437 | | - logging.info(f"[DEBUG] Merging module: {module_name} with alpha: {alpha}, scale: {scale}") |
438 | | - target_weight = model_sd[module_name] |
439 | | - if len(target_weight.shape) == 2: |
440 | | - update = (up_weight @ down_weight) * scale |
441 | | - else: |
442 | | - if down_weight.shape[2:4] == (1, 1): |
443 | | - update = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)) |
444 | | - update = update.unsqueeze(2).unsqueeze(3) * scale |
445 | | - else: |
446 | | - update = torch.nn.functional.conv2d( |
447 | | - down_weight.permute(1, 0, 2, 3), up_weight |
448 | | - ).permute(1, 0, 2, 3) * scale |
449 | | - model_sd[module_name] = target_weight + update.to(target_weight.dtype) |
450 | | - logging.info(f"[DEBUG] Updated module: {module_name}") |
451 | | - del up_weight, down_weight, update |
452 | | - del lora_sd |
453 | | - torch.cuda.empty_cache() |
454 | | - return model_sd |
455 | | - |
456 | | - def convert_to_gguf(self, model_path, working_dir): |
457 | | - base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
458 | | - convert_script = os.path.join(base_path, "ComfyUI-GGUF", "tools", "convert.py") |
459 | | - temp_gguf = os.path.join(working_dir, "temp_converted.gguf") |
460 | | - logging.info("[DEBUG] Running conversion script: " + convert_script) |
461 | | - subprocess.run([sys.executable, convert_script, "--src", model_path, "--dst", temp_gguf], check=True) |
462 | | - logging.info("[DEBUG] Conversion complete.") |
463 | | - return temp_gguf |
464 | | - |
465 | | - def load_and_quantize(self, unet_name, quantization, delete_final_gguf, new_model_name, **kwargs): |
466 | | - mapping = {"FP16": "F16"} |
467 | | - logging.info(f"[DEBUG] Starting load_and_quantize: {new_model_name} | Quantization: {quantization}") |
468 | | - with tempfile.TemporaryDirectory() as merge_dir: |
469 | | - merged_model_path = os.path.join(merge_dir, "merged_model.safetensors") |
470 | | - model_path = folder_paths.get_full_path("diffusion_models", unet_name) |
471 | | - lora_list = [] |
472 | | - for i in range(1, 5): |
473 | | - name = kwargs.get(f"lora_name_{i}", "None") |
474 | | - switch = kwargs.get(f"switch_{i}", "Off") |
475 | | - logging.info(f"[DEBUG] Processing LoRA slot {i}: name = {name}, switch = {switch}") |
476 | | - if switch == "On" and name and name != "None": |
477 | | - lora_file_path = folder_paths.get_full_path("loras", name) |
478 | | - weight = kwargs.get(f"lora_weight_{i}", 1.0) |
479 | | - lora_list.append((lora_file_path, weight)) |
480 | | - logging.info(f"[DEBUG] Slot {i} active: path = {lora_file_path}, weight = {weight}") |
481 | | - else: |
482 | | - logging.info(f"[DEBUG] Slot {i} is inactive") |
483 | | - logging.info(f"[DEBUG] Total active LoRAs: {len(lora_list)}") |
484 | | - if lora_list: |
485 | | - model_sd = load_file(model_path, device="cuda") |
486 | | - model_sd = self.merge_flux_loras( |
487 | | - model_sd, |
488 | | - [lp for lp, _ in lora_list], |
489 | | - [w for _, w in lora_list] |
490 | | - ) |
491 | | - save_file(model_sd, merged_model_path) |
492 | | - del model_sd |
493 | | - torch.cuda.empty_cache() |
494 | | - else: |
495 | | - shutil.copy2(model_path, merged_model_path) |
496 | | - initial_gguf = self.convert_to_gguf(merged_model_path, merge_dir) |
497 | | - logging.info("[DEBUG] Initial GGUF file created.") |
498 | | - if quantization == "FP16": |
499 | | - final_gguf = os.path.join(merge_dir, f"{new_model_name}-{mapping.get(quantization, quantization)}.gguf") |
500 | | - shutil.copy2(initial_gguf, final_gguf) |
501 | | - logging.info("[DEBUG] FP16 selected; conversion skipped.") |
502 | | - else: |
503 | | - binary = os.path.join(os.path.dirname(os.path.abspath(__file__)), "binaries", "linux", "llama-quantize") |
504 | | - final_gguf = os.path.join(merge_dir, f"quantized_{quantization}.gguf") |
505 | | - subprocess.run([binary, initial_gguf, final_gguf, quantization], check=True) |
506 | | - logging.info("[DEBUG] Quantization completed.") |
507 | | - models_dir = os.path.join(folder_paths.models_dir, "unet") |
508 | | - os.makedirs(models_dir, exist_ok=True) |
509 | | - final_name = f"{new_model_name}-{mapping.get(quantization, quantization)}.gguf" |
510 | | - final_path = os.path.join(models_dir, final_name) |
511 | | - shutil.copy2(final_gguf, final_path) |
512 | | - logging.info("[DEBUG] Final model file copied to: " + final_path) |
513 | | - logging.info("[DEBUG] Loading final model.") |
514 | | - loader = UnetLoaderGGUF() |
515 | | - result = loader.load_unet(final_name) |
516 | | - logging.info("[DEBUG] Final model loaded.") |
517 | | - if delete_final_gguf: |
518 | | - os.unlink(final_path) |
519 | | - return result |
520 | 401 |
|
521 | 402 |
|
522 | 403 | def override_class(cls): |
@@ -611,7 +492,7 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v |
611 | 492 | vram_string = "" |
612 | 493 | if virtual_vram_gb > 0: |
613 | 494 | if use_other_vram: |
614 | | - available_devices = [d for d in get_device_list() if d.startswith('cuda')] |
| 495 | + available_devices = [d for d in get_device_list() if d.startswith(("cuda", "xpu"))] |
615 | 496 | other_devices = [d for d in available_devices if d != device] |
616 | 497 | other_devices.sort(key=lambda x: int(x.split(':')[1] if ':' in x else x[-1]), reverse=False) |
617 | 498 | device_string = ','.join(other_devices + ['cpu']) |
@@ -667,7 +548,7 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v |
667 | 548 | vram_string = "" |
668 | 549 | if virtual_vram_gb > 0: |
669 | 550 | if use_other_vram: |
670 | | - available_devices = [d for d in get_device_list() if d.startswith('cuda')] |
| 551 | + available_devices = [d for d in get_device_list() if d.startswith(("cuda", "xpu"))] |
671 | 552 | other_devices = [d for d in available_devices if d != device] |
672 | 553 | other_devices.sort(key=lambda x: int(x.split(':')[1] if ':' in x else x[-1]), reverse=False) |
673 | 554 | device_string = ','.join(other_devices + ['cpu']) |
@@ -704,7 +585,6 @@ def check_module_exists(module_path): |
704 | 585 | "HunyuanVideoEmbeddingsAdapter": HunyuanVideoEmbeddingsAdapter, |
705 | 586 | } |
706 | 587 |
|
707 | | -NODE_CLASS_MAPPINGS["MergeFluxLoRAsQuantizeAndLoaddMultiGPU"] = override_class(MergeFluxLoRAsQuantizeAndLoad) |
708 | 588 |
|
709 | 589 | NODE_CLASS_MAPPINGS["UNETLoaderMultiGPU"] = override_class(GLOBAL_NODE_CLASS_MAPPINGS["UNETLoader"]) |
710 | 590 | NODE_CLASS_MAPPINGS["VAELoaderMultiGPU"] = override_class(GLOBAL_NODE_CLASS_MAPPINGS["VAELoader"]) |
|
0 commit comments