[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964
[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964cyanguwa wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Greptile SummaryThis PR replaces the hand-maintained backend selection table in
Confidence Score: 3/5The forward and backward probes build cuDNN graphs against hardcoded format assumptions that can diverge from the actual execution call, creating a risk of false-positive backend acceptance followed by a runtime failure in the kernel. Both is_supported_f16_fwd and is_supported_f16_bwd (and their FP8 counterparts) fix o_format = q_format and dqkv_layout = qkv_layout regardless of what the actual forward/backward invocation uses. When the real layout differs — which is supported by the nvte_fused_attn_fwd/bwd signatures — the cached graph was built for a different configuration. If cuDNN accepts the probed config but rejects the real one, the backend check gives a green light and the actual kernel launch fails. These are load-bearing probe assumptions that affect correctness of backend selection across all mixed-layout configurations. transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu and transformer_engine/common/fused_attn/fused_attn_fp8.cu — the o_format and dqkv_layout hardcoding in the probe functions needs to be resolved to match actual call-site parameters. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller as Python Caller
participant Wrapper as C++ Wrapper
participant Core as nvte_get_fused_attn_backend
participant Probe as is_supported probes
participant cuDNN as cudnn-frontend
Caller->>Wrapper: is_fused_attn_kernel_available(batch, dtype, layout)
Wrapper->>Core: nvte_get_fused_attn_backend with message param
Core->>Core: GetHandle - acquire cuDNN handle
Core->>Core: Pre-filter checks THD format and 64-bit offsets
alt FP8 path
Core->>Probe: "is_supported_fp8_fwd with hardcoded o_format=qkv_format"
Probe->>cuDNN: fp8_fwd_impl with null ptrs to build and cache graph
cuDNN-->>Probe: success or exception
Probe-->>Core: empty or error string
Core->>Probe: "is_supported_fp8_bwd with hardcoded dqkv_layout=qkv_layout"
Probe->>cuDNN: fp8_bwd_impl with null ptrs to build and cache graph
cuDNN-->>Probe: success or exception
Probe-->>Core: empty or error string
else F16/BF16 path
Core->>Core: "CUDA graph guard for cuDNN <= 9.15"
Core->>Probe: "is_supported_f16_fwd with hardcoded o_format=q_format"
Probe->>cuDNN: f16_fwd_impl with null ptrs to build and cache graph
cuDNN-->>Probe: success or exception
Probe-->>Core: empty or error string
Core->>Probe: "is_supported_f16_bwd with hardcoded dqkv_layout=qkv_layout"
Probe->>cuDNN: f16_bwd_impl with null ptrs to build and cache graph
cuDNN-->>Probe: success or exception
Probe-->>Core: empty or error string
end
Core-->>Wrapper: backend enum and message string
Wrapper-->>Caller: tuple of backend and message
Reviews (4): Last reviewed commit: "fix jax binding" | Re-trigger Greptile |
| const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); | ||
| if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && | ||
| qkv_format != NVTE_QKV_Format::NVTE_BHSD) { | ||
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; |
There was a problem hiding this comment.
Compilation error: invalid pointer arithmetic in string construction
NVTE_QKV_Format is an unscoped C enum, so "..." + qkv_format performs pointer arithmetic (advancing the const char* literal by the integer value of the enum). The result is a const char*, and then const char* + "." tries to add two pointers, which is ill-formed in C++. This expression will not compile. The same bug exists on line 1380 (is_supported_fp8_bwd). The fix is to use std::to_string(static_cast<int>(qkv_format)) or construct a std::string explicitly.
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; | |
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + | |
| std::to_string(static_cast<int>(qkv_format)) + "."; | |
| } | |
| size_t workspace_size = 0; |
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; | ||
| } | ||
| const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); |
There was a problem hiding this comment.
Same pointer-arithmetic compilation error as in
is_supported_fp8_fwd — adding qkv_format (an integer-valued enum) to a string literal produces a const char*, and then adding "." yields an ill-formed pointer+pointer expression. Use std::to_string to produce a proper string.
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; | |
| } | |
| const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); | |
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + | |
| std::to_string(static_cast<int>(qkv_format)) + "."; | |
| } | |
| const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); |
| constexpr size_t probe_batch = 1; | ||
| constexpr bool probe_bottom_right_diagonal = false; |
There was a problem hiding this comment.
Hardcoded probe parameters may silently misclassify configurations
probe_batch = 1 and probe_bottom_right_diagonal = false are used for all cudnn-frontend probe calls regardless of the actual values passed by the caller. When attn_mask_type is NVTE_CAUSAL_BOTTOM_RIGHT_MASK or NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, the probe uses false for bottom_right_diagonal even though the real call will use true, which can cause the probe to greenlight configurations that fail at actual execution time.
| thread_local std::string fused_attn_backend_message_buffer; | ||
|
|
||
| void set_message(const char **message, const std::string &reason) { | ||
| if (message == nullptr) return; | ||
| fused_attn_backend_message_buffer = reason; | ||
| *message = fused_attn_backend_message_buffer.c_str(); | ||
| } |
There was a problem hiding this comment.
Thread-local buffer invalidated on next call on the same thread
*message is set to the .c_str() of the thread-local fused_attn_backend_message_buffer. Any subsequent call to nvte_get_fused_attn_backend on the same thread will clear and overwrite this buffer, invalidating the raw pointer before the caller has a chance to copy it. The internal probe calls currently pass nullptr for message, so the buffer won't be clobbered by them, but this invariant is fragile and easy to break in a future change.
| bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, | ||
| const char **message); |
There was a problem hiding this comment.
Breaking public API change:
cudnnHandle_t added as a required parameter
nvte_get_fused_attn_backend is a public C API (no namespace, exported symbol). Adding cudnnHandle_t handle as a required parameter, and including <cudnn.h> in the public header, is a breaking change that forces every downstream consumer to hold and pass a cuDNN handle for what was previously a pure-metadata query.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
| // avoid CUDA graph issue with cuDNN <= 9.15 | ||
| if (cudnn_runtime_version <= 91500 && is_training && | ||
| (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && | ||
| (max_seqlen_kv % 128 != 0) && cuda_graph && | ||
| attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && | ||
| attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && | ||
| attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { | ||
| set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); | ||
| return NVTE_Fused_Attn_Backend::NVTE_No_Backend; | ||
| } |
There was a problem hiding this comment.
CUDA graph guard now incorrectly rejects FP8 configurations
The CUDA graph check at lines 281–290 was previously scoped exclusively to the F16/BF16 branch; moving it to the top-level pre-filter means it now also rejects FP8 + CUDA graph configurations on cuDNN ≤ 9.15.0 that were previously accepted. Users running FP8 training with CUDA graph capture, BSHD/SBHD layout, non-padding masks, and max_seqlen_kv % 128 != 0 on cuDNN ≤ 9.15.0 will see the backend silently downgrade to NVTE_No_Backend where it used to return NVTE_FP8. If the cuDNN bug does not affect FP8, this is a regression; if it does, the guard should be narrowed with a comment explaining why it applies to FP8.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
| const NVTE_QKV_Format do_format = o_format; | ||
| const NVTE_QKV_Layout dqkv_layout = qkv_layout; |
There was a problem hiding this comment.
Hardcoded
dqkv_layout and do_format in backward probe may build incorrect cached graph
is_supported_f16_bwd hardcodes dqkv_layout = qkv_layout and do_format = q_format. nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters and passes them through to the actual backward kernel — but these are never forwarded into nvte_get_fused_attn_backend. When the activation and gradient layouts differ, the probe builds and caches a cuDNN graph for a configuration that won't be used. More critically, if the config with qkv_layout is unsupported but the config with the true dqkv_layout would be supported, backend selection will falsely return NVTE_No_Backend. The same assumption exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.
| const int64_t bias_sq = has_bias ? sq : 0; | ||
| const int64_t bias_skv = has_bias ? skv : 0; | ||
|
|
||
| const NVTE_QKV_Format o_format = q_format; |
There was a problem hiding this comment.
o_format = q_format assumption can mismatch the actual forward call
is_supported_f16_fwd derives o_format from q_format, but nvte_fused_attn_fwd accepts a separate o_format parameter that is not forwarded into nvte_get_fused_attn_backend. When the caller uses a different output format from the query format — such as returning BSHD output from an SBHD_BSHD_BSHD layout — the probe builds a cuDNN graph for the wrong o_format. If the graph with o_format=q_format is accepted but the config with the actual o_format is not (or vice versa), nvte_get_fused_attn_backend produces an incorrect backend decision, causing an error when the actual kernel is invoked.
|
|
||
| // For ragged offsets we only support 32-bit prior to cuDNN 9.5 | ||
| // Only used when THD format is requested. | ||
| cudnnHandle_t handle = cudnnExecutionPlanManager::Instance().GetHandle(); |
There was a problem hiding this comment.
Backend query now builds cuDNN execution graphs as a side effect, altering call semantics
cudnnExecutionPlanManager::Instance().GetHandle() is obtained here, and the subsequent probe calls invoke fused_attn_*_impl with null device pointers, building and caching cuDNN execution plans. This turns a previously cheap metadata query into a graph-build operation that can take hundreds of milliseconds on first call. Code paths that call nvte_get_fused_attn_backend as an availability check — notably FusedAttnHelper.is_fused_attn_kernel_available() from Python/JAX — will pay this cost unexpectedly during model initialisation. The new semantic should be documented in the public API header.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
| const int64_t bias_sq = has_bias ? sq : 0; | ||
| const int64_t bias_skv = has_bias ? skv : 0; | ||
|
|
||
| const NVTE_QKV_Format o_format = q_format; |
There was a problem hiding this comment.
o_format probe hardcoded as q_format, mismatching callers that use a different output format
is_supported_f16_fwd derives o_format = q_format (line 1372) and uses it to build and cache the cuDNN graph. However, nvte_fused_attn_fwd accepts an independent o_format parameter that is never forwarded into nvte_get_fused_attn_backend. When a caller uses an output format different from the query format (e.g. BSHD output from an SBHD_BSHD_BSHD layout), the cached graph was built for q_format, not the actual o_format. If cuDNN accepts the wrong graph but rejects the real one — or vice versa — the backend check produces an incorrect decision, causing an unexpected error at actual kernel invocation.
| const NVTE_QKV_Format o_format = q_format; | ||
| const NVTE_QKV_Format do_format = o_format; | ||
| const NVTE_QKV_Layout dqkv_layout = qkv_layout; |
There was a problem hiding this comment.
Backward probe hardcodes
dqkv_layout = qkv_layout and do_format = o_format, diverging from the actual backward call
is_supported_f16_bwd sets dqkv_layout = qkv_layout and do_format = o_format (where o_format is already fixed to q_format). The actual backward call nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters that are never threaded through nvte_get_fused_attn_backend. When activation and gradient layouts differ, the probe builds a cuDNN graph for a configuration that is never used at runtime. More critically, if the real dqkv_layout is unsupported but the probe's assumed qkv_layout is accepted, the backend check returns NVTE_F16_arbitrary_seqlen and the actual backward pass silently fails. The same issue exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.
Description
This PR replaces the hand-maintained backend selection logic in
nvte_get_fused_attn_backendwith cudnn-frontend's production-grade support checks.nvte_get_fused_attn_backendcall.Type of change
Changes
See Description.
Checklist: