Skip to content

Implement row-scaled NVFP4 fprop recipe#2931

Open
zianglih wants to merge 47 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token
Open

Implement row-scaled NVFP4 fprop recipe#2931
zianglih wants to merge 47 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 27, 2026

Description

@HumansAnd

Implement per-token row-scaled NVFP4 recipe with fprop only.
Currently, the row-scaled scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.

The following tests passed on B200:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add a row_scaled_activation field in nvfp4 recipe, can be turned on by NVTE_NVFP4_ROW_SCALED_ACTIVATION
  • New per-token nvfp4 quantize kernels in transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation. New quantization kernels folded into existing nvfp4 quantization kernels.
  • Expand dequant kernel transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh to correctly handle this row-scaled nvfp4
  • In TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if row-scaled nvfp4 is enabled, it conducts separate per-token scaling using pytorch code, after cublas gemm
  • Broad test coverage by expanding 7 python and 2 cpp test files
  • Modify 1d quant reference implementation in tests/cpp/operator/test_cast_nvfp4_transpose.cu to align with pytorch reference numerics

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zianglih zianglih marked this pull request as draft April 27, 2026 06:24
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR implements a row-scaled (per-activation-row) NVFP4 recipe for fprop, controlled via the new NVTE_NVFP4_ROW_SCALED_ACTIVATION env-var. Instead of a single global FP32 amax per tensor, each activation row gets its own amax; the block-level FP8 scales are recomputed relative to that per-row value, and the matching global correction is applied in FP32 after the cuBLAS GEMM.

  • New compute_rowwise_amax CUDA kernel computes one FP32 max-abs per row before the main quantize kernel; quantize_transpose_nvfp4.cuh gains a ROW_SCALED_NVFP4 template branch that uses those per-row values when computing FP8 block scales.
  • general_gemm gains a post-GEMM scaling path that replaces both operands' global amaxes with 1.0 before cuBLAS, then multiplies the FP32 output by per_row_amax_B × scalar_amax_A; the grouped-GEMM variant loops over individual GEMMs to achieve the same effect.
  • All tensor storage classes (NVFP4TensorStorage, GroupedTensorStorage, NVFP4Tensor) and their C++ counterparts propagate the new row_scaled_nvfp4 flag through allocation, copy, and serialisation paths.

Confidence Score: 4/5

Safe to merge for fprop-only workloads on B200; remaining assert-based guards in the GEMM path can produce silent wrong results under optimised Python, but backward is already blocked by a RuntimeError.

The core quantisation math is verified bitwise-exact against the reference implementation and an extensive test suite passes on B200. Two contract checks in the row-scaled GEMM path use Python assert rather than RuntimeError; under -O/-OO the layout check is silently dropped and a wrong-transpose GEMM produces numerically incorrect output with no error.

transformer_engine/pytorch/cpp_extensions/gemm.py — the row-scaled GEMM helper _nvfp4_row_scaled_gemm_inputs and the layout guard inside general_gemm's row-scaled branch.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Adds row-scaled NVFP4 path to general_gemm and general_grouped_gemm; post-GEMM FP32 scale application logic is correct but several API-contract guards use assert (disabled by -O) instead of RuntimeError, including the critical layout check.
transformer_engine/common/recipe/init.py Adds row_scaled_activation field and NVTE_NVFP4_ROW_SCALED_ACTIVATION env-var toggle to NVFP4BlockScaling recipe; change is straightforward and correct.
transformer_engine/pytorch/quantization.py Propagates row_scaled_nvfp4 flag to forward quantizers via idx % 3 != 1 heuristic; backward quantizers correctly hardcode row_scaled_nvfp4=False.
transformer_engine/pytorch/tensor/nvfp4_tensor.py NVFP4Quantizer gains row_scaled_nvfp4 attribute and allocates per-row amax buffer; is_quantizable override returns False with a misleading docstring that obscures the intentional distributed all-gather fallback.
transformer_engine/pytorch/csrc/quantizer.cpp C++ NVFP4Quantizer correctly reads and propagates row_scaled_nvfp4; create_tensor, convert_and_update_tensor, and quantize_impl all validate constraints and allocate per-row amax buffers appropriately.
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh Adds compute_rowwise_amax kernel and ROW_SCALED_NVFP4 template branch to the main quantize kernel; per-row encode/decode scale logic appears correct.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Passes row_scaled_nvfp4 flag through to kernel and selects tensor_amax[y] vs tensor_amax[0] for row indexing; boundary check on amax size is added correctly.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Correctly computes total_amax_elements as sum of flat first dims per tensor for row-scaled case; amax offset slicing in split_into_quantized_tensors is consistent with the allocation.
transformer_engine/common/cast/dispatch/quantize.cuh Calls compute_rowwise_amax before the main quantize kernel for row-scaled path; backward quantize path hard-codes row_scaled_nvfp4=false with an additional NVTE_CHECK guard.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py Reference quantizer gains row_scaled_nvfp4 support with per-row amax computation and zero-guard handling; gemm_ref correctly broadcasts partial_alpha when per-row amaxes are used.

Sequence Diagram

sequenceDiagram
    participant FWD as Forward Pass
    participant Q as NVFP4Quantizer (row_scaled)
    participant K as compute_rowwise_amax kernel
    participant QK as quantize_transpose kernel
    participant G as general_gemm
    participant C as cuBLAS GEMM
    participant S as FP32 post-scale

    FWD->>Q: quantize(activation)
    Q->>K: compute per-row amax
    K-->>Q: amax_rowwise[M]
    Q->>QK: quantize with per-row encode scales
    QK-->>Q: FP4 data + FP8 block scales
    FWD->>G: general_gemm(weight_A, activation_B)
    G->>G: "set amax_A=1 amax_B=1, save rowwise_scales"
    G->>C: "GEMM with global scale=1"
    C-->>G: raw_out (FP32)
    G->>S: "raw_out *= rowwise_scales"
    S-->>FWD: correctly scaled output
Loading

Reviews (12): Last reviewed commit: "Update tests/pytorch/utils.py" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change here to stay aligned with pytorch reference.

@zianglih zianglih marked this pull request as ready for review April 27, 2026 09:14
@zianglih zianglih marked this pull request as draft May 2, 2026 18:22
zianglih and others added 14 commits May 2, 2026 11:27
Signed-off-by: Ziang Li <ziangli@umich.edu>
Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@ziang-and ziang-and force-pushed the fp4-per-token branch 2 times, most recently from 6998f64 to 5b2f606 Compare May 2, 2026 19:10
zianglih added 5 commits May 2, 2026 16:33
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 2, 2026

The following extended tests all passed:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

cd /root/TransformerEngine/tests/cpp
cmake --build build -j200
TEST_BIN="$(find build -type f -name test_operator -perm -u+x | head -n 1)"
"$TEST_BIN" --gtest_filter='*FusedCastTransposeNVFP4*:*DequantizeNVFP4*'
EOF

zianglih added 11 commits May 5, 2026 17:52
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih marked this pull request as ready for review May 6, 2026 06:08
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 6, 2026

Hi @timmoon10 , thank you for your comments! I have refactored accordingly and rename this to row_scaled_nvfp4, no per-token. In recipe we still have row_scaled_activation, which indicates 1d2d gemm where only the activation is modified to 1d. The control flow is now based on bool flags instead of implicitly inferred by shapes.

@zianglih zianglih changed the title Implement per-token NVFP4 fprop recipe Implement row scaled NVFP4 fprop recipe May 6, 2026
@zianglih zianglih changed the title Implement row scaled NVFP4 fprop recipe Implement row-scaled NVFP4 fprop recipe May 6, 2026
@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci L1

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is almost ready. The only important suggestion is to add checks in the core lib GEMM functions so we error out instead of misinterpreting the amax:

void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,

My other suggestions are nits.

Comment thread transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh Outdated
Comment thread transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh Outdated
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
zianglih added 2 commits May 6, 2026 15:37
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 6, 2026

Hi @timmoon10 , I have cleaned up accordingly and resolved a previous ci lint error. Thank you!

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
@timmoon10
Copy link
Copy Markdown
Collaborator

We see a test failure when running on A100: https://github.com/NVIDIA/TransformerEngine/actions/runs/25475147141/job/74746937329
The distributed test failures also show up in the main branch, so they are not blocking.

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented May 7, 2026

Previous L0_pytorch_unittest--A100_1GPU failed due to a minor test file issue. L1_pytorch_distributed_unittest--H100_4GPU failed due to numerics I did not touch. L1_pytorch_distributed_unittest--B200_8GPU reached time limit.

Signed-off-by: Ziang Li <ziangli@umich.edu>
@greptile-apps

This comment was marked as off-topic.

Comment thread tests/pytorch/utils.py Outdated
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci L1

else:
if _is_nvfp4_row_scaled_tensor(A):
raise NotImplementedError("Row-scaled NVFP4 GEMM does not support row-scaled A.")
assert layout[1] == "N", "Row-scaled NVFP4 GEMM currently supports N-layout B only."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Layout constraint silenced by -O produces wrong numerical results

assert layout[1] == "N" is the single most dangerous assertion in this path: if skipped under -O/-OO, the code continues with layout[1] == "T", meaning B is transposed. In that case B._amax_rowwise still contains per-row amaxes from the pre-transposition orientation, so rowwise_global_scales broadcasts the wrong scale over each output row — the GEMM produces silently incorrect values rather than crashing. Unlike the grad check that was already hardened to raise RuntimeError, this constraint should receive the same treatment since a wrong-layout GEMM is numerically undetectable without extra validation.

Comment on lines +86 to +103
A_metadata = A.get_metadata()
weight_amax = A._amax_rowwise if transa else A._amax_columnwise
assert weight_amax is not None and weight_amax.numel() == 1
A_metadata["amax_rowwise" if transa else "amax_columnwise"] = weight_amax.new_ones(1)
A_metadata["row_scaled_nvfp4"] = False

B_metadata = B.get_metadata()
rhs_rowwise_amax = B._amax_rowwise
assert rhs_rowwise_amax is not None
B_metadata["amax_rowwise"] = rhs_rowwise_amax.new_ones(1)
B_metadata["row_scaled_nvfp4"] = False

assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32
return (
NVFP4TensorStorage(**A_metadata),
NVFP4TensorStorage(**B_metadata),
(rhs_rowwise_amax * weight_amax).view(-1, 1),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Contract assertions in helper are disabled by -O, turning clean errors into cryptic AttributeErrors

_nvfp4_row_scaled_gemm_inputs contains three assert statements that enforce API contracts:

  • assert weight_amax is not None and weight_amax.numel() == 1 — if weight_amax is None (e.g., a columnwise-only weight tensor with transa=False) and the assert is stripped, weight_amax.new_ones(1) raises AttributeError: 'NoneType' object has no attribute 'new_ones' deep inside the function.
  • assert rhs_rowwise_amax is not None — same failure mode for B.
  • assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 — wrong dtype silently produces incorrect scaling arithmetic.

These three checks guard the function's entire scaling logic and should be RuntimeError raises just like the grad check already was.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants