Skip to content

[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964

Open
cyanguwa wants to merge 14 commits intoNVIDIA:mainfrom
cyanguwa:fe_check_support
Open

[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964
cyanguwa wants to merge 14 commits intoNVIDIA:mainfrom
cyanguwa:fe_check_support

Conversation

@cyanguwa
Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa commented May 6, 2026

Description

This PR replaces the hand-maintained backend selection logic in nvte_get_fused_attn_backend with cudnn-frontend's production-grade support checks.

  • It will build the same graph that runtime uses in execution, cache the graph if the build is successful, and provide a warmed-up cache for the next nvte_get_fused_attn_backend call.
  • It provides a cleaner dispatch logic, avoids accidental regressions, and helps TE to stay in sync with cudnn-frontend's support surface.
  • It also provides appropriate error messaging for when no backend is available, so users are notified and make config/architecture/cudnn version adjustments.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

See Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

cyanguwa and others added 4 commits May 5, 2026 18:55
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa changed the title [Common] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls [All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls May 8, 2026
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa marked this pull request as ready for review May 8, 2026 00:10
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR replaces the hand-maintained backend selection table in nvte_get_fused_attn_backend with actual cudnn-frontend probe calls (is_supported_f16_fwd/bwd, is_supported_fp8_fwd/bwd), building and caching the cuDNN execution graph during the availability check to warm up the cache and eliminate accidental regressions against cuDNN's real support surface.

  • Core dispatch (fused_attn.cpp): pre-filter checks (THD format, 64-bit ragged offsets, cuDNN ≤ 9.15 CUDA graph bug) now gate actual cudnn-frontend probe calls per dtype path; a thread-local diagnostic string is populated and returned via an output message parameter.
  • New probe functions: is_supported_f16_fwd/bwd and is_supported_fp8_fwd/bwd invoke the real impl functions with null device pointers to build and cache the cuDNN graph; both probes hardcode o_format = q_format and dqkv_layout = qkv_layout, which may diverge from actual forward/backward call parameters.
  • Public API and Python bindings: nvte_get_fused_attn_backend gains several required parameters; PyTorch and JAX C++ wrappers propagate these and return (backend, message) tuples; Python-side FusedAttnHelper gains batch_size and bottom_right_diagonal fields but is_fused_attn_kernel_available discards the diagnostic message before it reaches the Flax transformer warning.

Confidence Score: 3/5

The forward and backward probes build cuDNN graphs against hardcoded format assumptions that can diverge from the actual execution call, creating a risk of false-positive backend acceptance followed by a runtime failure in the kernel.

Both is_supported_f16_fwd and is_supported_f16_bwd (and their FP8 counterparts) fix o_format = q_format and dqkv_layout = qkv_layout regardless of what the actual forward/backward invocation uses. When the real layout differs — which is supported by the nvte_fused_attn_fwd/bwd signatures — the cached graph was built for a different configuration. If cuDNN accepts the probed config but rejects the real one, the backend check gives a green light and the actual kernel launch fails. These are load-bearing probe assumptions that affect correctness of backend selection across all mixed-layout configurations.

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu and transformer_engine/common/fused_attn/fused_attn_fp8.cu — the o_format and dqkv_layout hardcoding in the probe functions needs to be resolved to match actual call-site parameters.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Core dispatch logic refactored to delegate backend selection to cudnn-frontend probe calls; CUDA graph guard now scoped to F16/BF16 path only; FP8 path still hardcodes o_format
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu New is_supported_f16_fwd/bwd probe functions introduced; both hardcode o_format=q_format and dqkv_layout=qkv_layout, diverging from actual forward/backward call parameters
transformer_engine/common/fused_attn/fused_attn_fp8.cu New is_supported_fp8_fwd/bwd probe functions added; bwd probe hardcodes dqkv_layout=qkv_layout, same assumption mismatch as in f16 probes
transformer_engine/common/include/transformer_engine/fused_attn.h Public C API updated with new required parameters (batch_size, o_dtype, bottom_right_diagonal, return_max_logit, cuda_graph, deterministic, message); no cudnnHandle_t in public signature; breaking change for downstream consumers
transformer_engine/jax/cpp_extensions/attention.py FusedAttnHelper gains batch_size and bottom_right_diagonal fields; get_fused_attn_backend now returns (backend, message) tuple; is_fused_attn_kernel_available discards the message
transformer_engine/jax/attention.py is_fused_attn_kernel_available gains batch_size parameter; diagnostic message from backend query is discarded and not propagated to callers
transformer_engine/jax/flax/transformer.py Passes batch_size to is_fused_attn_kernel_available; fallback warning still generic, does not include the rejection reason from the backend query
transformer_engine/pytorch/csrc/extensions/attention.cpp PyTorch wrapper updated to forward all new parameters to nvte_get_fused_attn_backend; returns (backend, message) tuple correctly
transformer_engine/jax/csrc/extensions/attention.cpp JAX C++ wrapper hardcodes return_max_logit=false and cuda_graph=false for the backend query; returns (backend, message) tuple correctly
tests/jax/test_fused_attn.py Tests updated to pass batch_size and bottom_right_diagonal to FusedAttnHelper; now unpacks (backend, message) and uses message in pytest.skip

Sequence Diagram

sequenceDiagram
    participant Caller as Python Caller
    participant Wrapper as C++ Wrapper
    participant Core as nvte_get_fused_attn_backend
    participant Probe as is_supported probes
    participant cuDNN as cudnn-frontend

    Caller->>Wrapper: is_fused_attn_kernel_available(batch, dtype, layout)
    Wrapper->>Core: nvte_get_fused_attn_backend with message param
    Core->>Core: GetHandle - acquire cuDNN handle
    Core->>Core: Pre-filter checks THD format and 64-bit offsets

    alt FP8 path
        Core->>Probe: "is_supported_fp8_fwd with hardcoded o_format=qkv_format"
        Probe->>cuDNN: fp8_fwd_impl with null ptrs to build and cache graph
        cuDNN-->>Probe: success or exception
        Probe-->>Core: empty or error string
        Core->>Probe: "is_supported_fp8_bwd with hardcoded dqkv_layout=qkv_layout"
        Probe->>cuDNN: fp8_bwd_impl with null ptrs to build and cache graph
        cuDNN-->>Probe: success or exception
        Probe-->>Core: empty or error string
    else F16/BF16 path
        Core->>Core: "CUDA graph guard for cuDNN <= 9.15"
        Core->>Probe: "is_supported_f16_fwd with hardcoded o_format=q_format"
        Probe->>cuDNN: f16_fwd_impl with null ptrs to build and cache graph
        cuDNN-->>Probe: success or exception
        Probe-->>Core: empty or error string
        Core->>Probe: "is_supported_f16_bwd with hardcoded dqkv_layout=qkv_layout"
        Probe->>cuDNN: f16_bwd_impl with null ptrs to build and cache graph
        cuDNN-->>Probe: success or exception
        Probe-->>Core: empty or error string
    end

    Core-->>Wrapper: backend enum and message string
    Wrapper-->>Caller: tuple of backend and message
Loading

Reviews (4): Last reviewed commit: "fix jax binding" | Re-trigger Greptile

const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD &&
qkv_format != NVTE_QKV_Format::NVTE_BHSD) {
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Compilation error: invalid pointer arithmetic in string construction

NVTE_QKV_Format is an unscoped C enum, so "..." + qkv_format performs pointer arithmetic (advancing the const char* literal by the integer value of the enum). The result is a const char*, and then const char* + "." tries to add two pointers, which is ill-formed in C++. This expression will not compile. The same bug exists on line 1380 (is_supported_fp8_bwd). The fix is to use std::to_string(static_cast<int>(qkv_format)) or construct a std::string explicitly.

Suggested change
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " +
std::to_string(static_cast<int>(qkv_format)) + ".";
}
size_t workspace_size = 0;

Comment on lines +1380 to +1382
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
}
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Same pointer-arithmetic compilation error as in is_supported_fp8_fwd — adding qkv_format (an integer-valued enum) to a string literal produces a const char*, and then adding "." yields an ill-formed pointer+pointer expression. Use std::to_string to produce a proper string.

Suggested change
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
}
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype);
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " +
std::to_string(static_cast<int>(qkv_format)) + ".";
}
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype);

Comment on lines +292 to +293
constexpr size_t probe_batch = 1;
constexpr bool probe_bottom_right_diagonal = false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded probe parameters may silently misclassify configurations

probe_batch = 1 and probe_bottom_right_diagonal = false are used for all cudnn-frontend probe calls regardless of the actual values passed by the caller. When attn_mask_type is NVTE_CAUSAL_BOTTOM_RIGHT_MASK or NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, the probe uses false for bottom_right_diagonal even though the real call will use true, which can cause the probe to greenlight configurations that fail at actual execution time.

Comment on lines +232 to +238
thread_local std::string fused_attn_backend_message_buffer;

void set_message(const char **message, const std::string &reason) {
if (message == nullptr) return;
fused_attn_backend_message_buffer = reason;
*message = fused_attn_backend_message_buffer.c_str();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Thread-local buffer invalidated on next call on the same thread

*message is set to the .c_str() of the thread-local fused_attn_backend_message_buffer. Any subsequent call to nvte_get_fused_attn_backend on the same thread will clear and overwrite this buffer, invalidating the raw pointer before the caller has a chance to copy it. The internal probe calls currently pass nullptr for message, so the buffer won't be clobbered by them, but this invariant is fragile and easy to break in a future change.

Comment on lines +232 to +233
bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle,
const char **message);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public API change: cudnnHandle_t added as a required parameter

nvte_get_fused_attn_backend is a public C API (no namespace, exported symbol). Adding cudnnHandle_t handle as a required parameter, and including <cudnn.h> in the public header, is a breaking change that forces every downstream consumer to hold and pass a cuDNN handle for what was previously a pure-metadata query.

cyanguwa and others added 2 commits May 7, 2026 17:22
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Comment on lines +281 to +290
// avoid CUDA graph issue with cuDNN <= 9.15
if (cudnn_runtime_version <= 91500 && is_training &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
(max_seqlen_kv % 128 != 0) && cuda_graph &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) {
set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN.");
return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 CUDA graph guard now incorrectly rejects FP8 configurations

The CUDA graph check at lines 281–290 was previously scoped exclusively to the F16/BF16 branch; moving it to the top-level pre-filter means it now also rejects FP8 + CUDA graph configurations on cuDNN ≤ 9.15.0 that were previously accepted. Users running FP8 training with CUDA graph capture, BSHD/SBHD layout, non-padding masks, and max_seqlen_kv % 128 != 0 on cuDNN ≤ 9.15.0 will see the backend silently downgrade to NVTE_No_Backend where it used to return NVTE_FP8. If the cuDNN bug does not affect FP8, this is a regression; if it does, the guard should be narrowed with a comment explaining why it applies to FP8.

cyanguwa and others added 3 commits May 7, 2026 18:30
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

cyanguwa commented May 8, 2026

/te-ci L1

Comment on lines +1427 to +1428
const NVTE_QKV_Format do_format = o_format;
const NVTE_QKV_Layout dqkv_layout = qkv_layout;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded dqkv_layout and do_format in backward probe may build incorrect cached graph

is_supported_f16_bwd hardcodes dqkv_layout = qkv_layout and do_format = q_format. nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters and passes them through to the actual backward kernel — but these are never forwarded into nvte_get_fused_attn_backend. When the activation and gradient layouts differ, the probe builds and caches a cuDNN graph for a configuration that won't be used. More critically, if the config with qkv_layout is unsupported but the config with the true dqkv_layout would be supported, backend selection will falsely return NVTE_No_Backend. The same assumption exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.

const int64_t bias_sq = has_bias ? sq : 0;
const int64_t bias_skv = has_bias ? skv : 0;

const NVTE_QKV_Format o_format = q_format;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 o_format = q_format assumption can mismatch the actual forward call

is_supported_f16_fwd derives o_format from q_format, but nvte_fused_attn_fwd accepts a separate o_format parameter that is not forwarded into nvte_get_fused_attn_backend. When the caller uses a different output format from the query format — such as returning BSHD output from an SBHD_BSHD_BSHD layout — the probe builds a cuDNN graph for the wrong o_format. If the graph with o_format=q_format is accepted but the config with the actual o_format is not (or vice versa), nvte_get_fused_attn_backend produces an incorrect backend decision, causing an error when the actual kernel is invoked.


// For ragged offsets we only support 32-bit prior to cuDNN 9.5
// Only used when THD format is requested.
cudnnHandle_t handle = cudnnExecutionPlanManager::Instance().GetHandle();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Backend query now builds cuDNN execution graphs as a side effect, altering call semantics

cudnnExecutionPlanManager::Instance().GetHandle() is obtained here, and the subsequent probe calls invoke fused_attn_*_impl with null device pointers, building and caching cuDNN execution plans. This turns a previously cheap metadata query into a graph-build operation that can take hundreds of milliseconds on first call. Code paths that call nvte_get_fused_attn_backend as an availability check — notably FusedAttnHelper.is_fused_attn_kernel_available() from Python/JAX — will pay this cost unexpectedly during model initialisation. The new semantic should be documented in the public API header.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
const int64_t bias_sq = has_bias ? sq : 0;
const int64_t bias_skv = has_bias ? skv : 0;

const NVTE_QKV_Format o_format = q_format;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 o_format probe hardcoded as q_format, mismatching callers that use a different output format

is_supported_f16_fwd derives o_format = q_format (line 1372) and uses it to build and cache the cuDNN graph. However, nvte_fused_attn_fwd accepts an independent o_format parameter that is never forwarded into nvte_get_fused_attn_backend. When a caller uses an output format different from the query format (e.g. BSHD output from an SBHD_BSHD_BSHD layout), the cached graph was built for q_format, not the actual o_format. If cuDNN accepts the wrong graph but rejects the real one — or vice versa — the backend check produces an incorrect decision, causing an unexpected error at actual kernel invocation.

Comment on lines +1426 to +1428
const NVTE_QKV_Format o_format = q_format;
const NVTE_QKV_Format do_format = o_format;
const NVTE_QKV_Layout dqkv_layout = qkv_layout;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Backward probe hardcodes dqkv_layout = qkv_layout and do_format = o_format, diverging from the actual backward call

is_supported_f16_bwd sets dqkv_layout = qkv_layout and do_format = o_format (where o_format is already fixed to q_format). The actual backward call nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters that are never threaded through nvte_get_fused_attn_backend. When activation and gradient layouts differ, the probe builds a cuDNN graph for a configuration that is never used at runtime. More critically, if the real dqkv_layout is unsupported but the probe's assumed qkv_layout is accepted, the backend check returns NVTE_F16_arbitrary_seqlen and the actual backward pass silently fails. The same issue exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant