Skip to content

Commit b04ae1e

Browse files
committed
Coverted WanVideoImageToVideoEncode to shim
1 parent f0313e4 commit b04ae1e

1 file changed

Lines changed: 16 additions & 145 deletions

File tree

wanvideo.py

Lines changed: 16 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -472,156 +472,27 @@ def process(self, width, height, num_frames, force_offload, noise_aug_strength,
472472
temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None, vae=None, load_device=None):
473473
from . import set_current_device
474474

475+
original_encoder = NODE_CLASS_MAPPINGS["WanVideoImageToVideoEncode"]()
476+
encoder_module = inspect.getmodule(original_encoder)
477+
478+
original_module_device = encoder_module.device
479+
original_module_offload = encoder_module.offload_device
480+
475481
set_current_device(load_device)
476-
477-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoImageToVideoEncodeMultiGPU] load device: {load_device}")
478482

479-
device = mm.get_torch_device()
480-
PATCH_SIZE = (1, 2, 2)
481-
offload_device = mm.unet_offload_device()
483+
compute_device_to_be_patched = mm.get_torch_device()
484+
encoder_module.device = compute_device_to_be_patched
482485

483-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoImageToVideoEncodeMultiGPU] torch device: {device}")
486+
encoder_module.offload_device = mm.unet_offload_device()
484487

485-
if vae is not None:
486-
vae = vae[0]
487-
488-
if start_image is None and end_image is None and add_cond_latents is None:
489-
return WanVideoEmptyEmbeds().process(
490-
num_frames, width, height, control_embeds=control_embeds, extra_latents=extra_latents,
491-
)
492-
if vae is None:
493-
raise ValueError("VAE is required for image encoding.")
494-
H = height
495-
W = width
496-
497-
lat_h = H // vae.upsampling_factor
498-
lat_w = W // vae.upsampling_factor
499-
500-
num_frames = ((num_frames - 1) // 4) * 4 + 1
501-
two_ref_images = start_image is not None and end_image is not None
502-
503-
if start_image is None and end_image is not None:
504-
fun_or_fl2v_model = True # end image alone only works with this option
505-
506-
base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0)
507-
if temporal_mask is None:
508-
mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device, dtype=vae.dtype)
509-
if start_image is not None:
510-
mask[:, 0:start_image.shape[0]] = 1 # First frame
511-
if end_image is not None:
512-
mask[:, -end_image.shape[0]:] = 1 # End frame if exists
513-
else:
514-
mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1)
515-
if mask.shape[0] > base_frames:
516-
mask = mask[:base_frames]
517-
elif mask.shape[0] < base_frames:
518-
mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)])
519-
mask = mask.unsqueeze(0).to(device, vae.dtype)
520-
521-
# Repeat first frame and optionally end frame
522-
start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) # T, C, H, W
523-
if end_image is not None and not fun_or_fl2v_model:
524-
end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) # T, C, H, W
525-
mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1)
526-
else:
527-
mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1)
528-
529-
# Reshape mask into groups of 4 frames
530-
mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W
531-
mask = mask.movedim(1, 2)[0]# C, T, H, W
532-
533-
# Resize and rearrange the input image dimensions
534-
if start_image is not None:
535-
start_image = start_image[..., :3]
536-
if start_image.shape[1] != H or start_image.shape[2] != W:
537-
resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1)
538-
else:
539-
resized_start_image = start_image.permute(3, 0, 1, 2) # C, T, H, W
540-
resized_start_image = resized_start_image * 2 - 1
541-
if noise_aug_strength > 0.0:
542-
resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength)
543-
544-
if end_image is not None:
545-
end_image = end_image[..., :3]
546-
if end_image.shape[1] != H or end_image.shape[2] != W:
547-
resized_end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1)
548-
else:
549-
resized_end_image = end_image.permute(3, 0, 1, 2) # C, T, H, W
550-
resized_end_image = resized_end_image * 2 - 1
551-
if noise_aug_strength > 0.0:
552-
resized_end_image = add_noise_to_reference_video(resized_end_image, ratio=noise_aug_strength)
553-
554-
# Concatenate image with zero frames and encode
555-
if temporal_mask is None:
556-
if start_image is not None and end_image is None:
557-
zero_frames = torch.zeros(3, num_frames-start_image.shape[0], H, W, device=device, dtype=vae.dtype)
558-
concatenated = torch.cat([resized_start_image.to(device, dtype=vae.dtype), zero_frames], dim=1)
559-
del resized_start_image, zero_frames
560-
elif start_image is None and end_image is not None:
561-
zero_frames = torch.zeros(3, num_frames-end_image.shape[0], H, W, device=device, dtype=vae.dtype)
562-
concatenated = torch.cat([zero_frames, resized_end_image.to(device, dtype=vae.dtype)], dim=1)
563-
del zero_frames
564-
elif start_image is None and end_image is None:
565-
concatenated = torch.zeros(3, num_frames, H, W, device=device, dtype=vae.dtype)
566-
else:
567-
if fun_or_fl2v_model:
568-
zero_frames = torch.zeros(3, num_frames-(start_image.shape[0]+end_image.shape[0]), H, W, device=device, dtype=vae.dtype)
569-
else:
570-
zero_frames = torch.zeros(3, num_frames-1, H, W, device=device, dtype=vae.dtype)
571-
concatenated = torch.cat([resized_start_image.to(device, dtype=vae.dtype), zero_frames, resized_end_image.to(device, dtype=vae.dtype)], dim=1)
572-
del resized_start_image, zero_frames
573-
else:
574-
temporal_mask = common_upscale(temporal_mask.unsqueeze(1), W, H, "nearest", "disabled").squeeze(1)
575-
concatenated = resized_start_image[:,:num_frames].to(vae.dtype) * temporal_mask[:num_frames].unsqueeze(0).to(vae.dtype)
576-
del resized_start_image, temporal_mask
577-
578-
mm.soft_empty_cache()
579-
gc.collect()
580-
581-
vae.to(device)
582-
y = vae.encode([concatenated], device, end_=(end_image is not None and not fun_or_fl2v_model),tiled=tiled_vae)[0]
583-
del concatenated
584-
585-
has_ref = False
586-
if extra_latents is not None:
587-
samples = extra_latents["samples"].squeeze(0)
588-
y = torch.cat([samples, y], dim=1)
589-
mask = torch.cat([torch.ones_like(mask[:, 0:samples.shape[1]]), mask], dim=1)
590-
num_frames += samples.shape[1] * 4
591-
has_ref = True
592-
y[:, :1] *= start_latent_strength
593-
y[:, -1:] *= end_latent_strength
594-
595-
# Calculate maximum sequence length
596-
patches_per_frame = lat_h * lat_w // (PATCH_SIZE[1] * PATCH_SIZE[2])
597-
frames_per_stride = (num_frames - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1)
598-
max_seq_len = frames_per_stride * patches_per_frame
599-
600-
if add_cond_latents is not None:
601-
add_cond_latents["ref_latent_neg"] = vae.encode(torch.zeros(1, 3, 1, H, W, device=device, dtype=vae.dtype), device)
602-
603-
if force_offload:
604-
vae.model.to(offload_device)
605-
mm.soft_empty_cache()
606-
gc.collect()
607-
608-
image_embeds = {
609-
"image_embeds": y,
610-
"clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None,
611-
"negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None,
612-
"max_seq_len": max_seq_len,
613-
"num_frames": num_frames,
614-
"lat_h": lat_h,
615-
"lat_w": lat_w,
616-
"control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None,
617-
"end_image": resized_end_image if end_image is not None else None,
618-
"fun_or_fl2v_model": fun_or_fl2v_model,
619-
"has_ref": has_ref,
620-
"add_cond_latents": add_cond_latents,
621-
"mask": mask
622-
}
488+
inner_vae = vae[0]
623489

624-
return (image_embeds,)
490+
try:
491+
return original_encoder.process(width, height, num_frames, force_offload, noise_aug_strength, start_latent_strength, end_latent_strength, start_image,
492+
end_image, control_embeds, fun_or_fl2v_model, temporal_mask, extra_latents, clip_embeds, tiled_vae, add_cond_latents, inner_vae,)
493+
finally:
494+
encoder_module.device = original_module_device
495+
encoder_module.offload_device = original_module_offload
625496

626497
class WanVideoDecode:
627498
@classmethod

0 commit comments

Comments
 (0)