Skip to content

Commit 7a08dd9

Browse files
committed
feat: Experimental XPU support
Add guarded Intel XPU support alongside CUDA: - get_device_list now includes xpu:N when available - device selection (model/text encoder) considers CUDA or XPU and validates devices - DisTorch donor/offload selection includes xpu devices Also: remove unused MergeFluxLoRAs node and mapping; delete tools/ and precompiled_binaries/; bump project version to 1.8.1.
1 parent 34a1559 commit 7a08dd9

6 files changed

Lines changed: 4824 additions & 1535 deletions

File tree

__init__.py

Lines changed: 26 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,8 @@
66
from pathlib import Path
77
import logging
88
import folder_paths
9-
import shutil
109
from collections import defaultdict
1110
import hashlib
12-
import tempfile
13-
import subprocess
14-
import gc
15-
from safetensors.torch import save_file, load_file
1611
import comfy.utils
1712
from typing import Dict, List
1813

@@ -36,21 +31,29 @@
3631
current_text_encoder_device = mm.text_encoder_device()
3732
model_allocation_store = {}
3833

34+
def _has_xpu():
35+
try:
36+
return hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available()
37+
except Exception:
38+
return False
39+
3940
def get_torch_device_patched():
4041
device = None
41-
if (not torch.cuda.is_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_device).lower()):
42+
if (not (torch.cuda.is_available() or _has_xpu()) or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_device).lower()):
4243
device = torch.device("cpu")
4344
else:
44-
device = torch.device(current_device)
45+
devs = set(get_device_list())
46+
device = torch.device(current_device) if str(current_device) in devs else torch.device("cpu")
4547
logging.info(f"[MultiGPU get_torch_device_patched] Returning device: {device} (current_device={current_device})")
4648
return device
4749

4850
def text_encoder_device_patched():
4951
device = None
50-
if (not torch.cuda.is_available() or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_text_encoder_device).lower()):
52+
if (not (torch.cuda.is_available() or _has_xpu()) or mm.cpu_state == mm.CPUState.CPU or "cpu" in str(current_text_encoder_device).lower()):
5153
device = torch.device("cpu")
5254
else:
53-
device = torch.device(current_text_encoder_device)
55+
devs = set(get_device_list())
56+
device = torch.device(current_text_encoder_device) if str(current_text_encoder_device) in devs else torch.device("cpu")
5457
logging.info(f"[MultiGPU text_encoder_device_patched] Returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
5558
return device
5659

@@ -331,8 +334,18 @@ def calculate_vvram_allocation_string(model, virtual_vram_str):
331334
return allocation_string
332335

333336
def get_device_list():
334-
import torch
335-
return ["cpu"] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
337+
devs = ["cpu"]
338+
try:
339+
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
340+
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
341+
except Exception:
342+
pass
343+
try:
344+
if _has_xpu():
345+
devs += [f"xpu:{i}" for i in range(torch.xpu.device_count())]
346+
except Exception:
347+
pass
348+
return devs
336349

337350
class DeviceSelectorMultiGPU:
338351
@classmethod
@@ -385,138 +398,6 @@ def adapt_embeddings(self, hyvid_embeds):
385398
return ([[cond, pooled_dict]],)
386399

387400

388-
class MergeFluxLoRAsQuantizeAndLoad:
389-
@classmethod
390-
def INPUT_TYPES(cls):
391-
unet_name = folder_paths.get_filename_list("diffusion_models")
392-
loras = ["None"] + folder_paths.get_filename_list("loras")
393-
inputs = {
394-
"required": {
395-
"unet_name": (unet_name,),
396-
"switch_1": (["Off", "On"],),
397-
"lora_name_1": (loras,),
398-
"lora_weight_1": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
399-
"switch_2": (["Off", "On"],),
400-
"lora_name_2": (loras,),
401-
"lora_weight_2": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
402-
"switch_3": (["Off", "On"],),
403-
"lora_name_3": (loras,),
404-
"lora_weight_3": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
405-
"switch_4": (["Off", "On"],),
406-
"lora_name_4": (loras,),
407-
"lora_weight_4": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
408-
"quantization": (["Q2_K", "Q3_K_S", "Q4_0", "Q4_1", "Q4_K_S", "Q5_0", "Q5_1", "Q5_K_S", "Q6_K", "Q8_0", "FP16"], {"default": "Q4_K_S"}),
409-
"delete_final_gguf": ("BOOLEAN", {"default": False}),
410-
"new_model_name": ("STRING", {"default": "merged_model"}),
411-
}
412-
}
413-
return inputs
414-
415-
RETURN_TYPES = ("MODEL",)
416-
FUNCTION = "load_and_quantize"
417-
CATEGORY = "loaders"
418-
419-
def merge_flux_loras(self, model_sd: dict, lora_paths: list, weights: list, device="cuda") -> dict:
420-
for lora_path, weight in zip(lora_paths, weights):
421-
logging.info(f"[DEBUG] Merging LoRA file: {lora_path} with weight: {weight}")
422-
lora_sd = load_file(lora_path, device=device)
423-
for key in list(lora_sd.keys()):
424-
if "lora_down" not in key:
425-
continue
426-
base_name = key[: key.rfind(".lora_down")]
427-
up_key = key.replace("lora_down", "lora_up")
428-
module_name = base_name.replace("_", ".")
429-
alpha_key = f"{base_name}.alpha"
430-
if module_name not in model_sd:
431-
logging.info(f"[DEBUG] Module {module_name} not found in model_sd; skipping key {key}")
432-
continue
433-
down_weight = lora_sd[key].float()
434-
up_weight = lora_sd[up_key].float()
435-
alpha = float(lora_sd.get(alpha_key, up_weight.shape[0]))
436-
scale = weight * alpha / up_weight.shape[0]
437-
logging.info(f"[DEBUG] Merging module: {module_name} with alpha: {alpha}, scale: {scale}")
438-
target_weight = model_sd[module_name]
439-
if len(target_weight.shape) == 2:
440-
update = (up_weight @ down_weight) * scale
441-
else:
442-
if down_weight.shape[2:4] == (1, 1):
443-
update = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
444-
update = update.unsqueeze(2).unsqueeze(3) * scale
445-
else:
446-
update = torch.nn.functional.conv2d(
447-
down_weight.permute(1, 0, 2, 3), up_weight
448-
).permute(1, 0, 2, 3) * scale
449-
model_sd[module_name] = target_weight + update.to(target_weight.dtype)
450-
logging.info(f"[DEBUG] Updated module: {module_name}")
451-
del up_weight, down_weight, update
452-
del lora_sd
453-
torch.cuda.empty_cache()
454-
return model_sd
455-
456-
def convert_to_gguf(self, model_path, working_dir):
457-
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
458-
convert_script = os.path.join(base_path, "ComfyUI-GGUF", "tools", "convert.py")
459-
temp_gguf = os.path.join(working_dir, "temp_converted.gguf")
460-
logging.info("[DEBUG] Running conversion script: " + convert_script)
461-
subprocess.run([sys.executable, convert_script, "--src", model_path, "--dst", temp_gguf], check=True)
462-
logging.info("[DEBUG] Conversion complete.")
463-
return temp_gguf
464-
465-
def load_and_quantize(self, unet_name, quantization, delete_final_gguf, new_model_name, **kwargs):
466-
mapping = {"FP16": "F16"}
467-
logging.info(f"[DEBUG] Starting load_and_quantize: {new_model_name} | Quantization: {quantization}")
468-
with tempfile.TemporaryDirectory() as merge_dir:
469-
merged_model_path = os.path.join(merge_dir, "merged_model.safetensors")
470-
model_path = folder_paths.get_full_path("diffusion_models", unet_name)
471-
lora_list = []
472-
for i in range(1, 5):
473-
name = kwargs.get(f"lora_name_{i}", "None")
474-
switch = kwargs.get(f"switch_{i}", "Off")
475-
logging.info(f"[DEBUG] Processing LoRA slot {i}: name = {name}, switch = {switch}")
476-
if switch == "On" and name and name != "None":
477-
lora_file_path = folder_paths.get_full_path("loras", name)
478-
weight = kwargs.get(f"lora_weight_{i}", 1.0)
479-
lora_list.append((lora_file_path, weight))
480-
logging.info(f"[DEBUG] Slot {i} active: path = {lora_file_path}, weight = {weight}")
481-
else:
482-
logging.info(f"[DEBUG] Slot {i} is inactive")
483-
logging.info(f"[DEBUG] Total active LoRAs: {len(lora_list)}")
484-
if lora_list:
485-
model_sd = load_file(model_path, device="cuda")
486-
model_sd = self.merge_flux_loras(
487-
model_sd,
488-
[lp for lp, _ in lora_list],
489-
[w for _, w in lora_list]
490-
)
491-
save_file(model_sd, merged_model_path)
492-
del model_sd
493-
torch.cuda.empty_cache()
494-
else:
495-
shutil.copy2(model_path, merged_model_path)
496-
initial_gguf = self.convert_to_gguf(merged_model_path, merge_dir)
497-
logging.info("[DEBUG] Initial GGUF file created.")
498-
if quantization == "FP16":
499-
final_gguf = os.path.join(merge_dir, f"{new_model_name}-{mapping.get(quantization, quantization)}.gguf")
500-
shutil.copy2(initial_gguf, final_gguf)
501-
logging.info("[DEBUG] FP16 selected; conversion skipped.")
502-
else:
503-
binary = os.path.join(os.path.dirname(os.path.abspath(__file__)), "binaries", "linux", "llama-quantize")
504-
final_gguf = os.path.join(merge_dir, f"quantized_{quantization}.gguf")
505-
subprocess.run([binary, initial_gguf, final_gguf, quantization], check=True)
506-
logging.info("[DEBUG] Quantization completed.")
507-
models_dir = os.path.join(folder_paths.models_dir, "unet")
508-
os.makedirs(models_dir, exist_ok=True)
509-
final_name = f"{new_model_name}-{mapping.get(quantization, quantization)}.gguf"
510-
final_path = os.path.join(models_dir, final_name)
511-
shutil.copy2(final_gguf, final_path)
512-
logging.info("[DEBUG] Final model file copied to: " + final_path)
513-
logging.info("[DEBUG] Loading final model.")
514-
loader = UnetLoaderGGUF()
515-
result = loader.load_unet(final_name)
516-
logging.info("[DEBUG] Final model loaded.")
517-
if delete_final_gguf:
518-
os.unlink(final_path)
519-
return result
520401

521402

522403
def override_class(cls):
@@ -611,7 +492,7 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v
611492
vram_string = ""
612493
if virtual_vram_gb > 0:
613494
if use_other_vram:
614-
available_devices = [d for d in get_device_list() if d.startswith('cuda')]
495+
available_devices = [d for d in get_device_list() if d.startswith(("cuda", "xpu"))]
615496
other_devices = [d for d in available_devices if d != device]
616497
other_devices.sort(key=lambda x: int(x.split(':')[1] if ':' in x else x[-1]), reverse=False)
617498
device_string = ','.join(other_devices + ['cpu'])
@@ -667,7 +548,7 @@ def override(self, *args, device=None, expert_mode_allocations=None, use_other_v
667548
vram_string = ""
668549
if virtual_vram_gb > 0:
669550
if use_other_vram:
670-
available_devices = [d for d in get_device_list() if d.startswith('cuda')]
551+
available_devices = [d for d in get_device_list() if d.startswith(("cuda", "xpu"))]
671552
other_devices = [d for d in available_devices if d != device]
672553
other_devices.sort(key=lambda x: int(x.split(':')[1] if ':' in x else x[-1]), reverse=False)
673554
device_string = ','.join(other_devices + ['cpu'])
@@ -704,7 +585,6 @@ def check_module_exists(module_path):
704585
"HunyuanVideoEmbeddingsAdapter": HunyuanVideoEmbeddingsAdapter,
705586
}
706587

707-
NODE_CLASS_MAPPINGS["MergeFluxLoRAsQuantizeAndLoaddMultiGPU"] = override_class(MergeFluxLoRAsQuantizeAndLoad)
708588

709589
NODE_CLASS_MAPPINGS["UNETLoaderMultiGPU"] = override_class(GLOBAL_NODE_CLASS_MAPPINGS["UNETLoader"])
710590
NODE_CLASS_MAPPINGS["VAELoaderMultiGPU"] = override_class(GLOBAL_NODE_CLASS_MAPPINGS["VAELoader"])

0 commit comments

Comments
 (0)