Skip to content

Commit 610c210

Browse files
committed
Add multi-GPU support for remaining planned WanVideo model loaders
1 parent 375cbd0 commit 610c210

2 files changed

Lines changed: 237 additions & 0 deletions

File tree

__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ def text_encoder_device_patched():
226226
WanVideoEncode,
227227
LoadWanVideoClipTextEncoder,
228228
WanVideoClipVisionEncode,
229+
WanVideoControlnetLoaderMultiGPU,
230+
FantasyTalkingModelLoaderMultiGPU,
231+
Wav2VecModelLoaderMultiGPU,
232+
WanVideoUni3C_ControlnetLoaderMultiGPU,
233+
DownloadAndLoadWav2VecModelMultiGPU,
229234
)
230235

231236
from .wrappers import (
@@ -377,6 +382,11 @@ def register_and_count(module_names, node_map):
377382
"WanVideoEncodeMultiGPU": WanVideoEncode,
378383
"LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder,
379384
"WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode,
385+
"WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoaderMultiGPU,
386+
"FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoaderMultiGPU,
387+
"Wav2VecModelLoaderMultiGPU": Wav2VecModelLoaderMultiGPU,
388+
"WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoaderMultiGPU,
389+
"DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModelMultiGPU,
380390
}
381391
register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], wanvideo_nodes)
382392

wanvideo.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,230 @@ def process(self, clip_vision, load_device, image_1, strength_1, strength_2, for
712712
return original_encode.process(clip_vision[0], image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2, negative_image, tiles, ratio)
713713
finally:
714714
encode_module.device = original_module_device
715+
716+
class WanVideoControlnetLoader:
717+
@classmethod
718+
def INPUT_TYPES(s):
719+
return {
720+
"required": {
721+
"model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "These models are loaded from the 'ComfyUI/models/controlnet' -folder",}),
722+
723+
"base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}),
724+
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2', 'fp8_e4m3fn_fast_no_ffn'], {"default": 'disabled', "tooltip": "optional quantization method"}),
725+
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
726+
},
727+
}
728+
729+
RETURN_TYPES = ("WANVIDEOCONTROLNET",)
730+
RETURN_NAMES = ("controlnet", )
731+
FUNCTION = "loadmodel"
732+
CATEGORY = "WanVideoWrapper"
733+
DESCRIPTION = "Loads ControlNet model from 'https://huggingface.co/collections/TheDenk/wan21-controlnets-68302b430411dafc0d74d2fc'"
734+
735+
def loadmodel(self, model, base_precision, load_device, quantization):
736+
737+
device = mm.get_torch_device()
738+
offload_device = mm.unet_offload_device()
739+
740+
transformer_load_device = device if load_device == "main_device" else offload_device
741+
742+
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision]
743+
744+
model_path = folder_paths.get_full_path_or_raise("controlnet", model)
745+
746+
sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True)
747+
748+
num_layers = 8 if "blocks.7.scale_shift_table" in sd else 6
749+
out_proj_dim = sd["controlnet_blocks.0.bias"].shape[0]
750+
downscale_coef = 16 if out_proj_dim == 3072 else 8
751+
vae_channels = 48 if out_proj_dim == 3072 else 16
752+
753+
if not "control_encoder.0.0.weight" in sd:
754+
raise ValueError("Invalid ControlNet model")
755+
756+
controlnet_cfg = {
757+
"added_kv_proj_dim": None,
758+
"attention_head_dim": 128,
759+
"cross_attn_norm": None,
760+
"downscale_coef": downscale_coef,
761+
"eps": 1e-06,
762+
"ffn_dim": 8960,
763+
"freq_dim": 256,
764+
"image_dim": None,
765+
"in_channels": 3,
766+
"num_attention_heads": 12,
767+
"num_layers": num_layers,
768+
"out_proj_dim": out_proj_dim,
769+
"patch_size": [
770+
1,
771+
2,
772+
2
773+
],
774+
"qk_norm": "rms_norm_across_heads",
775+
"rope_max_seq_len": 1024,
776+
"text_dim": 4096,
777+
"vae_channels": vae_channels
778+
}
779+
print(f"Loading WanControlnet with config: {controlnet_cfg}")
780+
781+
from .wan_controlnet import WanControlnet
782+
783+
with init_empty_weights():
784+
controlnet = WanControlnet(**controlnet_cfg)
785+
controlnet.eval()
786+
787+
if quantization == "disabled":
788+
for k, v in sd.items():
789+
if isinstance(v, torch.Tensor):
790+
if v.dtype == torch.float8_e4m3fn:
791+
quantization = "fp8_e4m3fn"
792+
break
793+
elif v.dtype == torch.float8_e5m2:
794+
quantization = "fp8_e5m2"
795+
break
796+
797+
if "fp8_e4m3fn" in quantization:
798+
dtype = torch.float8_e4m3fn
799+
elif quantization == "fp8_e5m2":
800+
dtype = torch.float8_e5m2
801+
else:
802+
dtype = base_dtype
803+
params_to_keep = {"norm", "head", "time_in", "vector_in", "controlnet_patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter"}
804+
805+
log.info("Using accelerate to load and assign controlnet model weights to device...")
806+
param_count = sum(1 for _ in controlnet.named_parameters())
807+
for name, param in tqdm(controlnet.named_parameters(),
808+
desc=f"Loading transformer parameters to {transformer_load_device}",
809+
total=param_count,
810+
leave=True):
811+
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
812+
if "controlnet_patch_embedding" in name:
813+
dtype_to_use = torch.float32
814+
set_module_tensor_to_device(controlnet, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name])
815+
816+
del sd
817+
818+
if load_device == "offload_device" and controlnet.device != offload_device:
819+
log.info(f"Moving controlnet model from {controlnet.device} to {offload_device}")
820+
controlnet.to(offload_device)
821+
gc.collect()
822+
mm.soft_empty_cache()
823+
824+
return (controlnet,)
825+
826+
class WanVideoControlnetLoaderMultiGPU:
827+
@classmethod
828+
def INPUT_TYPES(s):
829+
devices = get_device_list()
830+
default_device = devices[1] if len(devices) > 1 else devices[0]
831+
return {
832+
"required": {
833+
"model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "These models are loaded from the 'ComfyUI/models/controlnet' -folder",}),
834+
"base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}),
835+
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2', 'fp8_e4m3fn_fast_no_ffn'], {"default": 'disabled', "tooltip": "optional quantization method"}),
836+
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
837+
"device": (devices, {"default": default_device}),
838+
},
839+
}
840+
841+
RETURN_TYPES = ("WANVIDEOCONTROLNET",)
842+
RETURN_NAMES = ("controlnet", )
843+
FUNCTION = "loadmodel"
844+
CATEGORY = "multigpu/WanVideoWrapper"
845+
DESCRIPTION = "MultiGPU-aware ControlNet loader for WanVideo models"
846+
847+
def loadmodel(self, model, base_precision, load_device, quantization, device):
848+
from . import set_current_device
849+
850+
set_current_device(device)
851+
852+
original_loader = NODE_CLASS_MAPPINGS["WanVideoControlnetLoader"]()
853+
return original_loader.loadmodel(model, base_precision, load_device, quantization)
854+
855+
class FantasyTalkingModelLoaderMultiGPU:
856+
@classmethod
857+
def INPUT_TYPES(s):
858+
devices = get_device_list()
859+
default_device = devices[1] if len(devices) > 1 else devices[0]
860+
return {
861+
"required": {
862+
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
863+
"base_precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
864+
"device": (devices, {"default": default_device}),
865+
},
866+
}
867+
868+
RETURN_TYPES = ("FANTASYTALKINGMODEL",)
869+
RETURN_NAMES = ("model", )
870+
FUNCTION = "loadmodel"
871+
CATEGORY = "multigpu/WanVideoWrapper"
872+
DESCRIPTION = "MultiGPU-aware FantasyTalking model loader"
873+
874+
def loadmodel(self, model, base_precision, device):
875+
from . import set_current_device
876+
877+
set_current_device(device)
878+
879+
original_loader = NODE_CLASS_MAPPINGS["FantasyTalkingModelLoader"]()
880+
return original_loader.loadmodel(model, base_precision)
881+
882+
class Wav2VecModelLoaderMultiGPU:
883+
@classmethod
884+
def INPUT_TYPES(s):
885+
devices = get_device_list()
886+
default_device = devices[1] if len(devices) > 1 else devices[0]
887+
return {
888+
"required": {
889+
"model": (folder_paths.get_filename_list("wav2vec2"), {"tooltip": "These models are loaded from the 'ComfyUI/models/wav2vec2' -folder",}),
890+
"base_precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
891+
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
892+
"device": (devices, {"default": default_device}),
893+
},
894+
}
895+
896+
RETURN_TYPES = ("WAV2VECMODEL",)
897+
RETURN_NAMES = ("wav2vec_model", )
898+
FUNCTION = "loadmodel"
899+
CATEGORY = "multigpu/WanVideoWrapper"
900+
DESCRIPTION = "MultiGPU-aware Wav2Vec model loader"
901+
902+
def loadmodel(self, model, base_precision, load_device, device):
903+
from . import set_current_device
904+
905+
set_current_device(device)
906+
907+
original_loader = NODE_CLASS_MAPPINGS["Wav2VecModelLoader"]()
908+
return original_loader.loadmodel(model, base_precision, load_device)
909+
910+
class DownloadAndLoadWav2VecModelMultiGPU:
911+
@classmethod
912+
def INPUT_TYPES(s):
913+
devices = get_device_list()
914+
default_device = devices[1] if len(devices) > 1 else devices[0]
915+
return {
916+
"required": {
917+
"model": (
918+
[
919+
"TencentGameMate/chinese-wav2vec2-base",
920+
"facebook/wav2vec2-base-960h"
921+
],
922+
),
923+
"base_precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
924+
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
925+
"device": (devices, {"default": default_device}),
926+
},
927+
}
928+
929+
RETURN_TYPES = ("WAV2VECMODEL",)
930+
RETURN_NAMES = ("wav2vec_model", )
931+
FUNCTION = "loadmodel"
932+
CATEGORY = "multigpu/WanVideoWrapper"
933+
DESCRIPTION = "MultiGPU-aware downloadable Wav2Vec model loader"
934+
935+
def loadmodel(self, model, base_precision, load_device, device):
936+
from . import set_current_device
937+
938+
set_current_device(device)
939+
940+
original_loader = NODE_CLASS_MAPPINGS["DownloadAndLoadWav2VecModel"]()
941+
return original_loader.loadmodel(model, base_precision, load_device)

0 commit comments

Comments
 (0)