Skip to content

Commit faef324

Browse files
committed
correcting implementation of salvagable nodes and obiliterating one abomination
1 parent 610c210 commit faef324

2 files changed

Lines changed: 13 additions & 123 deletions

File tree

__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ def text_encoder_device_patched():
226226
WanVideoEncode,
227227
LoadWanVideoClipTextEncoder,
228228
WanVideoClipVisionEncode,
229-
WanVideoControlnetLoaderMultiGPU,
230-
FantasyTalkingModelLoaderMultiGPU,
231-
Wav2VecModelLoaderMultiGPU,
232-
WanVideoUni3C_ControlnetLoaderMultiGPU,
233-
DownloadAndLoadWav2VecModelMultiGPU,
229+
WanVideoControlnetLoader,
230+
FantasyTalkingModelLoader,
231+
Wav2VecModelLoader,
232+
WanVideoUni3C_ControlnetLoader,
233+
DownloadAndLoadWav2VecModel,
234234
)
235235

236236
from .wrappers import (
@@ -382,11 +382,11 @@ def register_and_count(module_names, node_map):
382382
"WanVideoEncodeMultiGPU": WanVideoEncode,
383383
"LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder,
384384
"WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode,
385-
"WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoaderMultiGPU,
386-
"FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoaderMultiGPU,
387-
"Wav2VecModelLoaderMultiGPU": Wav2VecModelLoaderMultiGPU,
388-
"WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoaderMultiGPU,
389-
"DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModelMultiGPU,
385+
"WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoader,
386+
"FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoader,
387+
"Wav2VecModelLoaderMultiGPU": Wav2VecModelLoader,
388+
"WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoader,
389+
"DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModel,
390390
}
391391
register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], wanvideo_nodes)
392392

wanvideo.py

Lines changed: 3 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -714,116 +714,6 @@ def process(self, clip_vision, load_device, image_1, strength_1, strength_2, for
714714
encode_module.device = original_module_device
715715

716716
class WanVideoControlnetLoader:
717-
@classmethod
718-
def INPUT_TYPES(s):
719-
return {
720-
"required": {
721-
"model": (folder_paths.get_filename_list("controlnet"), {"tooltip": "These models are loaded from the 'ComfyUI/models/controlnet' -folder",}),
722-
723-
"base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}),
724-
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2', 'fp8_e4m3fn_fast_no_ffn'], {"default": 'disabled', "tooltip": "optional quantization method"}),
725-
"load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
726-
},
727-
}
728-
729-
RETURN_TYPES = ("WANVIDEOCONTROLNET",)
730-
RETURN_NAMES = ("controlnet", )
731-
FUNCTION = "loadmodel"
732-
CATEGORY = "WanVideoWrapper"
733-
DESCRIPTION = "Loads ControlNet model from 'https://huggingface.co/collections/TheDenk/wan21-controlnets-68302b430411dafc0d74d2fc'"
734-
735-
def loadmodel(self, model, base_precision, load_device, quantization):
736-
737-
device = mm.get_torch_device()
738-
offload_device = mm.unet_offload_device()
739-
740-
transformer_load_device = device if load_device == "main_device" else offload_device
741-
742-
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision]
743-
744-
model_path = folder_paths.get_full_path_or_raise("controlnet", model)
745-
746-
sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True)
747-
748-
num_layers = 8 if "blocks.7.scale_shift_table" in sd else 6
749-
out_proj_dim = sd["controlnet_blocks.0.bias"].shape[0]
750-
downscale_coef = 16 if out_proj_dim == 3072 else 8
751-
vae_channels = 48 if out_proj_dim == 3072 else 16
752-
753-
if not "control_encoder.0.0.weight" in sd:
754-
raise ValueError("Invalid ControlNet model")
755-
756-
controlnet_cfg = {
757-
"added_kv_proj_dim": None,
758-
"attention_head_dim": 128,
759-
"cross_attn_norm": None,
760-
"downscale_coef": downscale_coef,
761-
"eps": 1e-06,
762-
"ffn_dim": 8960,
763-
"freq_dim": 256,
764-
"image_dim": None,
765-
"in_channels": 3,
766-
"num_attention_heads": 12,
767-
"num_layers": num_layers,
768-
"out_proj_dim": out_proj_dim,
769-
"patch_size": [
770-
1,
771-
2,
772-
2
773-
],
774-
"qk_norm": "rms_norm_across_heads",
775-
"rope_max_seq_len": 1024,
776-
"text_dim": 4096,
777-
"vae_channels": vae_channels
778-
}
779-
print(f"Loading WanControlnet with config: {controlnet_cfg}")
780-
781-
from .wan_controlnet import WanControlnet
782-
783-
with init_empty_weights():
784-
controlnet = WanControlnet(**controlnet_cfg)
785-
controlnet.eval()
786-
787-
if quantization == "disabled":
788-
for k, v in sd.items():
789-
if isinstance(v, torch.Tensor):
790-
if v.dtype == torch.float8_e4m3fn:
791-
quantization = "fp8_e4m3fn"
792-
break
793-
elif v.dtype == torch.float8_e5m2:
794-
quantization = "fp8_e5m2"
795-
break
796-
797-
if "fp8_e4m3fn" in quantization:
798-
dtype = torch.float8_e4m3fn
799-
elif quantization == "fp8_e5m2":
800-
dtype = torch.float8_e5m2
801-
else:
802-
dtype = base_dtype
803-
params_to_keep = {"norm", "head", "time_in", "vector_in", "controlnet_patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter"}
804-
805-
log.info("Using accelerate to load and assign controlnet model weights to device...")
806-
param_count = sum(1 for _ in controlnet.named_parameters())
807-
for name, param in tqdm(controlnet.named_parameters(),
808-
desc=f"Loading transformer parameters to {transformer_load_device}",
809-
total=param_count,
810-
leave=True):
811-
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
812-
if "controlnet_patch_embedding" in name:
813-
dtype_to_use = torch.float32
814-
set_module_tensor_to_device(controlnet, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name])
815-
816-
del sd
817-
818-
if load_device == "offload_device" and controlnet.device != offload_device:
819-
log.info(f"Moving controlnet model from {controlnet.device} to {offload_device}")
820-
controlnet.to(offload_device)
821-
gc.collect()
822-
mm.soft_empty_cache()
823-
824-
return (controlnet,)
825-
826-
class WanVideoControlnetLoaderMultiGPU:
827717
@classmethod
828718
def INPUT_TYPES(s):
829719
devices = get_device_list()
@@ -852,7 +742,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, device):
852742
original_loader = NODE_CLASS_MAPPINGS["WanVideoControlnetLoader"]()
853743
return original_loader.loadmodel(model, base_precision, load_device, quantization)
854744

855-
class FantasyTalkingModelLoaderMultiGPU:
745+
class FantasyTalkingModelLoader:
856746
@classmethod
857747
def INPUT_TYPES(s):
858748
devices = get_device_list()
@@ -879,7 +769,7 @@ def loadmodel(self, model, base_precision, device):
879769
original_loader = NODE_CLASS_MAPPINGS["FantasyTalkingModelLoader"]()
880770
return original_loader.loadmodel(model, base_precision)
881771

882-
class Wav2VecModelLoaderMultiGPU:
772+
class Wav2VecModelLoader:
883773
@classmethod
884774
def INPUT_TYPES(s):
885775
devices = get_device_list()
@@ -907,7 +797,7 @@ def loadmodel(self, model, base_precision, load_device, device):
907797
original_loader = NODE_CLASS_MAPPINGS["Wav2VecModelLoader"]()
908798
return original_loader.loadmodel(model, base_precision, load_device)
909799

910-
class DownloadAndLoadWav2VecModelMultiGPU:
800+
class DownloadAndLoadWav2VecModel:
911801
@classmethod
912802
def INPUT_TYPES(s):
913803
devices = get_device_list()

0 commit comments

Comments
 (0)