Skip to content

fix(finetuning-asr): auto-detect device and attention implementation in LoRA inference#353

Open
voidborne-d wants to merge 1 commit into
microsoft:mainfrom
voidborne-d:fix/lora-inference-device-compat
Open

fix(finetuning-asr): auto-detect device and attention implementation in LoRA inference#353
voidborne-d wants to merge 1 commit into
microsoft:mainfrom
voidborne-d:fix/lora-inference-device-compat

Conversation

@voidborne-d
Copy link
Copy Markdown
Contributor

Problem

finetuning-asr/inference_lora.py hardcodes attn_implementation=\"flash_attention_2\" and splits dtype as bfloat16 vs float32 on cpu-only. Two concrete symptoms for non-CUDA users running a fine-tuned LoRA:

  • On Apple Silicon (--device mps): from_pretrained(..., attn_implementation=\"flash_attention_2\") blows up at model load — flash attention 2 requires CUDA.
  • On Intel XPU or CPU: same failure.
  • The --device argparse option has no choices=, so typos silently make it through to .to(device) with confusing tracebacks.

The sibling batch-inference demo at demo/vibevoice_asr_inference_from_file.py already solves this with an auto path. This PR brings the LoRA inference script to parity.

Changes

  • load_lora_model() now accepts attn_implementation and threads it through from_pretrained. MPS load-then-move comment added inline.
  • main() gains a --attn_implementation flag (flash_attention_2|sdpa|eager|auto, default auto). When auto, CUDA + importable flash_attnflash_attention_2, otherwise falls back to sdpa with a log line.
  • --device now restricts to {cuda, cpu, mps, xpu, auto} and defaults to the best available (cuda > xpu > mps > cpu), same pattern as vibevoice_asr_inference_from_file.py.
  • dtype is torch.float32 for cpu/mps/xpu (required), torch.bfloat16 for cuda.

Behaviour for CUDA users with flash_attn installed is unchanged.

Verification

Local smoke test with load_lora_model stubbed out (no GPU needed):

=== device=cpu ===
Auto-detected attention implementation: sdpa
load_lora_model called with device=cpu dtype=torch.float32 attn=sdpa

=== device=mps ===
Auto-detected attention implementation: sdpa
load_lora_model called with device=mps dtype=torch.float32 attn=sdpa

=== device=cuda (no flash_attn) ===
flash_attn not installed, falling back to sdpa
Auto-detected attention implementation: sdpa

=== device=cpu --attn_implementation eager ===
(no auto-detect print; explicit eager flows through)

python finetuning-asr/inference_lora.py --help reflects the new flag/choices.

…in LoRA inference

The inference_lora.py script hardcoded flash_attention_2 and defaulted
dtype based only on a cpu/not-cpu split. This made it impossible to run
on MPS (Apple Silicon) or XPU without manual code edits, and non-CUDA
users hit a hard failure at model load because flash_attention_2
requires CUDA.

Changes:
- Add --attn_implementation flag with 'auto' default that picks
  flash_attention_2 when CUDA + flash_attn are available, otherwise
  sdpa.
- Broaden --device choices to include mps/xpu/auto and pick the best
  available device by default (cuda > xpu > mps > cpu), matching the
  pattern already used in demo/vibevoice_asr_inference_from_file.py.
- Use float32 for cpu/mps/xpu (required) and bfloat16 for cuda.
- Thread attn_implementation through load_lora_model so callers can
  override explicitly.

No behaviour change for users on CUDA with flash_attn installed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant