fix(finetuning-asr): auto-detect device and attention implementation in LoRA inference#353
Open
voidborne-d wants to merge 1 commit into
Open
Conversation
…in LoRA inference The inference_lora.py script hardcoded flash_attention_2 and defaulted dtype based only on a cpu/not-cpu split. This made it impossible to run on MPS (Apple Silicon) or XPU without manual code edits, and non-CUDA users hit a hard failure at model load because flash_attention_2 requires CUDA. Changes: - Add --attn_implementation flag with 'auto' default that picks flash_attention_2 when CUDA + flash_attn are available, otherwise sdpa. - Broaden --device choices to include mps/xpu/auto and pick the best available device by default (cuda > xpu > mps > cpu), matching the pattern already used in demo/vibevoice_asr_inference_from_file.py. - Use float32 for cpu/mps/xpu (required) and bfloat16 for cuda. - Thread attn_implementation through load_lora_model so callers can override explicitly. No behaviour change for users on CUDA with flash_attn installed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
finetuning-asr/inference_lora.pyhardcodesattn_implementation=\"flash_attention_2\"and splits dtype asbfloat16vsfloat32on cpu-only. Two concrete symptoms for non-CUDA users running a fine-tuned LoRA:--device mps):from_pretrained(..., attn_implementation=\"flash_attention_2\")blows up at model load — flash attention 2 requires CUDA.--deviceargparse option has nochoices=, so typos silently make it through to.to(device)with confusing tracebacks.The sibling batch-inference demo at
demo/vibevoice_asr_inference_from_file.pyalready solves this with anautopath. This PR brings the LoRA inference script to parity.Changes
load_lora_model()now acceptsattn_implementationand threads it throughfrom_pretrained. MPS load-then-move comment added inline.main()gains a--attn_implementationflag (flash_attention_2|sdpa|eager|auto, defaultauto). Whenauto, CUDA + importableflash_attn→flash_attention_2, otherwise falls back tosdpawith a log line.--devicenow restricts to{cuda, cpu, mps, xpu, auto}and defaults to the best available (cuda > xpu > mps > cpu), same pattern asvibevoice_asr_inference_from_file.py.torch.float32forcpu/mps/xpu(required),torch.bfloat16forcuda.Behaviour for CUDA users with
flash_attninstalled is unchanged.Verification
Local smoke test with
load_lora_modelstubbed out (no GPU needed):python finetuning-asr/inference_lora.py --helpreflects the new flag/choices.