@@ -261,7 +261,6 @@ def loadmodel(self, model_name, precision, device=None, quantization="disabled")
261261
262262 return text_encoder , device
263263
264-
265264class WanVideoTextEncodeCached :
266265 @classmethod
267266 def INPUT_TYPES (s ):
@@ -409,7 +408,6 @@ def loadmodel(self, model_name, load_device=None, precision="fp16", parallel=Fal
409408
410409 return vae_model , load_device
411410
412-
413411class WanVideoBlockSwap :
414412 @classmethod
415413 def INPUT_TYPES (s ):
@@ -665,24 +663,15 @@ def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y):
665663 def decode (self , vae , load_device , samples , enable_vae_tiling , tile_x , tile_y , tile_stride_x , tile_stride_y , normalization = "default" ):
666664 from . import set_current_device
667665
666+ original_decode = NODE_CLASS_MAPPINGS ["WanVideoDecode" ]()
667+ decode_module = inspect .getmodule (original_decode )
668+ original_module_device = decode_module .device
669+
668670 set_current_device (load_device )
669671 compute_device_to_be_patched = mm .get_torch_device ()
670-
671- logger .info (f"[MultiGPU WanVideoWrapper][WanVideoDecodeMultiGPU] load device: { load_device } " )
672-
673- original_loader = NODE_CLASS_MAPPINGS ["WanVideoDecode" ]()
674- loader_module = inspect .getmodule (original_loader )
675-
676- original_module_device = loader_module .device
677-
678- loader_module .device = compute_device_to_be_patched
679-
680- result = original_loader .decode (vae [0 ], samples , enable_vae_tiling , tile_x , tile_y , tile_stride_x , tile_stride_y , normalization )
681-
682- loader_module .device = original_module_device
683-
684- decode = result [0 ]
685-
686- return (decode ,)
687-
672+ decode_module .device = compute_device_to_be_patched
688673
674+ try :
675+ return original_decode .decode (vae [0 ], samples , enable_vae_tiling , tile_x , tile_y , tile_stride_x , tile_stride_y , normalization )
676+ finally :
677+ decode_module .device = original_module_device
0 commit comments