Skip to content

Commit f0313e4

Browse files
committed
WanVideoDecode cleanup
1 parent 15ff29c commit f0313e4

1 file changed

Lines changed: 9 additions & 20 deletions

File tree

wanvideo.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def loadmodel(self, model_name, precision, device=None, quantization="disabled")
261261

262262
return text_encoder, device
263263

264-
265264
class WanVideoTextEncodeCached:
266265
@classmethod
267266
def INPUT_TYPES(s):
@@ -409,7 +408,6 @@ def loadmodel(self, model_name, load_device=None, precision="fp16", parallel=Fal
409408

410409
return vae_model, load_device
411410

412-
413411
class WanVideoBlockSwap:
414412
@classmethod
415413
def INPUT_TYPES(s):
@@ -665,24 +663,15 @@ def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y):
665663
def decode(self, vae, load_device, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, normalization="default"):
666664
from . import set_current_device
667665

666+
original_decode = NODE_CLASS_MAPPINGS["WanVideoDecode"]()
667+
decode_module = inspect.getmodule(original_decode)
668+
original_module_device = decode_module.device
669+
668670
set_current_device(load_device)
669671
compute_device_to_be_patched = mm.get_torch_device()
670-
671-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoDecodeMultiGPU] load device: {load_device}")
672-
673-
original_loader = NODE_CLASS_MAPPINGS["WanVideoDecode"]()
674-
loader_module = inspect.getmodule(original_loader)
675-
676-
original_module_device = loader_module.device
677-
678-
loader_module.device = compute_device_to_be_patched
679-
680-
result = original_loader.decode(vae[0], samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, normalization)
681-
682-
loader_module.device = original_module_device
683-
684-
decode = result[0]
685-
686-
return (decode,)
687-
672+
decode_module.device = compute_device_to_be_patched
688673

674+
try:
675+
return original_decode.decode(vae[0], samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, normalization)
676+
finally:
677+
decode_module.device = original_module_device

0 commit comments

Comments
 (0)