|
18 | 18 | logger = logging.getLogger("MultiGPU") |
19 | 19 |
|
20 | 20 |
|
| 21 | +scheduler_list = [ |
| 22 | + "unipc", "unipc/beta", |
| 23 | + "dpm++", "dpm++/beta", |
| 24 | + "dpm++_sde", "dpm++_sde/beta", |
| 25 | + "euler", "euler/beta", |
| 26 | + "deis", |
| 27 | + "lcm", "lcm/beta", |
| 28 | + "res_multistep", |
| 29 | + "flowmatch_causvid", |
| 30 | + "flowmatch_distill", |
| 31 | + "flowmatch_pusa", |
| 32 | + "multitalk", |
| 33 | + "sa_ode_stable" |
| 34 | +] |
| 35 | + |
| 36 | +rope_functions = ["default", "comfy", "comfy_chunked"] |
| 37 | + |
21 | 38 | class WanVideoModelLoader: |
22 | 39 | @classmethod |
23 | 40 | def INPUT_TYPES(s): |
@@ -59,31 +76,109 @@ def INPUT_TYPES(s): |
59 | 76 | FUNCTION = "loadmodel" |
60 | 77 | CATEGORY = "multigpu/WanVideoWrapper" |
61 | 78 |
|
62 | | - def loadmodel(self, model, base_precision, compute_device, quantization, load_device, |
63 | | - compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, |
64 | | - vram_management_args=None, extra_model=None, vace_model=None, |
65 | | - fantasytalking_model=None, multitalk_model=None, fantasyportrait_model=None, |
66 | | - rms_norm_function="default"): |
| 79 | + def loadmodel(self, model, base_precision, compute_device, quantization, load_device, **kwargs): |
67 | 80 | from . import set_current_device |
68 | 81 |
|
69 | | - set_current_device(compute_device) |
70 | | - compute_device_to_be_patched = mm.get_torch_device() |
71 | | - |
72 | 82 | original_loader = NODE_CLASS_MAPPINGS["WanVideoModelLoader"]() |
73 | 83 | loader_module = inspect.getmodule(original_loader) |
74 | | - |
75 | 84 | original_module_device = loader_module.device |
76 | 85 |
|
77 | | - loader_module.device = compute_device_to_be_patched |
| 86 | + set_current_device(compute_device) |
| 87 | + compute_device_to_be_patched = mm.get_torch_device() |
78 | 88 |
|
79 | | - result = original_loader.loadmodel(model, base_precision, load_device, quantization, compile_args, attention_mode, block_swap_args, lora, vram_management_args, extra_model=extra_model, |
80 | | - vace_model=vace_model, fantasytalking_model=fantasytalking_model, multitalk_model=multitalk_model, fantasyportrait_model=fantasyportrait_model, rms_norm_function=rms_norm_function,) |
| 89 | + loader_module.device = compute_device_to_be_patched |
81 | 90 |
|
82 | | - loader_module.device = original_module_device |
| 91 | + result = original_loader.loadmodel(model, base_precision, load_device, quantization, **kwargs,) |
83 | 92 |
|
84 | 93 | patcher = result[0] |
85 | 94 |
|
86 | | - return (patcher, compute_device) |
| 95 | + try: |
| 96 | + return (patcher, compute_device) |
| 97 | + |
| 98 | + finally: |
| 99 | + loader_module.device = original_module_device |
| 100 | + |
| 101 | +class WanVideoSampler: |
| 102 | + @classmethod |
| 103 | + def INPUT_TYPES(s): |
| 104 | + return { |
| 105 | + "required": { |
| 106 | + "model": ("WANVIDEOMODEL",), |
| 107 | + "compute_device": ("MULTIGPUDEVICE",), |
| 108 | + "image_embeds": ("WANVIDIMAGE_EMBEDS", ), |
| 109 | + "steps": ("INT", {"default": 30, "min": 1}), |
| 110 | + "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), |
| 111 | + "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), |
| 112 | + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
| 113 | + "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}), |
| 114 | + "scheduler": (scheduler_list, {"default": "unipc",}), |
| 115 | + "riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}), |
| 116 | + }, |
| 117 | + "optional": { |
| 118 | + "text_embeds": ("WANVIDEOTEXTEMBEDS", ), |
| 119 | + "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), |
| 120 | + "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), |
| 121 | + "feta_args": ("FETAARGS", ), |
| 122 | + "context_options": ("WANVIDCONTEXT", ), |
| 123 | + "cache_args": ("CACHEARGS", ), |
| 124 | + "flowedit_args": ("FLOWEDITARGS", ), |
| 125 | + "batched_cfg": ("BOOLEAN", {"default": False, "tooltip": "Batch cond and uncond for faster sampling, possibly faster on some hardware, uses more memory"}), |
| 126 | + "slg_args": ("SLGARGS", ), |
| 127 | + "rope_function": (rope_functions, {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile. Chunked version has reduced peak VRAM usage when not using torch.compile"}), |
| 128 | + "loop_args": ("LOOPARGS", ), |
| 129 | + "experimental_args": ("EXPERIMENTALARGS", ), |
| 130 | + "sigmas": ("SIGMAS", ), |
| 131 | + "unianimate_poses": ("UNIANIMATE_POSE", ), |
| 132 | + "fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ), |
| 133 | + "uni3c_embeds": ("UNI3C_EMBEDS", ), |
| 134 | + "multitalk_embeds": ("MULTITALK_EMBEDS", ), |
| 135 | + "freeinit_args": ("FREEINITARGS", ), |
| 136 | + "start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Start step for the sampling, 0 means full sampling, otherwise samples only from this step"}), |
| 137 | + "end_step": ("INT", {"default": -1, "min": -1, "max": 10000, "step": 1, "tooltip": "End step for the sampling, -1 means full sampling, otherwise samples only until this step"}), |
| 138 | + "add_noise_to_samples": ("BOOLEAN", {"default": False, "tooltip": "Add noise to the samples before sampling, needed for video2video sampling when starting from clean video"}), |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + RETURN_TYPES = ("LATENT", "LATENT",) |
| 143 | + RETURN_NAMES = ("samples", "denoised_samples",) |
| 144 | + FUNCTION = "process" |
| 145 | + CATEGORY = "multigpu/WanVideoWrapper" |
| 146 | + DESCRIPTION = "MultiGPU-aware sampler that ensures correct device for each model" |
| 147 | + |
| 148 | + def process(self, model, compute_device, **kwargs): |
| 149 | + from . import set_current_device |
| 150 | + |
| 151 | + original_sampler = NODE_CLASS_MAPPINGS["WanVideoSampler"]() |
| 152 | + sampler_module = inspect.getmodule(original_sampler) |
| 153 | + |
| 154 | + original_module_device = sampler_module.device |
| 155 | + original_module_offload_device = sampler_module.offload_device |
| 156 | + |
| 157 | + set_current_device(compute_device) |
| 158 | + compute_device_to_be_patched = mm.get_torch_device() |
| 159 | + sampler_module.device = compute_device_to_be_patched |
| 160 | + |
| 161 | + transformer = model.model.diffusion_model |
| 162 | + transformer_options = model.model_options.get("transformer_options", {}) |
| 163 | + block_swap_args = transformer_options.get("block_swap_args") |
| 164 | + |
| 165 | + multi_gpu_block_swap = block_swap_args is not None and "swap_device" in block_swap_args |
| 166 | + offload_device_to_be_patched = None |
| 167 | + if multi_gpu_block_swap: |
| 168 | + swap_label = block_swap_args.get("swap_device") |
| 169 | + logger.info(f"[MultiGPU WanVideoWrapper][WanVideoSamplerMultiGPU] block swap enabled, swap device: {swap_label}") |
| 170 | + offload_device_to_be_patched = torch.device(str(swap_label)) |
| 171 | + sampler_module.offload_device = offload_device_to_be_patched |
| 172 | + |
| 173 | + if transformer is not None and offload_device_to_be_patched is not None: |
| 174 | + transformer.offload_device = offload_device_to_be_patched |
| 175 | + transformer.cache_device = offload_device_to_be_patched |
| 176 | + |
| 177 | + try: |
| 178 | + return original_sampler.process(model, **kwargs) |
| 179 | + finally: |
| 180 | + sampler_module.device = original_module_device |
| 181 | + sampler_module.offload_device = original_module_offload_device |
87 | 182 |
|
88 | 183 | class WanVideoTextEncode: |
89 | 184 | @classmethod |
@@ -117,20 +212,13 @@ def process(self, positive_prompt, negative_prompt, t5=None, load_device=None,fo |
117 | 212 | else: |
118 | 213 | device = "gpu" |
119 | 214 |
|
120 | | - if t5 is not None: |
121 | | - text_encoder = t5[0] |
122 | | - else: |
123 | | - text_encoder = None |
124 | | - |
125 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeMulitiGPU] current_device set to: {load_device}") |
126 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeMulitiGPU] device set to: {device}") |
| 215 | + text_encoder = t5[0] |
127 | 216 |
|
128 | 217 | original_encoder = NODE_CLASS_MAPPINGS["WanVideoTextEncode"]() |
129 | 218 | prompt_embeds_dict = original_encoder.process(positive_prompt, negative_prompt, text_encoder, force_offload, model_to_offload, use_disk_cache, device) |
130 | 219 | return (prompt_embeds_dict) |
131 | 220 |
|
132 | 221 | def parse_prompt_weights(self, prompt): |
133 | | - """Extract text and weights from prompts with (text:weight) format""" |
134 | 222 | original_parser = NODE_CLASS_MAPPINGS["WanVideoTextEncode"]() |
135 | 223 | return original_parser.parse_prompt_weights(prompt) |
136 | 224 |
|
@@ -161,17 +249,13 @@ def INPUT_TYPES(s): |
161 | 249 | def loadmodel(self, model_name, precision, device=None, quantization="disabled"): |
162 | 250 | from . import set_current_device |
163 | 251 |
|
164 | | - if device is not None: |
165 | | - set_current_device(device) |
| 252 | + set_current_device(device) |
166 | 253 |
|
167 | 254 | if device == "cpu": |
168 | 255 | load_device = "offload_device" |
169 | 256 | else: |
170 | 257 | load_device = "main_device" |
171 | 258 |
|
172 | | - logger.info(f"[MultiGPU WanVideoWrapper][LoadWanVideoT5TextEncoder] current_device set to: {device}") |
173 | | - logger.info(f"[MultiGPU WanVideoWrapper][LoadWanVideoT5TextEncoder] load_device set to: {load_device}") |
174 | | - |
175 | 259 | original_loader = NODE_CLASS_MAPPINGS["LoadWanVideoT5TextEncoder"]() |
176 | 260 | text_encoder = original_loader.loadmodel(model_name, precision, load_device, quantization) |
177 | 261 |
|
@@ -210,17 +294,13 @@ def INPUT_TYPES(s): |
210 | 294 | def process(self, model_name, precision, positive_prompt, negative_prompt, quantization='disabled', use_disk_cache=True, load_device=None, extender_args=None): |
211 | 295 | from . import set_current_device |
212 | 296 |
|
213 | | - if load_device is not None: |
214 | | - set_current_device(load_device) |
| 297 | + set_current_device(load_device) |
215 | 298 |
|
216 | 299 | if load_device == "cpu": |
217 | 300 | device = "cpu" |
218 | 301 | else: |
219 | 302 | device = "gpu" |
220 | 303 |
|
221 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeCachedMulitiGPU] current_device set to: {load_device}") |
222 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeCachedMulitiGPU] device set to: {device}") |
223 | | - |
224 | 304 | original_encoder = NODE_CLASS_MAPPINGS["WanVideoTextEncodeCached"]() |
225 | 305 | prompt_embeds_dict, negative_text_embeds, positive_prompt_out = original_encoder.process(model_name, precision, positive_prompt, negative_prompt, quantization, use_disk_cache, device, extender_args) |
226 | 306 |
|
@@ -250,21 +330,14 @@ def INPUT_TYPES(s): |
250 | 330 | def process(self, prompt, t5=None, load_device=None, force_offload=True, model_to_offload=None, use_disk_cache=False): |
251 | 331 | from . import set_current_device |
252 | 332 |
|
253 | | - if load_device is not None: |
254 | | - set_current_device(load_device) |
| 333 | + set_current_device(load_device) |
255 | 334 |
|
256 | 335 | if load_device == "cpu": |
257 | 336 | device = "cpu" |
258 | 337 | else: |
259 | 338 | device = "gpu" |
260 | 339 |
|
261 | | - if t5 is not None: |
262 | | - text_encoder = t5[0] |
263 | | - else: |
264 | | - text_encoder = None |
265 | | - |
266 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeSingleMulitiGPU] current_device set to: {load_device}") |
267 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTextEncodeSingleMulitiGPU] device set to: {device}") |
| 340 | + text_encoder = t5[0] |
268 | 341 |
|
269 | 342 | original_encoder = NODE_CLASS_MAPPINGS["WanVideoTextEncodeSingle"]() |
270 | 343 | prompt_embeds_dict = original_encoder.process(prompt, text_encoder, force_offload, model_to_offload, use_disk_cache, device) |
@@ -297,15 +370,11 @@ def INPUT_TYPES(s): |
297 | 370 | def loadmodel(self, model_name, load_device=None, precision="fp16", compile_args=None): |
298 | 371 | from . import set_current_device |
299 | 372 |
|
300 | | - if load_device is not None: |
301 | | - set_current_device(load_device) |
302 | | - |
303 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoVAELoaderMultiGPU] load_device set to: {load_device}") |
| 373 | + set_current_device(load_device) |
304 | 374 |
|
305 | 375 | original_loader = NODE_CLASS_MAPPINGS["WanVideoVAELoader"]() |
306 | 376 | vae_model = original_loader.loadmodel(model_name, precision, compile_args) |
307 | 377 |
|
308 | | - # Return both the VAE model AND the selected device for device propagation |
309 | 378 | return vae_model, load_device |
310 | 379 |
|
311 | 380 | class WanVideoTinyVAELoader: |
@@ -333,15 +402,11 @@ def INPUT_TYPES(s): |
333 | 402 | def loadmodel(self, model_name, load_device=None, precision="fp16", parallel=False): |
334 | 403 | from . import set_current_device |
335 | 404 |
|
336 | | - if load_device is not None: |
337 | | - set_current_device(load_device) |
338 | | - |
339 | | - logger.info(f"[MultiGPU WanVideoWrapper][WanVideoTinyVAELoader] load_device set to: {load_device}") |
| 405 | + set_current_device(load_device) |
340 | 406 |
|
341 | 407 | original_loader = NODE_CLASS_MAPPINGS["WanVideoTinyVAELoader"]() |
342 | 408 | vae_model = original_loader.loadmodel(model_name, precision, parallel) |
343 | 409 |
|
344 | | - # Return both the VAE model AND the selected device for device propagation |
345 | 410 | return vae_model, load_device |
346 | 411 |
|
347 | 412 |
|
@@ -369,8 +434,7 @@ def INPUT_TYPES(s): |
369 | 434 |
|
370 | 435 | def setargs(self, swap_device=None, **kwargs): |
371 | 436 | block_swap_config = dict(kwargs) |
372 | | - if swap_device is not None: |
373 | | - block_swap_config["swap_device"] = str(swap_device) |
| 437 | + block_swap_config["swap_device"] = str(swap_device) |
374 | 438 | return (block_swap_config,) |
375 | 439 |
|
376 | 440 | class WanVideoImageToVideoEncode: |
@@ -622,86 +686,3 @@ def decode(self, vae, load_device, samples, enable_vae_tiling, tile_x, tile_y, t |
622 | 686 | return (decode,) |
623 | 687 |
|
624 | 688 |
|
625 | | -class WanVideoSampler: |
626 | | - @classmethod |
627 | | - def INPUT_TYPES(s): |
628 | | - # Get original inputs and add our device input |
629 | | - original_types = NODE_CLASS_MAPPINGS["WanVideoSampler"].INPUT_TYPES() |
630 | | - original_types["required"]["compute_device"] = ("MULTIGPUDEVICE",) |
631 | | - return original_types |
632 | | - |
633 | | - RETURN_TYPES = ("LATENT", "LATENT",) |
634 | | - RETURN_NAMES = ("samples", "denoised_samples",) |
635 | | - FUNCTION = "process" |
636 | | - CATEGORY = "multigpu/WanVideoWrapper" |
637 | | - DESCRIPTION = "MultiGPU-aware sampler that ensures correct device for each model" |
638 | | - |
639 | | - def process(self, model, compute_device, **kwargs): |
640 | | - from . import set_current_device |
641 | | - logger.info( |
642 | | - f"[MultiGPU WanVideoSampler] Received request to process on: {compute_device}" |
643 | | - ) |
644 | | - |
645 | | - patcher = model |
646 | | - transformer = None |
647 | | - if hasattr(patcher, "model"): |
648 | | - transformer = getattr(patcher.model, "diffusion_model", None) |
649 | | - |
650 | | - if compute_device: |
651 | | - target_device = torch.device(compute_device) |
652 | | - set_current_device(target_device) |
653 | | - else: |
654 | | - target_device = mm.get_torch_device() |
655 | | - |
656 | | - normalized_swap_device = None |
657 | | - transformer_options = {} |
658 | | - if hasattr(patcher, "model_options"): |
659 | | - transformer_options = patcher.model_options.get("transformer_options", {}) |
660 | | - block_swap_args = transformer_options.get("block_swap_args") if transformer_options else None |
661 | | - if block_swap_args: |
662 | | - swap_label = block_swap_args.get("resolved_swap_device") or block_swap_args.get("swap_device") |
663 | | - if swap_label: |
664 | | - try: |
665 | | - normalized_swap_device = torch.device(str(swap_label)) |
666 | | - except (TypeError, ValueError): |
667 | | - logger.warning( |
668 | | - "[MultiGPU WanVideoSampler] Invalid swap device '%s', leaving sampler offload unchanged", |
669 | | - swap_label, |
670 | | - ) |
671 | | - normalized_swap_device = None |
672 | | - |
673 | | - original_sampler = NODE_CLASS_MAPPINGS["WanVideoSampler"]() |
674 | | - sampler_module = inspect.getmodule(original_sampler) |
675 | | - |
676 | | - original_module_device = None |
677 | | - original_module_offload = None |
678 | | - had_sampler_offload_attr = False |
679 | | - if sampler_module is not None: |
680 | | - original_module_device = getattr(sampler_module, "device", None) |
681 | | - had_sampler_offload_attr = hasattr(sampler_module, "offload_device") |
682 | | - original_module_offload = getattr(sampler_module, "offload_device", None) |
683 | | - setattr(sampler_module, "device", target_device) |
684 | | - if normalized_swap_device is not None: |
685 | | - setattr(sampler_module, "offload_device", normalized_swap_device) |
686 | | - elif compute_device == "cpu": |
687 | | - setattr(sampler_module, "offload_device", target_device) |
688 | | - else: |
689 | | - logger.error("[MultiGPU WanVideoSampler] Unable to resolve sampler module for device patching.") |
690 | | - |
691 | | - if transformer is not None and normalized_swap_device is not None: |
692 | | - transformer.offload_device = normalized_swap_device |
693 | | - transformer.cache_device = normalized_swap_device |
694 | | - |
695 | | - try: |
696 | | - return original_sampler.process(model=patcher, **kwargs) |
697 | | - finally: |
698 | | - if sampler_module is not None: |
699 | | - if original_module_device is not None: |
700 | | - setattr(sampler_module, "device", original_module_device) |
701 | | - if had_sampler_offload_attr: |
702 | | - setattr(sampler_module, "offload_device", original_module_offload) |
703 | | - else: |
704 | | - try: |
705 | | - delattr(sampler_module, "offload_device") |
706 | | - except AttributeError: |
707 | | - pass |
0 commit comments