@@ -712,3 +712,230 @@ def process(self, clip_vision, load_device, image_1, strength_1, strength_2, for
712712 return original_encode .process (clip_vision [0 ], image_1 , strength_1 , strength_2 , force_offload , crop , combine_embeds , image_2 , negative_image , tiles , ratio )
713713 finally :
714714 encode_module .device = original_module_device
715+
716+ class WanVideoControlnetLoader :
717+ @classmethod
718+ def INPUT_TYPES (s ):
719+ return {
720+ "required" : {
721+ "model" : (folder_paths .get_filename_list ("controlnet" ), {"tooltip" : "These models are loaded from the 'ComfyUI/models/controlnet' -folder" ,}),
722+
723+ "base_precision" : (["fp32" , "bf16" , "fp16" ], {"default" : "bf16" }),
724+ "quantization" : (['disabled' , 'fp8_e4m3fn' , 'fp8_e4m3fn_fast' , 'fp8_e5m2' , 'fp8_e4m3fn_fast_no_ffn' ], {"default" : 'disabled' , "tooltip" : "optional quantization method" }),
725+ "load_device" : (["main_device" , "offload_device" ], {"default" : "main_device" , "tooltip" : "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM" }),
726+ },
727+ }
728+
729+ RETURN_TYPES = ("WANVIDEOCONTROLNET" ,)
730+ RETURN_NAMES = ("controlnet" , )
731+ FUNCTION = "loadmodel"
732+ CATEGORY = "WanVideoWrapper"
733+ DESCRIPTION = "Loads ControlNet model from 'https://huggingface.co/collections/TheDenk/wan21-controlnets-68302b430411dafc0d74d2fc'"
734+
735+ def loadmodel (self , model , base_precision , load_device , quantization ):
736+
737+ device = mm .get_torch_device ()
738+ offload_device = mm .unet_offload_device ()
739+
740+ transformer_load_device = device if load_device == "main_device" else offload_device
741+
742+ base_dtype = {"fp8_e4m3fn" : torch .float8_e4m3fn , "fp8_e4m3fn_fast" : torch .float8_e4m3fn , "bf16" : torch .bfloat16 , "fp16" : torch .float16 , "fp16_fast" : torch .float16 , "fp32" : torch .float32 }[base_precision ]
743+
744+ model_path = folder_paths .get_full_path_or_raise ("controlnet" , model )
745+
746+ sd = load_torch_file (model_path , device = transformer_load_device , safe_load = True )
747+
748+ num_layers = 8 if "blocks.7.scale_shift_table" in sd else 6
749+ out_proj_dim = sd ["controlnet_blocks.0.bias" ].shape [0 ]
750+ downscale_coef = 16 if out_proj_dim == 3072 else 8
751+ vae_channels = 48 if out_proj_dim == 3072 else 16
752+
753+ if not "control_encoder.0.0.weight" in sd :
754+ raise ValueError ("Invalid ControlNet model" )
755+
756+ controlnet_cfg = {
757+ "added_kv_proj_dim" : None ,
758+ "attention_head_dim" : 128 ,
759+ "cross_attn_norm" : None ,
760+ "downscale_coef" : downscale_coef ,
761+ "eps" : 1e-06 ,
762+ "ffn_dim" : 8960 ,
763+ "freq_dim" : 256 ,
764+ "image_dim" : None ,
765+ "in_channels" : 3 ,
766+ "num_attention_heads" : 12 ,
767+ "num_layers" : num_layers ,
768+ "out_proj_dim" : out_proj_dim ,
769+ "patch_size" : [
770+ 1 ,
771+ 2 ,
772+ 2
773+ ],
774+ "qk_norm" : "rms_norm_across_heads" ,
775+ "rope_max_seq_len" : 1024 ,
776+ "text_dim" : 4096 ,
777+ "vae_channels" : vae_channels
778+ }
779+ print (f"Loading WanControlnet with config: { controlnet_cfg } " )
780+
781+ from .wan_controlnet import WanControlnet
782+
783+ with init_empty_weights ():
784+ controlnet = WanControlnet (** controlnet_cfg )
785+ controlnet .eval ()
786+
787+ if quantization == "disabled" :
788+ for k , v in sd .items ():
789+ if isinstance (v , torch .Tensor ):
790+ if v .dtype == torch .float8_e4m3fn :
791+ quantization = "fp8_e4m3fn"
792+ break
793+ elif v .dtype == torch .float8_e5m2 :
794+ quantization = "fp8_e5m2"
795+ break
796+
797+ if "fp8_e4m3fn" in quantization :
798+ dtype = torch .float8_e4m3fn
799+ elif quantization == "fp8_e5m2" :
800+ dtype = torch .float8_e5m2
801+ else :
802+ dtype = base_dtype
803+ params_to_keep = {"norm" , "head" , "time_in" , "vector_in" , "controlnet_patch_embedding" , "time_" , "img_emb" , "modulation" , "text_embedding" , "adapter" }
804+
805+ log .info ("Using accelerate to load and assign controlnet model weights to device..." )
806+ param_count = sum (1 for _ in controlnet .named_parameters ())
807+ for name , param in tqdm (controlnet .named_parameters (),
808+ desc = f"Loading transformer parameters to { transformer_load_device } " ,
809+ total = param_count ,
810+ leave = True ):
811+ dtype_to_use = base_dtype if any (keyword in name for keyword in params_to_keep ) else dtype
812+ if "controlnet_patch_embedding" in name :
813+ dtype_to_use = torch .float32
814+ set_module_tensor_to_device (controlnet , name , device = transformer_load_device , dtype = dtype_to_use , value = sd [name ])
815+
816+ del sd
817+
818+ if load_device == "offload_device" and controlnet .device != offload_device :
819+ log .info (f"Moving controlnet model from { controlnet .device } to { offload_device } " )
820+ controlnet .to (offload_device )
821+ gc .collect ()
822+ mm .soft_empty_cache ()
823+
824+ return (controlnet ,)
825+
826+ class WanVideoControlnetLoaderMultiGPU :
827+ @classmethod
828+ def INPUT_TYPES (s ):
829+ devices = get_device_list ()
830+ default_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
831+ return {
832+ "required" : {
833+ "model" : (folder_paths .get_filename_list ("controlnet" ), {"tooltip" : "These models are loaded from the 'ComfyUI/models/controlnet' -folder" ,}),
834+ "base_precision" : (["fp32" , "bf16" , "fp16" ], {"default" : "bf16" }),
835+ "quantization" : (['disabled' , 'fp8_e4m3fn' , 'fp8_e4m3fn_fast' , 'fp8_e5m2' , 'fp8_e4m3fn_fast_no_ffn' ], {"default" : 'disabled' , "tooltip" : "optional quantization method" }),
836+ "load_device" : (["main_device" , "offload_device" ], {"default" : "main_device" , "tooltip" : "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM" }),
837+ "device" : (devices , {"default" : default_device }),
838+ },
839+ }
840+
841+ RETURN_TYPES = ("WANVIDEOCONTROLNET" ,)
842+ RETURN_NAMES = ("controlnet" , )
843+ FUNCTION = "loadmodel"
844+ CATEGORY = "multigpu/WanVideoWrapper"
845+ DESCRIPTION = "MultiGPU-aware ControlNet loader for WanVideo models"
846+
847+ def loadmodel (self , model , base_precision , load_device , quantization , device ):
848+ from . import set_current_device
849+
850+ set_current_device (device )
851+
852+ original_loader = NODE_CLASS_MAPPINGS ["WanVideoControlnetLoader" ]()
853+ return original_loader .loadmodel (model , base_precision , load_device , quantization )
854+
855+ class FantasyTalkingModelLoaderMultiGPU :
856+ @classmethod
857+ def INPUT_TYPES (s ):
858+ devices = get_device_list ()
859+ default_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
860+ return {
861+ "required" : {
862+ "model" : (folder_paths .get_filename_list ("diffusion_models" ), {"tooltip" : "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder" ,}),
863+ "base_precision" : (["fp32" , "bf16" , "fp16" ], {"default" : "fp16" }),
864+ "device" : (devices , {"default" : default_device }),
865+ },
866+ }
867+
868+ RETURN_TYPES = ("FANTASYTALKINGMODEL" ,)
869+ RETURN_NAMES = ("model" , )
870+ FUNCTION = "loadmodel"
871+ CATEGORY = "multigpu/WanVideoWrapper"
872+ DESCRIPTION = "MultiGPU-aware FantasyTalking model loader"
873+
874+ def loadmodel (self , model , base_precision , device ):
875+ from . import set_current_device
876+
877+ set_current_device (device )
878+
879+ original_loader = NODE_CLASS_MAPPINGS ["FantasyTalkingModelLoader" ]()
880+ return original_loader .loadmodel (model , base_precision )
881+
882+ class Wav2VecModelLoaderMultiGPU :
883+ @classmethod
884+ def INPUT_TYPES (s ):
885+ devices = get_device_list ()
886+ default_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
887+ return {
888+ "required" : {
889+ "model" : (folder_paths .get_filename_list ("wav2vec2" ), {"tooltip" : "These models are loaded from the 'ComfyUI/models/wav2vec2' -folder" ,}),
890+ "base_precision" : (["fp32" , "bf16" , "fp16" ], {"default" : "fp16" }),
891+ "load_device" : (["main_device" , "offload_device" ], {"default" : "main_device" , "tooltip" : "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM" }),
892+ "device" : (devices , {"default" : default_device }),
893+ },
894+ }
895+
896+ RETURN_TYPES = ("WAV2VECMODEL" ,)
897+ RETURN_NAMES = ("wav2vec_model" , )
898+ FUNCTION = "loadmodel"
899+ CATEGORY = "multigpu/WanVideoWrapper"
900+ DESCRIPTION = "MultiGPU-aware Wav2Vec model loader"
901+
902+ def loadmodel (self , model , base_precision , load_device , device ):
903+ from . import set_current_device
904+
905+ set_current_device (device )
906+
907+ original_loader = NODE_CLASS_MAPPINGS ["Wav2VecModelLoader" ]()
908+ return original_loader .loadmodel (model , base_precision , load_device )
909+
910+ class DownloadAndLoadWav2VecModelMultiGPU :
911+ @classmethod
912+ def INPUT_TYPES (s ):
913+ devices = get_device_list ()
914+ default_device = devices [1 ] if len (devices ) > 1 else devices [0 ]
915+ return {
916+ "required" : {
917+ "model" : (
918+ [
919+ "TencentGameMate/chinese-wav2vec2-base" ,
920+ "facebook/wav2vec2-base-960h"
921+ ],
922+ ),
923+ "base_precision" : (["fp32" , "bf16" , "fp16" ], {"default" : "fp16" }),
924+ "load_device" : (["main_device" , "offload_device" ], {"default" : "main_device" , "tooltip" : "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM" }),
925+ "device" : (devices , {"default" : default_device }),
926+ },
927+ }
928+
929+ RETURN_TYPES = ("WAV2VECMODEL" ,)
930+ RETURN_NAMES = ("wav2vec_model" , )
931+ FUNCTION = "loadmodel"
932+ CATEGORY = "multigpu/WanVideoWrapper"
933+ DESCRIPTION = "MultiGPU-aware downloadable Wav2Vec model loader"
934+
935+ def loadmodel (self , model , base_precision , load_device , device ):
936+ from . import set_current_device
937+
938+ set_current_device (device )
939+
940+ original_loader = NODE_CLASS_MAPPINGS ["DownloadAndLoadWav2VecModel" ]()
941+ return original_loader .loadmodel (model , base_precision , load_device )
0 commit comments