Skip to content

Add CUDA graph capture/replay for qwen 3.5 moe decode method#18809

Open
Gasoonjia wants to merge 30 commits intomainfrom
cuda-graph
Open

Add CUDA graph capture/replay for qwen 3.5 moe decode method#18809
Gasoonjia wants to merge 30 commits intomainfrom
cuda-graph

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 10, 2026

Problem: In the ExecuTorch CUDA backend, each decode step re-launches all GPU kernels individually, incurring significant CPU-side kernel launch overhead. For autoregressive LLM decoding (single-token steps), this launch overhead dominates end-to-end latency.

Solution: Added CUDA graph support to the CUDA backend. When enabled for a method (e.g., decode), the backend:

  • Uses the first three decode iteration as warmup to stabilize GPU state (lazy allocations, cuBLAS handles, kernel JIT)
  • Captures the 4th execution into a CUDA graph
  • Replays the captured graph on all subsequent decode calls, eliminating per-step kernel launch overhead

The feature is opt-in via a runtime backend option (--cuda_graph flag) and is transparent to the model — no changes to export or model code required. Decode performance improved from 88.1 token/s to 113.8 token/s.

Gasoonjia and others added 20 commits April 1, 2026 23:06
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode.
Replace with plain PyTorch einsum ops that Inductor can fuse:
- FLA GPU time: 1.085ms → 0.344ms/step (-68%)
- Total GPU time: 12.0ms → 9.0ms/step (-25%)
- Export changed to static T=1 with enable_dynamic_shape=False
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op
instead of using torch.cond at model level. This follows the same pattern
as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond
incompatibility with AOTI's FunctionalTensor pipeline.

Changes:
- chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1,
  refactor chunked pipeline into _launch_chunked(), dispatch via Python
  if inside the @triton_op wrapper
- model.py: Remove torch.cond from GatedDeltaNet.forward(), call
  triton_op directly (dispatch is internal)
- export.py: Single-method export with dynamic seq_len dim
- main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic
is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
  reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
  T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
  paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM
  occupancy (128 blocks vs 32 on A100)
- BV reduced from 128 to 32 — lower register pressure, no spilling
- Removed unnecessary .contiguous() copies on squeezed inputs
- Removed debug print from triton_op dispatch
- GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch
  recurrent delta rule) methods with explicit state passing
- Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap
  options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect
- Keep state tensors GPU-resident across method calls; only copy logits to CPU
  for sampling via cudaMemcpy
- Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline)

Modified files:
- cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream
- main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper
- CMakeLists.txt: link CUDA::cudart for cudaMemcpy
- model.py: dual-method model definition (prefill + decode)
- export.py: export script for dual-method PTE
Revert from explicit state passing back to registered buffers with
in-place updates (KVCache, conv_state, recurrent_state). Export with
share_mutable_buffers=True so both prefill and forward methods share
mutable state via mem_id=2. C++ runner uses share_memory_arenas=true
and only passes (tokens, input_pos) — no CUDA runtime dependency.

Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile,
65 D2H memcpy (logits only).
Add runtime buffer sharing between AOTI containers so that prefill and
decode methods operate on the same GPU tensors (KV cache, conv_state,
etc.) without unnecessary H2D/D2H copies or getter/setter overhead.

The first container to initialize extracts its constants (keyed by
original FQN). Subsequent containers with matching FQNs are updated via
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs to point
to the same GPU memory (user_managed = true, no copy).

Also switch main.cpp prefill to token-by-token decode path while the
chunked FLA triton_op numerical issue is being resolved.

Tested E2E: "What is the capital of France?" → "Paris" with 966
constants shared between prefill and decode containers on A100.
- cuda_backend.cpp: Use codegen name (from GetConstantName) instead of
  original FQN when calling UpdateUserManagedConstantBufferPairs. The AOTI
  API matches against internal codegen names, not FQNs — using FQNs caused
  silent no-op sharing, breaking KV cache flow between prefill and decode.

- main.cpp: Add chunked prefill path using the "prefill" method (T>=2) with
  cudaDeviceSynchronize between prefill and decode for cross-stream safety.
  Add --decode_only flag to fall back to token-by-token decode for all tokens.

- inference.py: Update docstring to reflect that chunked FLA is used in PTE
  mode (not eager).

Verified E2E: "What is the capital of France?" → "The capital of France is Paris."
Prefill: 105 tok/s (chunked FLA), Decode: 87 tok/s (recurrent delta rule).
- cuda_backend.cpp: Replace debug printf with ET_LOG for errors/info only
- main.cpp: Remove --decode_only flag, keep only chunked prefill path
- cuda_backend.cpp: Replace ET_CHECK_OK_OR_RETURN_ERROR with explicit error
  handling + cudaDeviceSynchronize after weight transfer, add logging for
  missing weights_blob
- main.cpp: Support single "forward" method fallback when prefill/decode
  not available, use prefill_method variable, remove debug printf
Implements CUDA graph support in the CUDA backend to reduce CPU kernel
launch overhead during autoregressive decoding:

- cuda_backend.cpp: 3-phase execution (warmup → capture → replay) with
  static input/output GPU buffers, cudaMemcpyAsync for I/O, and
  cudaGraphInstantiateFlagAutoFreeOnLaunch for cudaMallocAsync compat
- cuda_delegate_handle.h: CUDA graph state (phase, graph objects, static
  buffer metadata) with RAII cleanup in destructor
- main.cpp: --cuda_graph flag that sets BackendOptions before load_method
- test_model_e2e.sh: Enable --cuda_graph for Qwen3.5 MoE CI, set
  PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync

Benchmark (A100, Qwen3.5-35B-A3B HQQ-INT4): 82→98 tok/s (1.20x)
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18809

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures, 2 Cancelled Jobs, 5 Unrelated Failures

As of commit 8fc7355 with merge base 2eff4f4 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 10, 2026
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 10, 2026 06:45 — with GitHub Actions Inactive
Comment thread backends/cuda/runtime/cuda_backend.cpp Outdated
nbytes, cudaMemcpyHostToDevice);

handle->static_input_ptrs.push_back(static_ptr);
handle->static_input_sizes.push_back(sizes_vec);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean a method must have all static shapes for its intermediate tensors?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the cuda graph can only support a static pipeline.

Base automatically changed from recurrent-fla to main April 10, 2026 20:44
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 10, 2026 23:42 — with GitHub Actions Inactive
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 13, 2026 09:41 — with GitHub Actions Inactive
@Gasoonjia Gasoonjia marked this pull request as ready for review April 13, 2026 16:23
@Gasoonjia Gasoonjia requested a review from lucylq as a code owner April 13, 2026 16:23
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 14, 2026 07:41 — with GitHub Actions Inactive
@Gasoonjia
Copy link
Copy Markdown
Contributor Author

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Apr 15, 2026

Claude finished @Gasoonjia's task in 2m 39s —— View job


Review of PR #18809 — CUDA Graph Capture/Replay for Qwen 3.5 MoE Decode

  • Read changed files and gather context
  • Review cuda_delegate_handle.h changes
  • Review cuda_backend.cpp changes
  • Review main.cpp changes
  • Review test_model_e2e.sh changes
  • Post comprehensive review

Overall this is a well-structured addition. The warmup → capture → replay state machine is clean, the opt-in flag design is sensible, and the code integrates naturally into the existing backend. Below are the issues I found, ordered by severity.


Bugs

1. Missing cudaStreamSynchronize on capture-step output copy (cuda_backend.cpp:873-887)

The capture step copies outputs to CPU via cudaMemcpyAsync but returns Error::Ok without synchronizing the stream. The caller may read the CPU output buffer before the async copy completes, causing data races. Compare with the replay path at line 641 which correctly calls cudaStreamSynchronize(cs).

// Line 873-887: capture-step copy path — missing sync before return
if (copy_outputs) {
    for (size_t i = 0; i < n_outputs; i++) {
        // ...
        cudaMemcpyAsync(...);
        gpu_outputs[i] = nullptr;
    }
    // Missing: cudaStreamSynchronize(cuda_stream);
}
return Error::Ok;

Fix this →

2. CudaGraphState is implicitly copyable despite owning GPU resources (cuda_delegate_handle.h:50-86)

CudaGraphState has a custom destructor that calls cudaGraphExecDestroy, cudaGraphDestroy, and cudaFree, but no deleted or custom copy/move constructors. An accidental copy would lead to double-free of CUDA resources. The rule of five applies here.

struct CudaGraphState {
    // ... has custom ~CudaGraphState() that frees CUDA resources
    // Missing:
    // CudaGraphState(const CudaGraphState&) = delete;
    // CudaGraphState& operator=(const CudaGraphState&) = delete;
    // CudaGraphState(CudaGraphState&&) = default; // or custom
    // CudaGraphState& operator=(CudaGraphState&&) = default; // or custom
};

Fix this →

3. Result::get() called before error check (cuda_backend.cpp:604-606)

Result<cudaStream_t> csr = getCurrentCUDAStream(0);
cudaStream_t cs = csr.get();        // UB if error
ET_CHECK_OK_OR_RETURN_ERROR(csr.error()); // checked too late

Calling .get() on a Result that holds an error is undefined behavior. The error check should come first. Note: this same pattern exists in the pre-existing code at line 775-777, so it's a pre-existing issue being replicated here.


Potential Issues

4. Capture-step output tensors leak when !copy_outputs (cuda_backend.cpp:870-887)

When should_skip_copy_for_method returns true during the capture step, gpu_outputs[i] are never nulled out and the ScopeGuard will delete them. But these output SlimTensors' data pointers were just stored into static_output_ptrs — deleting the SlimTensors may (depending on SlimTensor semantics) free the underlying CUDA memory that the graph replay relies on. If SlimTensors don't free allocated memory on delete, there's a smaller issue: the static output pointers are stored but the skip-copy path doesn't wire them up to ETensors, so the caller gets no data.

5. No input size validation during replay (cuda_backend.cpp:609-617)

During replay, input data is copied using static_input_nbytes[i] recorded at capture time. If an input tensor's size changes between calls (e.g., different sequence length), this will silently copy the wrong amount of data (buffer overread/overwrite). While the PR description says this requires static shapes, a runtime assertion would catch misuse early:

// Suggested: add in replay loop
ET_CHECK_OR_RETURN_ERROR(
    cpu_tensor->nbytes() == handle->cuda_graph_state.static_input_nbytes[i],
    InvalidArgument,
    "CUDA graph replay: input %zu size mismatch (expected %zu, got %zu)",
    i, handle->cuda_graph_state.static_input_nbytes[i], cpu_tensor->nbytes());

Fix this →

6. cudaMemcpyAsync return values unchecked (multiple locations)

None of the cudaMemcpyAsync calls in the replay path (lines 611-616, 634-639) or capture output copy path (lines 876-881) check the return value. While failures are unlikely in practice, a failed async copy would produce silently corrupted data. The cudaGraphLaunch calls are correctly checked — the memcpy calls should be consistent.


Minor / Style

7. Unused metadata stored in CudaGraphState (cuda_delegate_handle.h:63-68)

static_input_sizes, static_input_strides, static_output_sizes, static_output_strides, static_input_scalar_types, and static_output_scalar_types are populated during capture but never read afterward. If they're intended for future validation or debugging, that's fine, but currently they add memory overhead with no benefit. Consider removing them or adding the validation that justifies their existence (e.g., the size check suggested in point 5 above).

8. BackendOptions<2> capacity may be insufficient (main.cpp:92)

executorch::runtime::BackendOptions<2> cuda_opts;
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");

BackendOptions<2> allocates space for 2 options. Currently only 1 is set, so this works. If future code adds more options in the same block, the capacity of 2 could silently fail. This is minor but worth noting.

9. CI test always enables --cuda_graph (.ci/scripts/test_model_e2e.sh:357)

The qwen3_5_moe case in the e2e test script always passes --cuda_graph. It would be useful to also have a non-cuda-graph test path to ensure the normal execution path doesn't regress for this model. Consider parameterizing this.


Summary

The core design is sound — the warmup/capture/replay state machine is a proven pattern for CUDA graphs. The main issues to address before merging are:

  1. Missing stream sync on the capture step output copy path (data race)
  2. Missing copy/move delete on CudaGraphState (double-free risk)
  3. Add runtime size validation during replay (safety against shape mismatches)

The rest are improvements worth considering.


@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 15, 2026 04:08 — with GitHub Actions Inactive
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 15, 2026 20:41 — with GitHub Actions Inactive
@Gasoonjia Gasoonjia had a problem deploying to upload-benchmark-results April 16, 2026 00:08 — with GitHub Actions Failure
@Gasoonjia Gasoonjia had a problem deploying to upload-benchmark-results April 16, 2026 17:51 — with GitHub Actions Failure
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 16, 2026 22:58 — with GitHub Actions Inactive
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results April 17, 2026 07:22 — with GitHub Actions Inactive
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants