Skip to content

Commit 15ff29c

Browse files
committed
methodological convergence and cleanup
1 parent 7c87983 commit 15ff29c

1 file changed

Lines changed: 117 additions & 136 deletions

File tree

wanvideo.py

Lines changed: 117 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,23 @@
1818
logger = logging.getLogger("MultiGPU")
1919

2020

21+
scheduler_list = [
22+
"unipc", "unipc/beta",
23+
"dpm++", "dpm++/beta",
24+
"dpm++_sde", "dpm++_sde/beta",
25+
"euler", "euler/beta",
26+
"deis",
27+
"lcm", "lcm/beta",
28+
"res_multistep",
29+
"flowmatch_causvid",
30+
"flowmatch_distill",
31+
"flowmatch_pusa",
32+
"multitalk",
33+
"sa_ode_stable"
34+
]
35+
36+
rope_functions = ["default", "comfy", "comfy_chunked"]
37+
2138
class WanVideoModelLoader:
2239
@classmethod
2340
def INPUT_TYPES(s):
@@ -59,31 +76,109 @@ def INPUT_TYPES(s):
5976
FUNCTION = "loadmodel"
6077
CATEGORY = "multigpu/WanVideoWrapper"
6178

62-
def loadmodel(self, model, base_precision, compute_device, quantization, load_device,
63-
compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None,
64-
vram_management_args=None, extra_model=None, vace_model=None,
65-
fantasytalking_model=None, multitalk_model=None, fantasyportrait_model=None,
66-
rms_norm_function="default"):
79+
def loadmodel(self, model, base_precision, compute_device, quantization, load_device, **kwargs):
6780
from . import set_current_device
6881

69-
set_current_device(compute_device)
70-
compute_device_to_be_patched = mm.get_torch_device()
71-
7282
original_loader = NODE_CLASS_MAPPINGS["WanVideoModelLoader"]()
7383
loader_module = inspect.getmodule(original_loader)
74-
7584
original_module_device = loader_module.device
7685

77-
loader_module.device = compute_device_to_be_patched
86+
set_current_device(compute_device)
87+
compute_device_to_be_patched = mm.get_torch_device()
7888

79-
result = original_loader.loadmodel(model, base_precision, load_device, quantization, compile_args, attention_mode, block_swap_args, lora, vram_management_args, extra_model=extra_model,
80-
vace_model=vace_model, fantasytalking_model=fantasytalking_model, multitalk_model=multitalk_model, fantasyportrait_model=fantasyportrait_model, rms_norm_function=rms_norm_function,)
89+
loader_module.device = compute_device_to_be_patched
8190

82-
loader_module.device = original_module_device
91+
result = original_loader.loadmodel(model, base_precision, load_device, quantization, **kwargs,)
8392

8493
patcher = result[0]
8594

86-
return (patcher, compute_device)
95+
try:
96+
return (patcher, compute_device)
97+
98+
finally:
99+
loader_module.device = original_module_device
100+
101+
class WanVideoSampler:
102+
@classmethod
103+
def INPUT_TYPES(s):
104+
return {
105+
"required": {
106+
"model": ("WANVIDEOMODEL",),
107+
"compute_device": ("MULTIGPUDEVICE",),
108+
"image_embeds": ("WANVIDIMAGE_EMBEDS", ),
109+
"steps": ("INT", {"default": 30, "min": 1}),
110+
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
111+
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
112+
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
113+
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}),
114+
"scheduler": (scheduler_list, {"default": "unipc",}),
115+
"riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}),
116+
},
117+
"optional": {
118+
"text_embeds": ("WANVIDEOTEXTEMBEDS", ),
119+
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
120+
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
121+
"feta_args": ("FETAARGS", ),
122+
"context_options": ("WANVIDCONTEXT", ),
123+
"cache_args": ("CACHEARGS", ),
124+
"flowedit_args": ("FLOWEDITARGS", ),
125+
"batched_cfg": ("BOOLEAN", {"default": False, "tooltip": "Batch cond and uncond for faster sampling, possibly faster on some hardware, uses more memory"}),
126+
"slg_args": ("SLGARGS", ),
127+
"rope_function": (rope_functions, {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile. Chunked version has reduced peak VRAM usage when not using torch.compile"}),
128+
"loop_args": ("LOOPARGS", ),
129+
"experimental_args": ("EXPERIMENTALARGS", ),
130+
"sigmas": ("SIGMAS", ),
131+
"unianimate_poses": ("UNIANIMATE_POSE", ),
132+
"fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ),
133+
"uni3c_embeds": ("UNI3C_EMBEDS", ),
134+
"multitalk_embeds": ("MULTITALK_EMBEDS", ),
135+
"freeinit_args": ("FREEINITARGS", ),
136+
"start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Start step for the sampling, 0 means full sampling, otherwise samples only from this step"}),
137+
"end_step": ("INT", {"default": -1, "min": -1, "max": 10000, "step": 1, "tooltip": "End step for the sampling, -1 means full sampling, otherwise samples only until this step"}),
138+
"add_noise_to_samples": ("BOOLEAN", {"default": False, "tooltip": "Add noise to the samples before sampling, needed for video2video sampling when starting from clean video"}),
139+
}
140+
}
141+
142+
RETURN_TYPES = ("LATENT", "LATENT",)
143+
RETURN_NAMES = ("samples", "denoised_samples",)
144+
FUNCTION = "process"
145+
CATEGORY = "multigpu/WanVideoWrapper"
146+
DESCRIPTION = "MultiGPU-aware sampler that ensures correct device for each model"
147+
148+
def process(self, model, compute_device, **kwargs):
149+
from . import set_current_device
150+
151+
original_sampler = NODE_CLASS_MAPPINGS["WanVideoSampler"]()
152+
sampler_module = inspect.getmodule(original_sampler)
153+
154+
original_module_device = sampler_module.device
155+
original_module_offload_device = sampler_module.offload_device
156+
157+
set_current_device(compute_device)
158+
compute_device_to_be_patched = mm.get_torch_device()
159+
sampler_module.device = compute_device_to_be_patched
160+
161+
transformer = model.model.diffusion_model
162+
transformer_options = model.model_options.get("transformer_options", {})
163+
block_swap_args = transformer_options.get("block_swap_args")
164+
165+
multi_gpu_block_swap = block_swap_args is not None and "swap_device" in block_swap_args
166+
offload_device_to_be_patched = None
167+
if multi_gpu_block_swap:
168+
swap_label = block_swap_args.get("swap_device")
169+
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoSamplerMultiGPU] block swap enabled, swap device: {swap_label}")
170+
offload_device_to_be_patched = torch.device(str(swap_label))
171+
sampler_module.offload_device = offload_device_to_be_patched
172+
173+
if transformer is not None and offload_device_to_be_patched is not None:
174+
transformer.offload_device = offload_device_to_be_patched
175+
transformer.cache_device = offload_device_to_be_patched
176+
177+
try:
178+
return original_sampler.process(model, **kwargs)
179+
finally:
180+
sampler_module.device = original_module_device
181+
sampler_module.offload_device = original_module_offload_device
87182

88183
class WanVideoTextEncode:
89184
@classmethod
@@ -117,20 +212,13 @@ def process(self, positive_prompt, negative_prompt, t5=None, load_device=None,fo
117212
else:
118213
device = "gpu"
119214

120-
if t5 is not None:
121-
text_encoder = t5[0]
122-
else:
123-
text_encoder = None
124-
125-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeMulitiGPU] current_device set to: {load_device}")
126-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeMulitiGPU] device set to: {device}")
215+
text_encoder = t5[0]
127216

128217
original_encoder = NODE_CLASS_MAPPINGS["WanVideoTextEncode"]()
129218
prompt_embeds_dict = original_encoder.process(positive_prompt, negative_prompt, text_encoder, force_offload, model_to_offload, use_disk_cache, device)
130219
return (prompt_embeds_dict)
131220

132221
def parse_prompt_weights(self, prompt):
133-
"""Extract text and weights from prompts with (text:weight) format"""
134222
original_parser = NODE_CLASS_MAPPINGS["WanVideoTextEncode"]()
135223
return original_parser.parse_prompt_weights(prompt)
136224

@@ -161,17 +249,13 @@ def INPUT_TYPES(s):
161249
def loadmodel(self, model_name, precision, device=None, quantization="disabled"):
162250
from . import set_current_device
163251

164-
if device is not None:
165-
set_current_device(device)
252+
set_current_device(device)
166253

167254
if device == "cpu":
168255
load_device = "offload_device"
169256
else:
170257
load_device = "main_device"
171258

172-
logger.info(f"[MultiGPU WanVideoWrapper][LoadWanVideoT5TextEncoder] current_device set to: {device}")
173-
logger.info(f"[MultiGPU WanVideoWrapper][LoadWanVideoT5TextEncoder] load_device set to: {load_device}")
174-
175259
original_loader = NODE_CLASS_MAPPINGS["LoadWanVideoT5TextEncoder"]()
176260
text_encoder = original_loader.loadmodel(model_name, precision, load_device, quantization)
177261

@@ -210,17 +294,13 @@ def INPUT_TYPES(s):
210294
def process(self, model_name, precision, positive_prompt, negative_prompt, quantization='disabled', use_disk_cache=True, load_device=None, extender_args=None):
211295
from . import set_current_device
212296

213-
if load_device is not None:
214-
set_current_device(load_device)
297+
set_current_device(load_device)
215298

216299
if load_device == "cpu":
217300
device = "cpu"
218301
else:
219302
device = "gpu"
220303

221-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeCachedMulitiGPU] current_device set to: {load_device}")
222-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeCachedMulitiGPU] device set to: {device}")
223-
224304
original_encoder = NODE_CLASS_MAPPINGS["WanVideoTextEncodeCached"]()
225305
prompt_embeds_dict, negative_text_embeds, positive_prompt_out = original_encoder.process(model_name, precision, positive_prompt, negative_prompt, quantization, use_disk_cache, device, extender_args)
226306

@@ -250,21 +330,14 @@ def INPUT_TYPES(s):
250330
def process(self, prompt, t5=None, load_device=None, force_offload=True, model_to_offload=None, use_disk_cache=False):
251331
from . import set_current_device
252332

253-
if load_device is not None:
254-
set_current_device(load_device)
333+
set_current_device(load_device)
255334

256335
if load_device == "cpu":
257336
device = "cpu"
258337
else:
259338
device = "gpu"
260339

261-
if t5 is not None:
262-
text_encoder = t5[0]
263-
else:
264-
text_encoder = None
265-
266-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeSingleMulitiGPU] current_device set to: {load_device}")
267-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeSingleMulitiGPU] device set to: {device}")
340+
text_encoder = t5[0]
268341

269342
original_encoder = NODE_CLASS_MAPPINGS["WanVideoTextEncodeSingle"]()
270343
prompt_embeds_dict = original_encoder.process(prompt, text_encoder, force_offload, model_to_offload, use_disk_cache, device)
@@ -297,15 +370,11 @@ def INPUT_TYPES(s):
297370
def loadmodel(self, model_name, load_device=None, precision="fp16", compile_args=None):
298371
from . import set_current_device
299372

300-
if load_device is not None:
301-
set_current_device(load_device)
302-
303-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoVAELoaderMultiGPU] load_device set to: {load_device}")
373+
set_current_device(load_device)
304374

305375
original_loader = NODE_CLASS_MAPPINGS["WanVideoVAELoader"]()
306376
vae_model = original_loader.loadmodel(model_name, precision, compile_args)
307377

308-
# Return both the VAE model AND the selected device for device propagation
309378
return vae_model, load_device
310379

311380
class WanVideoTinyVAELoader:
@@ -333,15 +402,11 @@ def INPUT_TYPES(s):
333402
def loadmodel(self, model_name, load_device=None, precision="fp16", parallel=False):
334403
from . import set_current_device
335404

336-
if load_device is not None:
337-
set_current_device(load_device)
338-
339-
logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTinyVAELoader] load_device set to: {load_device}")
405+
set_current_device(load_device)
340406

341407
original_loader = NODE_CLASS_MAPPINGS["WanVideoTinyVAELoader"]()
342408
vae_model = original_loader.loadmodel(model_name, precision, parallel)
343409

344-
# Return both the VAE model AND the selected device for device propagation
345410
return vae_model, load_device
346411

347412

@@ -369,8 +434,7 @@ def INPUT_TYPES(s):
369434

370435
def setargs(self, swap_device=None, **kwargs):
371436
block_swap_config = dict(kwargs)
372-
if swap_device is not None:
373-
block_swap_config["swap_device"] = str(swap_device)
437+
block_swap_config["swap_device"] = str(swap_device)
374438
return (block_swap_config,)
375439

376440
class WanVideoImageToVideoEncode:
@@ -622,86 +686,3 @@ def decode(self, vae, load_device, samples, enable_vae_tiling, tile_x, tile_y, t
622686
return (decode,)
623687

624688

625-
class WanVideoSampler:
626-
@classmethod
627-
def INPUT_TYPES(s):
628-
# Get original inputs and add our device input
629-
original_types = NODE_CLASS_MAPPINGS["WanVideoSampler"].INPUT_TYPES()
630-
original_types["required"]["compute_device"] = ("MULTIGPUDEVICE",)
631-
return original_types
632-
633-
RETURN_TYPES = ("LATENT", "LATENT",)
634-
RETURN_NAMES = ("samples", "denoised_samples",)
635-
FUNCTION = "process"
636-
CATEGORY = "multigpu/WanVideoWrapper"
637-
DESCRIPTION = "MultiGPU-aware sampler that ensures correct device for each model"
638-
639-
def process(self, model, compute_device, **kwargs):
640-
from . import set_current_device
641-
logger.info(
642-
f"[MultiGPU WanVideoSampler] Received request to process on: {compute_device}"
643-
)
644-
645-
patcher = model
646-
transformer = None
647-
if hasattr(patcher, "model"):
648-
transformer = getattr(patcher.model, "diffusion_model", None)
649-
650-
if compute_device:
651-
target_device = torch.device(compute_device)
652-
set_current_device(target_device)
653-
else:
654-
target_device = mm.get_torch_device()
655-
656-
normalized_swap_device = None
657-
transformer_options = {}
658-
if hasattr(patcher, "model_options"):
659-
transformer_options = patcher.model_options.get("transformer_options", {})
660-
block_swap_args = transformer_options.get("block_swap_args") if transformer_options else None
661-
if block_swap_args:
662-
swap_label = block_swap_args.get("resolved_swap_device") or block_swap_args.get("swap_device")
663-
if swap_label:
664-
try:
665-
normalized_swap_device = torch.device(str(swap_label))
666-
except (TypeError, ValueError):
667-
logger.warning(
668-
"[MultiGPU WanVideoSampler] Invalid swap device '%s', leaving sampler offload unchanged",
669-
swap_label,
670-
)
671-
normalized_swap_device = None
672-
673-
original_sampler = NODE_CLASS_MAPPINGS["WanVideoSampler"]()
674-
sampler_module = inspect.getmodule(original_sampler)
675-
676-
original_module_device = None
677-
original_module_offload = None
678-
had_sampler_offload_attr = False
679-
if sampler_module is not None:
680-
original_module_device = getattr(sampler_module, "device", None)
681-
had_sampler_offload_attr = hasattr(sampler_module, "offload_device")
682-
original_module_offload = getattr(sampler_module, "offload_device", None)
683-
setattr(sampler_module, "device", target_device)
684-
if normalized_swap_device is not None:
685-
setattr(sampler_module, "offload_device", normalized_swap_device)
686-
elif compute_device == "cpu":
687-
setattr(sampler_module, "offload_device", target_device)
688-
else:
689-
logger.error("[MultiGPU WanVideoSampler] Unable to resolve sampler module for device patching.")
690-
691-
if transformer is not None and normalized_swap_device is not None:
692-
transformer.offload_device = normalized_swap_device
693-
transformer.cache_device = normalized_swap_device
694-
695-
try:
696-
return original_sampler.process(model=patcher, **kwargs)
697-
finally:
698-
if sampler_module is not None:
699-
if original_module_device is not None:
700-
setattr(sampler_module, "device", original_module_device)
701-
if had_sampler_offload_attr:
702-
setattr(sampler_module, "offload_device", original_module_offload)
703-
else:
704-
try:
705-
delattr(sampler_module, "offload_device")
706-
except AttributeError:
707-
pass

0 commit comments

Comments
 (0)