Skip to content

Commit c0ec424

Browse files
Merge pull request #128 from vllm-project/lwilkinson/update-fa4
Update FA4
2 parents 2921022 + 95e93d2 commit c0ec424

85 files changed

Lines changed: 15013 additions & 8696 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/scripts/bump_beta_tag.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python3
2+
"""Resolve the next FA4 beta tag and optionally create + push it.
3+
4+
Usage:
5+
python bump_beta_tag.py # dry-run by default
6+
python bump_beta_tag.py --push # create and push the tag
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import argparse
12+
import os
13+
import re
14+
import subprocess
15+
import sys
16+
17+
TAG_PATTERN = re.compile(r"^(fa4-v.+\.beta)(\d+)$")
18+
19+
20+
def git(*args: str) -> str:
21+
result = subprocess.run(
22+
["git", *args], capture_output=True, text=True, check=True
23+
)
24+
return result.stdout.strip()
25+
26+
27+
def get_beta_tags() -> list[tuple[str, int]]:
28+
raw = git("tag", "-l", "fa4-v*.beta*")
29+
if not raw:
30+
return []
31+
tags = []
32+
for line in raw.splitlines():
33+
m = TAG_PATTERN.match(line.strip())
34+
if m:
35+
tags.append((line.strip(), int(m.group(2))))
36+
return sorted(tags, key=lambda t: t[1])
37+
38+
39+
def tag_exists(tag: str) -> bool:
40+
result = subprocess.run(
41+
["git", "rev-parse", tag], capture_output=True, text=True
42+
)
43+
return result.returncode == 0
44+
45+
46+
def set_github_output(key: str, value: str) -> None:
47+
path = os.environ.get("GITHUB_OUTPUT")
48+
if path:
49+
with open(path, "a") as f:
50+
f.write(f"{key}={value}\n")
51+
52+
53+
def main() -> None:
54+
parser = argparse.ArgumentParser()
55+
parser.add_argument("--push", action="store_true", help="Create and push the tag (default: dry-run)")
56+
args = parser.parse_args()
57+
58+
tags = get_beta_tags()
59+
if not tags:
60+
print("::error::No existing fa4-v*.beta* tags found", file=sys.stderr)
61+
sys.exit(1)
62+
63+
latest_tag, latest_num = tags[-1]
64+
next_num = latest_num + 1
65+
prefix = TAG_PATTERN.match(latest_tag)
66+
if prefix is None:
67+
print(f"::error::Latest tag {latest_tag!r} no longer matches pattern", file=sys.stderr)
68+
sys.exit(1)
69+
next_tag = f"{prefix.group(1)}{next_num}"
70+
71+
already_exists = tag_exists(next_tag)
72+
73+
if already_exists:
74+
print(f"Tag {next_tag} already exists, reusing it")
75+
else:
76+
print(f"Bumping: {latest_tag} -> {next_tag}")
77+
78+
set_github_output("next_tag", next_tag)
79+
80+
if args.push and not already_exists:
81+
try:
82+
git("tag", next_tag)
83+
git("push", "origin", next_tag)
84+
except subprocess.CalledProcessError:
85+
if tag_exists(next_tag):
86+
print(f"Tag {next_tag} was created by a concurrent run, reusing it")
87+
else:
88+
raise
89+
else:
90+
print(f"Pushed {next_tag}")
91+
elif not args.push:
92+
print(f"Dry-run: would create and push {next_tag}")
93+
94+
95+
if __name__ == "__main__":
96+
main()

.github/workflows/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# GitHub Workflow Tagging Flow
2+
3+
This repository uses separate tag lanes so FA2 and FA4 publishing do not collide.
4+
5+
## Release lanes
6+
7+
| Tag pattern | Workflow | Package target | Version source |
8+
| --- | --- | --- | --- |
9+
| `v*` | `.github/workflows/publish.yml` | Root package (`flash-attn`) | Root package version metadata |
10+
| `fa4-v*` | `.github/workflows/publish-fa4.yml` | `flash_attn/cute` package (`flash-attn-4`) | `setuptools-scm` with `fa4-v*` tags |
11+
12+
## How to publish
13+
14+
### FA2 / root package lane
15+
16+
1. Create a tag matching `v*` (example: `v2.9.0`).
17+
2. Push that tag.
18+
3. `publish.yml` creates a release, builds wheel matrix artifacts, and publishes to PyPI.
19+
20+
### FA4 / CUTE package lane
21+
22+
**Manual release**: create and push a tag matching `fa4-v*` (example: `fa4-v4.0.0`).
23+
24+
**Weekly beta**: `publish-fa4.yml` also runs every Wednesday at 08:00 UTC via cron. The scheduled or manual run creates and pushes the next `fa4-v*.beta*` tag, then continues in the same workflow run to build and publish that beta. Manual dispatch is restricted to the repository default branch so it cannot tag a feature branch commit. The pushed tag matches the `fa4-v*` trigger, but GitHub suppresses workflow runs for events created by `GITHUB_TOKEN`, so no recursive run is triggered.
25+
26+
| Week | Tag created | PyPI version |
27+
| --- | --- | --- |
28+
| 1 | `fa4-v4.0.0.beta5` | `4.0.0b5` |
29+
| 2 | `fa4-v4.0.0.beta6` | `4.0.0b6` |
30+
31+
To stop weekly betas: GitHub repo → Actions → "Publish flash-attn-4 to PyPI" → `···` menu → **Disable workflow**. Re-enable when ready to resume, or switch to manual tag pushes only by removing the `schedule` trigger. Users can still push a `fa4-v*.beta*` tag directly when they need to cut a beta outside the schedule.
32+
33+
## Guardrails
34+
35+
- Do not use `v*` tags for FA4 releases.
36+
- Do not use `fa4-v*` tags for FA2 releases.
37+
- Keep `flash_attn/cute/pyproject.toml` tag parsing in sync with the FA4 tag prefix.
38+
- The workflow filename (`publish-fa4.yml`) is part of the PyPI trusted publishing OIDC identity — do not rename without updating PyPI.

.github/workflows/_build.yml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ jobs:
7777

7878
- name: Install CUDA ${{ inputs.cuda-version }}
7979
if: ${{ inputs.cuda-version != 'cpu' }}
80-
uses: Jimver/cuda-toolkit@v0.2.29
80+
uses: Jimver/cuda-toolkit@v0.2.30
8181
id: cuda-toolkit
8282
with:
8383
cuda: ${{ inputs.cuda-version }}
@@ -93,14 +93,17 @@ jobs:
9393
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
9494
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
9595
pip install typing-extensions==4.12.2
96-
# We want to figure out the CUDA version to download pytorch
97-
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
98-
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
99-
# This code is ugly, maybe there's a better way to do this.
96+
# Pick the highest available PyTorch wheel CUDA version that doesn't exceed system CUDA
10097
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
101-
minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \
102-
maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \
103-
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
98+
available = { \
99+
'2.6': [118, 124, 126], \
100+
'2.7': [118, 126, 128], \
101+
'2.8': [126, 128, 129], \
102+
'2.9': [126, 128, 130], \
103+
'2.10': [126, 128, 130], \
104+
}[env['MATRIX_TORCH_VERSION']]; \
105+
sys_cuda = int(env['MATRIX_CUDA_VERSION']); \
106+
print(max(v for v in available if v <= sys_cuda))" \
104107
)
105108
# detect if we're on ARM
106109
if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then

.github/workflows/publish-fa4.yml

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
name: Publish flash-attn-4 to PyPI
2+
3+
on:
4+
push:
5+
tags:
6+
- 'fa4-v*'
7+
schedule:
8+
- cron: '0 8 * * 3' # Wednesday 08:00 UTC
9+
workflow_dispatch:
10+
11+
permissions:
12+
contents: write
13+
14+
jobs:
15+
prepare-release:
16+
runs-on: ubuntu-latest
17+
outputs:
18+
release_tag: ${{ steps.resolve-tag.outputs.release_tag }}
19+
steps:
20+
- name: Require default branch for manual runs
21+
if: github.event_name == 'workflow_dispatch'
22+
run: |
23+
if [ "${{ github.ref_name }}" != "${{ github.event.repository.default_branch }}" ]; then
24+
echo "::error::Run this workflow from ${{ github.event.repository.default_branch }} only"
25+
exit 1
26+
fi
27+
- uses: actions/checkout@v4
28+
if: github.event_name != 'push'
29+
with:
30+
ref: ${{ github.event.repository.default_branch }}
31+
fetch-depth: 0
32+
- uses: actions/setup-python@v5
33+
if: github.event_name != 'push'
34+
with:
35+
python-version: '3.12'
36+
- name: Bump beta tag
37+
if: github.event_name != 'push'
38+
id: bump
39+
run: python .github/scripts/bump_beta_tag.py --push
40+
- name: Resolve release tag
41+
id: resolve-tag
42+
run: |
43+
if [ "${{ github.event_name }}" = "push" ]; then
44+
echo "release_tag=${GITHUB_REF#refs/tags/}" >> "$GITHUB_OUTPUT"
45+
else
46+
echo "release_tag=${{ steps.bump.outputs.next_tag }}" >> "$GITHUB_OUTPUT"
47+
fi
48+
49+
build:
50+
needs: prepare-release
51+
runs-on: ubuntu-latest
52+
steps:
53+
- uses: actions/checkout@v4
54+
with:
55+
ref: ${{ needs.prepare-release.outputs.release_tag }}
56+
fetch-depth: 0
57+
- uses: actions/setup-python@v5
58+
with:
59+
python-version: '3.12'
60+
- name: Install build dependencies
61+
run: pip install build twine
62+
- name: Extract version from tag
63+
id: strip-prefix
64+
run: |
65+
TAG="${{ needs.prepare-release.outputs.release_tag }}"
66+
echo "version=${TAG#fa4-v}" >> "$GITHUB_OUTPUT"
67+
- name: Build package
68+
run: python -m build
69+
working-directory: flash_attn/cute
70+
env:
71+
SETUPTOOLS_SCM_PRETEND_VERSION: ${{ steps.strip-prefix.outputs.version }}
72+
- name: Check package metadata
73+
run: twine check dist/*
74+
working-directory: flash_attn/cute
75+
- name: Store distribution packages
76+
uses: actions/upload-artifact@v4
77+
with:
78+
name: python-package-distributions
79+
path: flash_attn/cute/dist/
80+
81+
github-release:
82+
needs: [prepare-release, build]
83+
runs-on: ubuntu-latest
84+
steps:
85+
- name: Download distribution packages
86+
uses: actions/download-artifact@v4
87+
with:
88+
name: python-package-distributions
89+
path: dist/
90+
- name: Create GitHub Release
91+
uses: softprops/action-gh-release@v2
92+
with:
93+
tag_name: ${{ needs.prepare-release.outputs.release_tag }}
94+
files: dist/*
95+
generate_release_notes: true
96+
prerelease: ${{ contains(needs.prepare-release.outputs.release_tag, '.beta') }}
97+
98+
publish-to-pypi:
99+
needs: build
100+
runs-on: ubuntu-latest
101+
environment:
102+
name: pypi
103+
url: https://pypi.org/p/flash-attn-4
104+
permissions:
105+
id-token: write
106+
steps:
107+
- name: Download distribution packages
108+
uses: actions/download-artifact@v4
109+
with:
110+
name: python-package-distributions
111+
path: dist/
112+
- name: Publish to PyPI
113+
uses: pypa/gh-action-pypi-publish@release/v1

.github/workflows/publish.yml

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,13 @@ jobs:
2121
steps:
2222
- name: Get the tag version
2323
id: extract_branch
24-
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
24+
run: echo "branch=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
2525
shell: bash
2626
- name: Create Release
27-
id: create_release
28-
uses: actions/create-release@v1
2927
env:
3028
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
31-
with:
32-
tag_name: ${{ steps.extract_branch.outputs.branch }}
33-
release_name: ${{ steps.extract_branch.outputs.branch }}
29+
run: gh release create ${{ steps.extract_branch.outputs.branch }} --repo $GITHUB_REPOSITORY --title ${{ steps.extract_branch.outputs.branch }} --generate-notes
30+
shell: bash
3431

3532
build_wheels:
3633
name: Build Wheel
@@ -42,24 +39,33 @@ jobs:
4239
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4340
os: [ubuntu-22.04, ubuntu-22.04-arm]
4441
python-version: ["3.10", "3.11", "3.12", "3.13"]
45-
torch-version: ["2.5.1", "2.6.0", "2.7.1", "2.8.0", "2.9.1"]
46-
cuda-version: ["12.9.1"]
42+
torch-version: ["2.6.0", "2.7.1", "2.8.0", "2.9.1", "2.10.0"]
43+
cuda-version: ["12.9.1", "13.0.1"]
4744
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
4845
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
4946
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
5047
# when building without C++11 ABI and using it on nvcr images.
5148
cxx11_abi: ["FALSE", "TRUE"]
52-
include:
53-
- torch-version: "2.9.1"
54-
cuda-version: "13.0.2"
55-
python-version: "3.14"
56-
- torch-version: "2.10.0.dev20251108"
57-
cuda-version: "13.0.2"
5849
exclude:
59-
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
60-
# Pytorch < 2.5 does not support Python 3.13
61-
- torch-version: "2.4.0"
62-
python-version: "3.13"
50+
# CUDA 13.0 is only supported by PyTorch 2.9+
51+
- torch-version: "2.6.0"
52+
cuda-version: "13.0.1"
53+
- torch-version: "2.7.1"
54+
cuda-version: "13.0.1"
55+
- torch-version: "2.8.0"
56+
cuda-version: "13.0.1"
57+
# No aarch64 PyTorch wheels for 2.6.0
58+
- torch-version: "2.6.0"
59+
os: ubuntu-22.04-arm
60+
# PyTorch 2.7+ pip wheels use CXX11_ABI=1 by default, no need for FALSE
61+
- torch-version: "2.7.1"
62+
cxx11_abi: "FALSE"
63+
- torch-version: "2.8.0"
64+
cxx11_abi: "FALSE"
65+
- torch-version: "2.9.1"
66+
cxx11_abi: "FALSE"
67+
- torch-version: "2.10.0"
68+
cxx11_abi: "FALSE"
6369
uses: ./.github/workflows/_build.yml
6470
with:
6571
runs-on: ${{ matrix.os }}

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
*.ncu-rep
2+
*.sass
3+
*.ptx
4+
*.cubin
5+
*.plk
26
.DS_store
37
.vscode
8+
worktrees/
49

510
# Byte-compiled / optimized / DLL files
611
__pycache__/
@@ -33,4 +38,4 @@ var/
3338
venv
3439

3540
# compile-time generated file
36-
flash_attn_config.py
41+
flash_attn_config.py

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@
55
path = csrc/composable_kernel
66
url = https://github.com/ROCm/composable_kernel.git
77
branch = amd-master
8+
[submodule "third_party/aiter"]
9+
path = third_party/aiter
10+
url = https://github.com/ROCm/aiter.git

0 commit comments

Comments
 (0)