Important: Do NOT commit or push changes without explicit user permission.
First time building TXL wheel for Linux:
# 1. Initialize git submodules
git submodule update --init --recursive
# 2. Download LLVM (if not already present)
# Place llvm-7d5de303-ubuntu-x64/ in project root
# 3. Build Docker image and create wheel
./tools/build-wheel-docker.sh -n
# Wheel will be in: output/txl-3.5.1-cp312-cp312-manylinux_2_35_x86_64.whlBuild flags:
-n- Apply TXL patches to Triton (run once, first time only)-j N- Number of parallel jobs (default: 8)-c- Clean build directories before build--no-cache- Rebuild Docker image without cache
After making code changes, use incremental rebuild (much faster):
# Fast incremental rebuild (uses ninja, only rebuilds changed files)
./tools/build-wheel-docker.sh -r
# Rebuild with clean (full rebuild)
./tools/build-wheel-docker.sh -r -cNotes:
- Conda environment is persisted in
txl-conda/directory - Build artifacts are persisted in
thirdparty/triton/build/ - Uses clang by default (less memory than gcc)
Test TXL wheel on Modal's cloud H100 GPUs:
# Run test with default volume (txl-dump)
./tools/modal_tests.sh flash_attention.py
./tools/modal_tests.sh mla_decoding.py
./tools/modal_tests.sh nsa_prefill.py
# Run with custom test name
./tools/modal_tests.sh flash_attention.py my-test
# Run with custom volume name
./tools/modal_tests.sh flash_attention.py my-test txl-dump
# Output files:
# - docker/dumps/{test_name}_{timestamp}.log - Console output
# - docker/dumps/{test_name}_{timestamp}/ - Dump files (kernel caches)Available Modal test scripts:
docker/flash_attention.py- Flash attention benchmarkdocker/mla_decoding.py- MLA decoding benchmarkdocker/nsa_prefill.py- NSA prefill benchmark (1800s timeout)
Pass debug environment variable to Modal container:
# Run with TXLGPU pipeliner debug
TRITON_LLVM_DEBUG_ONLY=txlgpu-pipeliner ./tools/modal_tests.sh nsa_prefill.py debug-test txl-dumpThe debug output will be in docker/dumps/{test_name}_{timestamp}.log.
Notes:
- All tests save dump files to Modal volume
txl-dump - Dump files are automatically downloaded after test completes
- Use
--forceflag in volume get to overwrite existing local directories
TRITON_LLVM_DEBUG_ONLY="triton-gpu-taskid-propagate" \
TRITON_KERNEL_DUMP=1 \
TRITON_DUMP_DIR=dump \
TRITON_ALWAYS_COMPILE=1 \
python python/txl/tests/fused-attention.pyCUDA_COREDUMP_SHOW_PROGRESS=1 \
CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \
CUDA_LAUNCH_BLOCKING=1 \
cuda-gdbIf build runs out of memory:
# Reduce parallel jobs
./tools/build-wheel-docker.sh -j 4If you get 'GLIBCXX_3.4.30' not found:
conda install -c conda-forge gcc=12.1.0- Don't push for every change - Only push when user explicitly asks
- Use
./tools/build-wheel-docker.sh -rfor incremental rebuilds after code changes - Use
./tools/build-wheel-docker.sh -r -cfor clean rebuilds when build issues occur - The
-nflag should only be used once when setting up the project for the first time - LLVM and conda directories are excluded from git (too large)
This repo has two copies of Triton code:
thirdparty/triton/- Git submodule (main Triton codebase) - MODIFY HEREpatch/triton/- TXL patches applied to Triton - DO NOT MODIFY
Correct workflow:
- Edit code in
thirdparty/triton/(the submodule) - Build and test with
./tools/build-wheel-docker.sh -r - Repeat steps 1-2 until fix is verified
- Only when explicitly requested by user, copy changes to
patch/triton:bash tools/cp_from_triton.sh
- Commit the patch changes
-
Edit in submodule (thirdparty/triton):
# Make changes to code in thirdparty/triton/ -
Build and test:
./tools/build-wheel-docker.sh -r # Run tests -
Repeat until fix is verified
-
Copy to patch/triton (only when user requests):
bash tools/cp_from_triton.sh
-
Commit (only when user requests)
tools/cp_to_triton.sh- Copy patch/triton → thirdparty/tritontools/cp_from_triton.sh- Copy thirdparty/triton → patch/tritontools/diff_triton.py- Compare patch/triton vs thirdparty/triton
- Remove trailing slashes from
cp -rcommands to avoid issues - Submodule changes are tracked separately from main repo
- Use
.gitignorepatterns likellvm-*,txl-conda/for large build artifacts - Never modify patch/triton directly - only copy from thirdparty/triton
When encountering RuntimeError: PassManager::run failed, follow this workflow:
Run test with TXLGPU pipeliner debug flag to see which stage fails:
# Set the debug env var BEFORE running the test script
TRITON_LLVM_DEBUG_ONLY=txlgpu-pipeliner ./tools/modal_tests.sh nsa_prefill.py debug-test txl-dumpThe log will show stages like:
[txlgpu-pipeliner]: SoftwarePipeliner After SmemAllocs
[txlgpu-pipeliner]: DONE
[txlgpu-pipeliner]: SoftwarePipeliner After TmemAllocs
[txlgpu-pipeliner]: DONE
...
[txlgpu-pipeliner]: SoftwarePipeliner After lowerLoads
[txlgpu-pipeliner]: DONE
python: ...Assertion failed...
The crash happens AFTER the last DONE printed.
TXLGPU SoftwarePipeliner.cpp stages (in order):
- After SmemAllocs
- After TmemAllocs
- After Removing RedundantTMEMAllocs
- After Mbars
- After lowerLoads ← Crash happens after this
- After lowerSmemLoadStores
- After MemDesc
- After lowerDotXOps
- After wgmma
The bug is in the pass between the last printed stage and the next stage.
Edit the failing pass in thirdparty/triton/third_party/nvidia/lib/Dialect/TXLGPU/Transforms/SoftwarePipeliner.cpp:
void lowerSmemLoadStores(ModuleOp moduleOp) {
LDBG("[DEBUG] lowerSmemLoadStores: Processing SmemLoadOps\n");
moduleOp->walk([&](tt::SmemLoadOp op) {
lowerSmemLoad(op);
});
LDBG("[DEBUG] lowerSmemLoadStores: Processing SmemStoreOps\n");
// ... add more debug prints
}Then rebuild:
./tools/build-wheel-docker.sh -rThe log file will be at docker/dumps/{test_name}_{timestamp}.log. Look for:
[txlgpu-pipeliner]:- pipeliner stages[DEBUG] lowerSmemLoadStores:- our custom debug printspython: ...Assertion failed- the actual error
To analyze the IR that causes the crash:
- Find line numbers in the log:
grep -n "SoftwarePipeliner After lowerLoads\|DONE" docker/dumps/{test_name}.log- Extract the MLIR between stages:
# Extract lines between "After lowerLoads" and "DONE"
sed -n '5737,6843p' docker/dumps/{test_name}.log > docker/dumps/lowerLoads.mlir- Upload to gist for analysis:
gh gist create docker/dumps/lowerLoads.mlir --public -d "MLIR after lowerLoads"Common issue: DenseElementsAttr type mismatch - when MLIR tries to create a constant with float attribute type that doesn't match tensor element type.
Look for:
- Operations with operands of mixed types (bf16 vs f32)
- Constants where the value type doesn't match the tensor element type
- e.g.,
dense<0xFF800000>with typetensor<64xf32>- the hex value is negative infinity in f32 bit pattern
In practice:
- Crash happened after
SoftwarePipeliner After lowerLoads→DONE - This means bug is in
lowerSmemLoadStoresfunction - Error:
floatAttr.getType() == eltTypeassertion failed - The issue was in how
getRegType()determines types for SmemLoadOp - wrong type caused DenseElementsAttr creation to fail
TXL provides debug utilities in TXLUtility.h:
#include "triton/Analysis/TXLUtility.h"
txlDebugMsg("message", operation);
txlDebugMsg("message", value);
txlDebugMsg("message", type);
txlDebugMsg("message", SmallVector<Value>{...});
txlDebugMsg("message", SmallVector<Operation*>{...});When replacing frag_smem_load with smem_load in NSA kernel, the compilation fails with encoding mismatch errors.
-
Python API difference:
smem_load(mem_desc, layout=None)- layout is optional, can be Nonefrag_smem_load(mem_desc, shape, layout)- layout is required
-
C++ backend behavior:
- When
layoutis None: Python creates a dummyregType(tensor<1x1xi32>) and setstxl.with_reg_type = 0 - When
layoutis provided: Python creates real regType and setstxl.with_reg_type = 1
- When
-
RemoveLayoutConversions pass:
FragSmemLoadOphas canonicalization in Ops.cpp:cvt(frag_smem_load) -> frag_smem_loadSmemLoadOpdid NOT have this canonicalization- This causes redundant layout conversion paths to be created
-
TXLGPUPipeline crash:
- When
with_reg_type = 0,lowerSmemLoaduses dummy regType which has NO encoding - This causes
llvm::dyn_cast<DistributedEncodingTrait>assertion failure
- When
-
Add canonicalization in Ops.cpp (fold ConvertLayout into SmemLoad):
// In lib/Dialect/TritonGPU/IR/Ops.cpp // cvt(smem_load) -> smem_load. if (auto smemLoad = dyn_cast<SmemLoadOp>(arg)) { rewriter.setInsertionPoint(arg); auto newOp = rewriter.replaceOpWithNewOp<SmemLoadOp>(op, op->getResult(0).getType(), smemLoad.getSrc(), smemLoad.getRegType(), smemLoad.getCtaId()); if (auto attr = smemLoad->getAttrOfType<IntegerAttr>("txl.with_reg_type")) { newOp->setAttr("txl.with_reg_type", attr); } return success(); }
-
Fix lowerSmemLoad in SoftwarePipeliner.cpp:
// When with_reg_type = 0, use result type instead of dummy regType if (withRegType == 0) { retType = op.getResult().getType(); // Has actual encoding }
Use diff_mode='ttgir' and diff_select=N in txl.jit to see IR at each pass:
@txl.jit(diff_mode="ttgir", diff_select=10, log_dir="/workspace/dump/smem/")
def txl_mla0(...):diff_select=10shows pass 10's diff- SoftwarePipeliner is pass 22
- Binary search between passes to find which introduces the bug
thirdparty/triton/lib/Dialect/TritonGPU/IR/Ops.cpp- Add SmemLoadOp canonicalizationthirdparty/triton/third_party/nvidia/lib/Dialect/TXLGPU/Transforms/SoftwarePipeliner.cpp- Fix lowerSmemLoad
This section documents the conventions and notation used in CuTeDSL dense GEMM kernels for Blackwell B200 GPU.
In NVIDIA CUTLASS (especially CuTe), variable names like tCrC follow a rigorous naming convention. The name encodes four pieces of information: <prefix><partitioner><storage><data_role>.
| Component | Symbol | Meaning |
|---|---|---|
| Prefix | t |
Thread-level tensor (this thread's private slice) |
| Partitioner | C, A, B |
Thread mapping rules (tC, tA, tB) |
| Storage | r, s, g |
r=Register, s=Shared Memory, g=Global Memory |
| Data Role | A, B, C |
GEMM operands: A=first factor, B=second factor, C=accumulator |
| Variable | Meaning |
|---|---|
tCsA |
Thread's view of A matrix in Shared Memory (partitioned by C rules) |
tCrA |
Thread's fragment of A matrix in Registers (for MMA compute) |
tAgA |
Thread's view of A matrix in Global Memory (for TMA loads) |
tCrC |
Thread's accumulator fragment in Registers |
tCgA |
Thread's partition of A for MMA (Global memory view) |
tDtC |
TMEM source tensor for epilogue copy |
tDgC |
Global memory destination tensor for epilogue copy |
This section documents the custom notation used in the dense_gemm_2.py tutorial to describe tensor shapes at different levels.
| Term | Meaning |
|---|---|
per_mma_atom |
Elements within one MMA instruction (spatial) |
per_mma_tile |
Tile in the MMA grid (spatial) |
per_tma_atom |
Elements within one TMA copy instruction |
per_tma_tile |
Tile for TMA operations (spatial) |
per_wave |
Pipeline stages that can run in parallel (stages dimension) |
per_tide |
Full K loop (runtime iterations) |
per_tmem_atom |
Elements within one TMEM copy instruction |
per_tmem_tile |
Tile for TMEM operations |
# SMEM tensor A: ((128,16),1,4,1)
# My Notation: per_mma_atom((128,16)), per_tma_tile(1,4), per_wave(1)
# - (128,16) = per_mma_atom (MMA instruction shape)
# - (1,4) = per_tma_tile (1 MMA_M_tile, 4 MMA_K_tiles)
# - 1 = per_wave (1 pipeline stage)
# MMA fragment A: (1,1,4,1)
# My Notation: per_mma_atom(1), per_tma_tile(1,4), per_wave(1)
# - 1 = 1 smem descriptor for whole block
# Accumulator: ((128,256),1,1)
# My Notation: per_mma_atom((128,256)), per_tma_tile(1,1)
# - (128,256) = MMA instruction produces 128x256 accumulator
# TMEM source: (((64,32),1),1,((1,4),1,1))
# My Notation: per_tmem_atom((64,32)), per_tmem_tile(1), per_mma_tile((1,4)) per_tma_tile(1,1)
# Register tensor: ((64,1),1)
# My Notation: per_tmem_atom((64,1)), per_tmem_tile(1)The typical dense GEMM kernel follows this structure:
# Allocate SMEM tensors
sA = smem.allocate_tensor(element_type=io_dtype, layout=a_smem_layout.outer, ...)
sB = smem.allocate_tensor(element_type=io_dtype, layout=b_smem_layout.outer, ...)
# Allocate TMEM for accumulator
tmem = utils.TmemAllocator(...)
tmem.allocate(num_tmem_cols)# Tiled MMA
tiled_mma = cute.make_tiled_mma(op)
# TMA atoms
tma_atom_a = cute.make_tma_atom(...)
tma_atom_b = cute.make_tma_atom(...)# Partition GMEM tensors for MMA
tCgA = thr_mma.partition_A(gA)
tCgB = thr_mma.partition_B(gB)
tCgC = thr_mma.partition_C(gC)
# Create MMA fragments (from SMEM)
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
tCtAcc = tiled_mma.make_fragment_C(acc_shape)
# Create TMA descriptors
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(tma_atom_a, ...)
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(tma_atom_b, ...)for k_tile_idx in range(num_k_tiles):
# Issue TMA loads
ab_empty = ab_producer.acquire_and_advance()
cute.copy(tma_atom_a, tAgA[(None, ab_empty.count)], tAsA[(None, ab_empty.index)], ...)
cute.copy(tma_atom_b, ...)
# Execute MMA
ab_full = ab_consumer.wait_and_advance()
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
cute.gemm(tiled_mma, tCtAcc, tCrA[k_block_coord], tCrB[k_block_coord], tCtAcc)# Sub-tiling for ILP
for i in cutlass.range(cute.size(tDtC, mode=[2])):
cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc) # TMEM -> Reg
tCrC.store(tCrAcc.load().to(io_dtype)) # Convert dtype
cute.autovec_copy(tCrC, tDgC[None, None, i]) # Reg -> GMEM# Run default version (dense_gemm_2.py - simplified)
./tools/modal_tests.sh tutorials/cuteDSL/01_dense_gemm.py
# Available versions:
# - 0: dense_gemm_0.py (4-stage pipelining)
# - 1: dense_gemm_1.py (1-stage pipelining)
# - 2: dense_gemm_2.py (simplified with detailed comments)
# - original: dense_gemm.py (full-featured)docker/tutorials/cuteDSL/
├── 01_dense_gemm.py # Modal test script
└── blackwell/
├── dense_gemm.py # Original full-featured
├── dense_gemm_0.py # 4-stage pipelining
├── dense_gemm_1.py # 1-stage pipelining
└── dense_gemm_2.py # Simplified with detailed comments
thirdparty/cutlass/python/CuTeDSL/cutlass/
├── utils/blackwell_helpers.py # make_smem_layout_a/b
└── cute/nvgpu/cpasync/helpers.py # tma_partition
Use cute.printf to verify tensor shapes during kernel execution:
if tidx == 0:
cute.printf("tensor shape: {}\n", tensor.shape)smem_descrepresents the whole block, not decomposed like MMA fragmentstma_partitionrequires input tensors folded in shape(Each_Iter, Num_Iters)subtile_cntin epilogue controls ILP (typically 4 for fp16)
This section documents the low-level mbarrier implementation in dense_gemm_2.py, which replaces the high-level PipelineTmaUmma API.
| PipelineOp | Mbarrier Type | Operation | expect_tx? |
|---|---|---|---|
PipelineOp.TmaLoad |
Full | mbarrier_arrive_and_expect_tx(bytes) |
YES |
PipelineOp.TCGen05Mma |
Empty | tcgen05.commit() |
NO |
ab_mbar_ptr: [ab_mbar_full, ab_mbar_empty]
│ │
▼ ▼
index 0 index 1
- Full mbarrier: Signals data is ready (TMA→MMA), uses expect_tx with bytes
- Empty mbarrier: Signals buffer is consumed (MMA→TMA), just sync signal
phase = 1 # Toggle between 0 and 1
for k_tile_idx in range(num_k_tiles):
# TMA Load: wait for empty buffer
cute.arch.mbarrier_wait(ab_mbar_empty, phase)
# TMA loads (single-stage: GMEM[k_tile_idx] → SMEM[0])
cute.copy(tma_atom_a, tAgA[(None, k_tile_idx)], tAsA[(None, 0)], tma_bar_ptr=ab_mbar_full)
cute.copy(tma_atom_b, ...)
# TMA arrives on full: elect_one + expect_tx
with cute.arch.elect_one():
cute.arch.mbarrier_arrive_and_expect_tx(ab_mbar_full, num_tma_copy_bytes)
# MMA: wait for full buffer
cute.arch.mbarrier_wait(ab_mbar_full, 1 - phase)
# MMA compute
cute.gemm(...)
# MMA arrives on empty: elect_one + tcgen05.commit (NO expect_tx)
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(ab_mbar_empty)
# Toggle phase
phase = 1 - phase-
Always use
elect_one()for arrive operations:mbarrier_arrive_and_expect_txneeds elect_onetcgen05.commitneeds elect_one
-
Single-stage indexing (ab_stages=1):
- SMEM index always 0 (only one buffer)
- GMEM index uses
k_tile_idx - Phase toggles for mbarrier coordination only
-
Accurate mbarrier wait:
- Wait for full mbarrier with
1 - phase(the opposite phase) - This ensures proper synchronization between TMA and MMA
- Wait for full mbarrier with
# Signal MMA done
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(acc_mbar_ptr)
# Wait for MMA to complete
cute.arch.mbarrier_wait(acc_mbar_ptr, phase=0)