Implement row-scaled NVFP4 fprop recipe#2931
Conversation
Greptile SummaryThis PR implements a row-scaled (per-activation-row) NVFP4 recipe for fprop, controlled via the new
Confidence Score: 4/5Safe to merge for fprop-only workloads on B200; remaining assert-based guards in the GEMM path can produce silent wrong results under optimised Python, but backward is already blocked by a RuntimeError. The core quantisation math is verified bitwise-exact against the reference implementation and an extensive test suite passes on B200. Two contract checks in the row-scaled GEMM path use Python assert rather than RuntimeError; under -O/-OO the layout check is silently dropped and a wrong-transpose GEMM produces numerically incorrect output with no error. transformer_engine/pytorch/cpp_extensions/gemm.py — the row-scaled GEMM helper _nvfp4_row_scaled_gemm_inputs and the layout guard inside general_gemm's row-scaled branch. Important Files Changed
Sequence DiagramsequenceDiagram
participant FWD as Forward Pass
participant Q as NVFP4Quantizer (row_scaled)
participant K as compute_rowwise_amax kernel
participant QK as quantize_transpose kernel
participant G as general_gemm
participant C as cuBLAS GEMM
participant S as FP32 post-scale
FWD->>Q: quantize(activation)
Q->>K: compute per-row amax
K-->>Q: amax_rowwise[M]
Q->>QK: quantize with per-row encode scales
QK-->>Q: FP4 data + FP8 block scales
FWD->>G: general_gemm(weight_A, activation_B)
G->>G: "set amax_A=1 amax_B=1, save rowwise_scales"
G->>C: "GEMM with global scale=1"
C-->>G: raw_out (FP32)
G->>S: "raw_out *= rowwise_scales"
S-->>FWD: correctly scaled output
Reviews (12): Last reviewed commit: "Update tests/pytorch/utils.py" | Re-trigger Greptile |
| // Compute "correct" per-block encoding scaling factor | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : | ||
| fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm); |
There was a problem hiding this comment.
We have to change here to stay aligned with pytorch reference.
Signed-off-by: Ziang Li <ziangli@umich.edu> Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
6998f64 to
5b2f606
Compare
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
The following extended tests all passed: |
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
Hi @timmoon10 , thank you for your comments! I have refactored accordingly and rename this to |
|
/te-ci L1 |
timmoon10
left a comment
There was a problem hiding this comment.
This is almost ready. The only important suggestion is to add checks in the core lib GEMM functions so we error out instead of misinterpreting the amax:
My other suggestions are nits.
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
Hi @timmoon10 , I have cleaned up accordingly and resolved a previous ci lint error. Thank you! |
|
We see a test failure when running on A100: https://github.com/NVIDIA/TransformerEngine/actions/runs/25475147141/job/74746937329 |
|
Previous L0_pytorch_unittest--A100_1GPU failed due to a minor test file issue. L1_pytorch_distributed_unittest--H100_4GPU failed due to numerics I did not touch. L1_pytorch_distributed_unittest--B200_8GPU reached time limit. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
This comment was marked as off-topic.
This comment was marked as off-topic.
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Ziang Li <ziangli@umich.edu>
|
/te-ci L1 |
| else: | ||
| if _is_nvfp4_row_scaled_tensor(A): | ||
| raise NotImplementedError("Row-scaled NVFP4 GEMM does not support row-scaled A.") | ||
| assert layout[1] == "N", "Row-scaled NVFP4 GEMM currently supports N-layout B only." |
There was a problem hiding this comment.
Layout constraint silenced by
-O produces wrong numerical results
assert layout[1] == "N" is the single most dangerous assertion in this path: if skipped under -O/-OO, the code continues with layout[1] == "T", meaning B is transposed. In that case B._amax_rowwise still contains per-row amaxes from the pre-transposition orientation, so rowwise_global_scales broadcasts the wrong scale over each output row — the GEMM produces silently incorrect values rather than crashing. Unlike the grad check that was already hardened to raise RuntimeError, this constraint should receive the same treatment since a wrong-layout GEMM is numerically undetectable without extra validation.
| A_metadata = A.get_metadata() | ||
| weight_amax = A._amax_rowwise if transa else A._amax_columnwise | ||
| assert weight_amax is not None and weight_amax.numel() == 1 | ||
| A_metadata["amax_rowwise" if transa else "amax_columnwise"] = weight_amax.new_ones(1) | ||
| A_metadata["row_scaled_nvfp4"] = False | ||
|
|
||
| B_metadata = B.get_metadata() | ||
| rhs_rowwise_amax = B._amax_rowwise | ||
| assert rhs_rowwise_amax is not None | ||
| B_metadata["amax_rowwise"] = rhs_rowwise_amax.new_ones(1) | ||
| B_metadata["row_scaled_nvfp4"] = False | ||
|
|
||
| assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 | ||
| return ( | ||
| NVFP4TensorStorage(**A_metadata), | ||
| NVFP4TensorStorage(**B_metadata), | ||
| (rhs_rowwise_amax * weight_amax).view(-1, 1), | ||
| ) |
There was a problem hiding this comment.
Contract assertions in helper are disabled by
-O, turning clean errors into cryptic AttributeErrors
_nvfp4_row_scaled_gemm_inputs contains three assert statements that enforce API contracts:
assert weight_amax is not None and weight_amax.numel() == 1— ifweight_amaxisNone(e.g., a columnwise-only weight tensor withtransa=False) and the assert is stripped,weight_amax.new_ones(1)raisesAttributeError: 'NoneType' object has no attribute 'new_ones'deep inside the function.assert rhs_rowwise_amax is not None— same failure mode forB.assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32— wrong dtype silently produces incorrect scaling arithmetic.
These three checks guard the function's entire scaling logic and should be RuntimeError raises just like the grad check already was.
Description
@HumansAnd
Implement
per-tokenrow-scaled NVFP4 recipe with fprop only.Currently, the row-scaled scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.
The following tests passed on B200:
Type of change
Changes
Please list the changes introduced in this PR:
row_scaled_activationfield in nvfp4 recipe, can be turned on byNVTE_NVFP4_ROW_SCALED_ACTIVATIONNew per-token nvfp4 quantize kernels inNew quantization kernels folded into existing nvfp4 quantization kernels.transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuhto correctly handle this row-scaled nvfp4TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if row-scaled nvfp4 is enabled, it conducts separate per-token scaling using pytorch code, after cublas gemmtests/cpp/operator/test_cast_nvfp4_transpose.cuto align with pytorch reference numericsChecklist: