@@ -714,116 +714,6 @@ def process(self, clip_vision, load_device, image_1, strength_1, strength_2, for
714714 encode_module .device = original_module_device
715715
716716class WanVideoControlnetLoader :
717- @classmethod
718- def INPUT_TYPES (s ):
719- return {
720- "required" : {
721- "model" : (folder_paths .get_filename_list ("controlnet" ), {"tooltip" : "These models are loaded from the 'ComfyUI/models/controlnet' -folder" ,}),
722-
723- "base_precision" : (["fp32" , "bf16" , "fp16" ], {"default" : "bf16" }),
724- "quantization" : (['disabled' , 'fp8_e4m3fn' , 'fp8_e4m3fn_fast' , 'fp8_e5m2' , 'fp8_e4m3fn_fast_no_ffn' ], {"default" : 'disabled' , "tooltip" : "optional quantization method" }),
725- "load_device" : (["main_device" , "offload_device" ], {"default" : "main_device" , "tooltip" : "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM" }),
726- },
727- }
728-
729- RETURN_TYPES = ("WANVIDEOCONTROLNET" ,)
730- RETURN_NAMES = ("controlnet" , )
731- FUNCTION = "loadmodel"
732- CATEGORY = "WanVideoWrapper"
733- DESCRIPTION = "Loads ControlNet model from 'https://huggingface.co/collections/TheDenk/wan21-controlnets-68302b430411dafc0d74d2fc'"
734-
735- def loadmodel (self , model , base_precision , load_device , quantization ):
736-
737- device = mm .get_torch_device ()
738- offload_device = mm .unet_offload_device ()
739-
740- transformer_load_device = device if load_device == "main_device" else offload_device
741-
742- base_dtype = {"fp8_e4m3fn" : torch .float8_e4m3fn , "fp8_e4m3fn_fast" : torch .float8_e4m3fn , "bf16" : torch .bfloat16 , "fp16" : torch .float16 , "fp16_fast" : torch .float16 , "fp32" : torch .float32 }[base_precision ]
743-
744- model_path = folder_paths .get_full_path_or_raise ("controlnet" , model )
745-
746- sd = load_torch_file (model_path , device = transformer_load_device , safe_load = True )
747-
748- num_layers = 8 if "blocks.7.scale_shift_table" in sd else 6
749- out_proj_dim = sd ["controlnet_blocks.0.bias" ].shape [0 ]
750- downscale_coef = 16 if out_proj_dim == 3072 else 8
751- vae_channels = 48 if out_proj_dim == 3072 else 16
752-
753- if not "control_encoder.0.0.weight" in sd :
754- raise ValueError ("Invalid ControlNet model" )
755-
756- controlnet_cfg = {
757- "added_kv_proj_dim" : None ,
758- "attention_head_dim" : 128 ,
759- "cross_attn_norm" : None ,
760- "downscale_coef" : downscale_coef ,
761- "eps" : 1e-06 ,
762- "ffn_dim" : 8960 ,
763- "freq_dim" : 256 ,
764- "image_dim" : None ,
765- "in_channels" : 3 ,
766- "num_attention_heads" : 12 ,
767- "num_layers" : num_layers ,
768- "out_proj_dim" : out_proj_dim ,
769- "patch_size" : [
770- 1 ,
771- 2 ,
772- 2
773- ],
774- "qk_norm" : "rms_norm_across_heads" ,
775- "rope_max_seq_len" : 1024 ,
776- "text_dim" : 4096 ,
777- "vae_channels" : vae_channels
778- }
779- print (f"Loading WanControlnet with config: { controlnet_cfg } " )
780-
781- from .wan_controlnet import WanControlnet
782-
783- with init_empty_weights ():
784- controlnet = WanControlnet (** controlnet_cfg )
785- controlnet .eval ()
786-
787- if quantization == "disabled" :
788- for k , v in sd .items ():
789- if isinstance (v , torch .Tensor ):
790- if v .dtype == torch .float8_e4m3fn :
791- quantization = "fp8_e4m3fn"
792- break
793- elif v .dtype == torch .float8_e5m2 :
794- quantization = "fp8_e5m2"
795- break
796-
797- if "fp8_e4m3fn" in quantization :
798- dtype = torch .float8_e4m3fn
799- elif quantization == "fp8_e5m2" :
800- dtype = torch .float8_e5m2
801- else :
802- dtype = base_dtype
803- params_to_keep = {"norm" , "head" , "time_in" , "vector_in" , "controlnet_patch_embedding" , "time_" , "img_emb" , "modulation" , "text_embedding" , "adapter" }
804-
805- log .info ("Using accelerate to load and assign controlnet model weights to device..." )
806- param_count = sum (1 for _ in controlnet .named_parameters ())
807- for name , param in tqdm (controlnet .named_parameters (),
808- desc = f"Loading transformer parameters to { transformer_load_device } " ,
809- total = param_count ,
810- leave = True ):
811- dtype_to_use = base_dtype if any (keyword in name for keyword in params_to_keep ) else dtype
812- if "controlnet_patch_embedding" in name :
813- dtype_to_use = torch .float32
814- set_module_tensor_to_device (controlnet , name , device = transformer_load_device , dtype = dtype_to_use , value = sd [name ])
815-
816- del sd
817-
818- if load_device == "offload_device" and controlnet .device != offload_device :
819- log .info (f"Moving controlnet model from { controlnet .device } to { offload_device } " )
820- controlnet .to (offload_device )
821- gc .collect ()
822- mm .soft_empty_cache ()
823-
824- return (controlnet ,)
825-
826- class WanVideoControlnetLoaderMultiGPU :
827717 @classmethod
828718 def INPUT_TYPES (s ):
829719 devices = get_device_list ()
@@ -852,7 +742,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, device):
852742 original_loader = NODE_CLASS_MAPPINGS ["WanVideoControlnetLoader" ]()
853743 return original_loader .loadmodel (model , base_precision , load_device , quantization )
854744
855- class FantasyTalkingModelLoaderMultiGPU :
745+ class FantasyTalkingModelLoader :
856746 @classmethod
857747 def INPUT_TYPES (s ):
858748 devices = get_device_list ()
@@ -879,7 +769,7 @@ def loadmodel(self, model, base_precision, device):
879769 original_loader = NODE_CLASS_MAPPINGS ["FantasyTalkingModelLoader" ]()
880770 return original_loader .loadmodel (model , base_precision )
881771
882- class Wav2VecModelLoaderMultiGPU :
772+ class Wav2VecModelLoader :
883773 @classmethod
884774 def INPUT_TYPES (s ):
885775 devices = get_device_list ()
@@ -907,7 +797,7 @@ def loadmodel(self, model, base_precision, load_device, device):
907797 original_loader = NODE_CLASS_MAPPINGS ["Wav2VecModelLoader" ]()
908798 return original_loader .loadmodel (model , base_precision , load_device )
909799
910- class DownloadAndLoadWav2VecModelMultiGPU :
800+ class DownloadAndLoadWav2VecModel :
911801 @classmethod
912802 def INPUT_TYPES (s ):
913803 devices = get_device_list ()
0 commit comments