@@ -472,156 +472,27 @@ def process(self, width, height, num_frames, force_offload, noise_aug_strength,
472472 temporal_mask = None , extra_latents = None , clip_embeds = None , tiled_vae = False , add_cond_latents = None , vae = None , load_device = None ):
473473 from . import set_current_device
474474
475+ original_encoder = NODE_CLASS_MAPPINGS ["WanVideoImageToVideoEncode" ]()
476+ encoder_module = inspect .getmodule (original_encoder )
477+
478+ original_module_device = encoder_module .device
479+ original_module_offload = encoder_module .offload_device
480+
475481 set_current_device (load_device )
476-
477- logger .info (f"[MultiGPU WanVideoWrapper][WanVideoImageToVideoEncodeMultiGPU] load device: { load_device } " )
478482
479- device = mm .get_torch_device ()
480- PATCH_SIZE = (1 , 2 , 2 )
481- offload_device = mm .unet_offload_device ()
483+ compute_device_to_be_patched = mm .get_torch_device ()
484+ encoder_module .device = compute_device_to_be_patched
482485
483- logger . info ( f"[MultiGPU WanVideoWrapper][WanVideoImageToVideoEncodeMultiGPU] torch device: { device } " )
486+ encoder_module . offload_device = mm . unet_offload_device ( )
484487
485- if vae is not None :
486- vae = vae [0 ]
487-
488- if start_image is None and end_image is None and add_cond_latents is None :
489- return WanVideoEmptyEmbeds ().process (
490- num_frames , width , height , control_embeds = control_embeds , extra_latents = extra_latents ,
491- )
492- if vae is None :
493- raise ValueError ("VAE is required for image encoding." )
494- H = height
495- W = width
496-
497- lat_h = H // vae .upsampling_factor
498- lat_w = W // vae .upsampling_factor
499-
500- num_frames = ((num_frames - 1 ) // 4 ) * 4 + 1
501- two_ref_images = start_image is not None and end_image is not None
502-
503- if start_image is None and end_image is not None :
504- fun_or_fl2v_model = True # end image alone only works with this option
505-
506- base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0 )
507- if temporal_mask is None :
508- mask = torch .zeros (1 , base_frames , lat_h , lat_w , device = device , dtype = vae .dtype )
509- if start_image is not None :
510- mask [:, 0 :start_image .shape [0 ]] = 1 # First frame
511- if end_image is not None :
512- mask [:, - end_image .shape [0 ]:] = 1 # End frame if exists
513- else :
514- mask = common_upscale (temporal_mask .unsqueeze (1 ).to (device ), lat_w , lat_h , "nearest" , "disabled" ).squeeze (1 )
515- if mask .shape [0 ] > base_frames :
516- mask = mask [:base_frames ]
517- elif mask .shape [0 ] < base_frames :
518- mask = torch .cat ([mask , torch .zeros (base_frames - mask .shape [0 ], lat_h , lat_w , device = device )])
519- mask = mask .unsqueeze (0 ).to (device , vae .dtype )
520-
521- # Repeat first frame and optionally end frame
522- start_mask_repeated = torch .repeat_interleave (mask [:, 0 :1 ], repeats = 4 , dim = 1 ) # T, C, H, W
523- if end_image is not None and not fun_or_fl2v_model :
524- end_mask_repeated = torch .repeat_interleave (mask [:, - 1 :], repeats = 4 , dim = 1 ) # T, C, H, W
525- mask = torch .cat ([start_mask_repeated , mask [:, 1 :- 1 ], end_mask_repeated ], dim = 1 )
526- else :
527- mask = torch .cat ([start_mask_repeated , mask [:, 1 :]], dim = 1 )
528-
529- # Reshape mask into groups of 4 frames
530- mask = mask .view (1 , mask .shape [1 ] // 4 , 4 , lat_h , lat_w ) # 1, T, C, H, W
531- mask = mask .movedim (1 , 2 )[0 ]# C, T, H, W
532-
533- # Resize and rearrange the input image dimensions
534- if start_image is not None :
535- start_image = start_image [..., :3 ]
536- if start_image .shape [1 ] != H or start_image .shape [2 ] != W :
537- resized_start_image = common_upscale (start_image .movedim (- 1 , 1 ), W , H , "lanczos" , "disabled" ).movedim (0 , 1 )
538- else :
539- resized_start_image = start_image .permute (3 , 0 , 1 , 2 ) # C, T, H, W
540- resized_start_image = resized_start_image * 2 - 1
541- if noise_aug_strength > 0.0 :
542- resized_start_image = add_noise_to_reference_video (resized_start_image , ratio = noise_aug_strength )
543-
544- if end_image is not None :
545- end_image = end_image [..., :3 ]
546- if end_image .shape [1 ] != H or end_image .shape [2 ] != W :
547- resized_end_image = common_upscale (end_image .movedim (- 1 , 1 ), W , H , "lanczos" , "disabled" ).movedim (0 , 1 )
548- else :
549- resized_end_image = end_image .permute (3 , 0 , 1 , 2 ) # C, T, H, W
550- resized_end_image = resized_end_image * 2 - 1
551- if noise_aug_strength > 0.0 :
552- resized_end_image = add_noise_to_reference_video (resized_end_image , ratio = noise_aug_strength )
553-
554- # Concatenate image with zero frames and encode
555- if temporal_mask is None :
556- if start_image is not None and end_image is None :
557- zero_frames = torch .zeros (3 , num_frames - start_image .shape [0 ], H , W , device = device , dtype = vae .dtype )
558- concatenated = torch .cat ([resized_start_image .to (device , dtype = vae .dtype ), zero_frames ], dim = 1 )
559- del resized_start_image , zero_frames
560- elif start_image is None and end_image is not None :
561- zero_frames = torch .zeros (3 , num_frames - end_image .shape [0 ], H , W , device = device , dtype = vae .dtype )
562- concatenated = torch .cat ([zero_frames , resized_end_image .to (device , dtype = vae .dtype )], dim = 1 )
563- del zero_frames
564- elif start_image is None and end_image is None :
565- concatenated = torch .zeros (3 , num_frames , H , W , device = device , dtype = vae .dtype )
566- else :
567- if fun_or_fl2v_model :
568- zero_frames = torch .zeros (3 , num_frames - (start_image .shape [0 ]+ end_image .shape [0 ]), H , W , device = device , dtype = vae .dtype )
569- else :
570- zero_frames = torch .zeros (3 , num_frames - 1 , H , W , device = device , dtype = vae .dtype )
571- concatenated = torch .cat ([resized_start_image .to (device , dtype = vae .dtype ), zero_frames , resized_end_image .to (device , dtype = vae .dtype )], dim = 1 )
572- del resized_start_image , zero_frames
573- else :
574- temporal_mask = common_upscale (temporal_mask .unsqueeze (1 ), W , H , "nearest" , "disabled" ).squeeze (1 )
575- concatenated = resized_start_image [:,:num_frames ].to (vae .dtype ) * temporal_mask [:num_frames ].unsqueeze (0 ).to (vae .dtype )
576- del resized_start_image , temporal_mask
577-
578- mm .soft_empty_cache ()
579- gc .collect ()
580-
581- vae .to (device )
582- y = vae .encode ([concatenated ], device , end_ = (end_image is not None and not fun_or_fl2v_model ),tiled = tiled_vae )[0 ]
583- del concatenated
584-
585- has_ref = False
586- if extra_latents is not None :
587- samples = extra_latents ["samples" ].squeeze (0 )
588- y = torch .cat ([samples , y ], dim = 1 )
589- mask = torch .cat ([torch .ones_like (mask [:, 0 :samples .shape [1 ]]), mask ], dim = 1 )
590- num_frames += samples .shape [1 ] * 4
591- has_ref = True
592- y [:, :1 ] *= start_latent_strength
593- y [:, - 1 :] *= end_latent_strength
594-
595- # Calculate maximum sequence length
596- patches_per_frame = lat_h * lat_w // (PATCH_SIZE [1 ] * PATCH_SIZE [2 ])
597- frames_per_stride = (num_frames - 1 ) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1 )
598- max_seq_len = frames_per_stride * patches_per_frame
599-
600- if add_cond_latents is not None :
601- add_cond_latents ["ref_latent_neg" ] = vae .encode (torch .zeros (1 , 3 , 1 , H , W , device = device , dtype = vae .dtype ), device )
602-
603- if force_offload :
604- vae .model .to (offload_device )
605- mm .soft_empty_cache ()
606- gc .collect ()
607-
608- image_embeds = {
609- "image_embeds" : y ,
610- "clip_context" : clip_embeds .get ("clip_embeds" , None ) if clip_embeds is not None else None ,
611- "negative_clip_context" : clip_embeds .get ("negative_clip_embeds" , None ) if clip_embeds is not None else None ,
612- "max_seq_len" : max_seq_len ,
613- "num_frames" : num_frames ,
614- "lat_h" : lat_h ,
615- "lat_w" : lat_w ,
616- "control_embeds" : control_embeds ["control_embeds" ] if control_embeds is not None else None ,
617- "end_image" : resized_end_image if end_image is not None else None ,
618- "fun_or_fl2v_model" : fun_or_fl2v_model ,
619- "has_ref" : has_ref ,
620- "add_cond_latents" : add_cond_latents ,
621- "mask" : mask
622- }
488+ inner_vae = vae [0 ]
623489
624- return (image_embeds ,)
490+ try :
491+ return original_encoder .process (width , height , num_frames , force_offload , noise_aug_strength , start_latent_strength , end_latent_strength , start_image ,
492+ end_image , control_embeds , fun_or_fl2v_model , temporal_mask , extra_latents , clip_embeds , tiled_vae , add_cond_latents , inner_vae ,)
493+ finally :
494+ encoder_module .device = original_module_device
495+ encoder_module .offload_device = original_module_offload
625496
626497class WanVideoDecode :
627498 @classmethod
0 commit comments