Skip to content

Commit 7c148f7

Browse files
author
gasoonjia
committed
Add CUDA graph capture/replay for decode method
Implements CUDA graph support in the CUDA backend to reduce CPU kernel launch overhead during autoregressive decoding: - cuda_backend.cpp: 3-phase execution (warmup → capture → replay) with static input/output GPU buffers, cudaMemcpyAsync for I/O, and cudaGraphInstantiateFlagAutoFreeOnLaunch for cudaMallocAsync compat - cuda_delegate_handle.h: CUDA graph state (phase, graph objects, static buffer metadata) with RAII cleanup in destructor - main.cpp: --cuda_graph flag that sets BackendOptions before load_method - test_model_e2e.sh: Enable --cuda_graph for Qwen3.5 MoE CI, set PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync Benchmark (A100, Qwen3.5-35B-A3B HQQ-INT4): 82→98 tok/s (1.20x)
1 parent ca69871 commit 7c148f7

4 files changed

Lines changed: 283 additions & 1 deletion

File tree

.ci/scripts/test_model_e2e.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,9 @@ EOF
354354
fi
355355
;;
356356
qwen3_5_moe)
357-
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0"
357+
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph"
358+
# CUDA graph capture requires cudaMallocAsync backend for stream-ordered allocations
359+
export PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync
358360
;;
359361
voxtral_realtime)
360362
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ namespace {
8080
constexpr char kSkipCopyOutputToCpuForMethod[] =
8181
"skip_copy_output_to_cpu_for_method";
8282
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
83+
constexpr char kEnableCudaGraphForMethod[] =
84+
"enable_cuda_graph_for_method";
85+
constexpr int kCudaGraphWarmupSteps = 3;
8386
} // anonymous namespace
8487

8588
class ET_EXPERIMENTAL CudaBackend final
@@ -146,6 +149,20 @@ class ET_EXPERIMENTAL CudaBackend final
146149
return method_in_csv(method_name, skip_copy_method_);
147150
}
148151

152+
void set_cuda_graph_method(
153+
const std::array<char, kMaxOptionValueLength>& raw) {
154+
std::lock_guard<std::mutex> guard(cuda_graph_method_mutex_);
155+
cuda_graph_method_ = std::string(raw.data());
156+
}
157+
158+
bool should_use_cuda_graph_for_method(const std::string& method_name) const {
159+
if (method_name.empty()) {
160+
return false;
161+
}
162+
std::lock_guard<std::mutex> guard(cuda_graph_method_mutex_);
163+
return method_in_csv(method_name, cuda_graph_method_);
164+
}
165+
149166
// Create the shared CUDA stream. Called when use_shared_cuda_stream option
150167
// is set to true. The presence of shared_cuda_stream_ indicates shared mode.
151168
void create_shared_cuda_stream() {
@@ -264,6 +281,17 @@ class ET_EXPERIMENTAL CudaBackend final
264281
ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream);
265282
return Error::InvalidArgument;
266283
}
284+
} else if (std::strcmp(option.key, kEnableCudaGraphForMethod) == 0) {
285+
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
286+
&option.value)) {
287+
set_cuda_graph_method(*val);
288+
} else {
289+
ET_LOG(
290+
Error,
291+
"Option %s must be a method name string.",
292+
kEnableCudaGraphForMethod);
293+
return Error::InvalidArgument;
294+
}
267295
}
268296
}
269297
return Error::Ok;
@@ -510,6 +538,17 @@ class ET_EXPERIMENTAL CudaBackend final
510538
method_name.c_str());
511539
}
512540

541+
// Initialize CUDA graph state if enabled for this method.
542+
if (should_use_cuda_graph_for_method(method_name)) {
543+
handle->cuda_graph_phase = 1; // warmup
544+
handle->cuda_graph_warmup_remaining = kCudaGraphWarmupSteps;
545+
ET_LOG(
546+
Info,
547+
"CUDA graph enabled for method '%s' (warmup=%d)",
548+
method_name.c_str(),
549+
kCudaGraphWarmupSteps);
550+
}
551+
513552
return (DelegateHandle*)handle; // Return the handle post-processing
514553
}
515554

@@ -536,6 +575,59 @@ class ET_EXPERIMENTAL CudaBackend final
536575
n_outputs,
537576
args.size())
538577

578+
// ---------------------------------------------------------------
579+
// CUDA graph REPLAY path — skip all tensor setup and just replay
580+
// ---------------------------------------------------------------
581+
if (handle->cuda_graph_phase == 2) {
582+
Result<cudaStream_t> csr = getCurrentCUDAStream(0);
583+
cudaStream_t cs = csr.get();
584+
ET_CHECK_OK_OR_RETURN_ERROR(csr.error());
585+
586+
// Copy new input data into static input buffers
587+
for (size_t i = 0; i < n_inputs; i++) {
588+
auto* cpu_tensor = &(args[i]->toTensor());
589+
cudaMemcpyAsync(
590+
handle->static_input_ptrs[i],
591+
cpu_tensor->const_data_ptr(),
592+
handle->static_input_nbytes[i],
593+
cudaMemcpyHostToDevice,
594+
cs);
595+
}
596+
597+
// Replay the captured graph
598+
cudaError_t gerr = cudaGraphLaunch(handle->cuda_graph_exec, cs);
599+
ET_CHECK_OR_RETURN_ERROR(
600+
gerr == cudaSuccess,
601+
Internal,
602+
"cudaGraphLaunch failed: %s",
603+
cudaGetErrorString(gerr));
604+
605+
// Copy outputs back to CPU
606+
const bool copy_outputs =
607+
!should_skip_copy_for_method(handle->method_name);
608+
if (copy_outputs) {
609+
for (size_t i = 0; i < n_outputs; i++) {
610+
auto* cpu_out = &(args[i + n_inputs]->toTensor());
611+
cudaMemcpyAsync(
612+
cpu_out->mutable_data_ptr(),
613+
handle->static_output_ptrs[i],
614+
handle->static_output_nbytes[i],
615+
cudaMemcpyDeviceToHost,
616+
cs);
617+
}
618+
cudaStreamSynchronize(cs);
619+
}
620+
621+
return Error::Ok;
622+
}
623+
624+
// ---------------------------------------------------------------
625+
// Normal path (also used for WARMUP and CAPTURE phases)
626+
// ---------------------------------------------------------------
627+
bool is_capture_step =
628+
(handle->cuda_graph_phase == 1 &&
629+
handle->cuda_graph_warmup_remaining == 0);
630+
539631
// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
540632
// optimization. We need to create GPU copies for CUDA kernel execution
541633
// using SlimTensor.
@@ -546,6 +638,41 @@ class ET_EXPERIMENTAL CudaBackend final
546638
for (size_t i = 0; i < n_inputs; i++) {
547639
auto* cpu_tensor = &(args[i]->toTensor());
548640

641+
// CAPTURE step: allocate persistent static GPU buffers
642+
if (is_capture_step) {
643+
auto sizes = cpu_tensor->sizes();
644+
auto strides = cpu_tensor->strides();
645+
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
646+
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
647+
size_t nbytes = cpu_tensor->nbytes();
648+
649+
void* static_ptr = nullptr;
650+
cudaError_t merr = cudaMalloc(&static_ptr, nbytes);
651+
ET_CHECK_OR_RETURN_ERROR(
652+
merr == cudaSuccess, Internal,
653+
"cudaMalloc for static input %zu failed: %s",
654+
i, cudaGetErrorString(merr));
655+
656+
cudaMemcpy(
657+
static_ptr, cpu_tensor->const_data_ptr(),
658+
nbytes, cudaMemcpyHostToDevice);
659+
660+
handle->static_input_ptrs.push_back(static_ptr);
661+
handle->static_input_sizes.push_back(sizes_vec);
662+
handle->static_input_strides.push_back(strides_vec);
663+
handle->static_input_scalar_types.push_back(
664+
static_cast<int>(cpu_tensor->scalar_type()));
665+
handle->static_input_nbytes.push_back(nbytes);
666+
667+
gpu_inputs[i] = new SlimTensor(slim::from_blob(
668+
static_ptr,
669+
slim::makeArrayRef(sizes_vec),
670+
slim::makeArrayRef(strides_vec),
671+
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
672+
DEFAULT_CUDA_DEVICE, 0));
673+
continue;
674+
}
675+
549676
// Check if input data is already on GPU (skip-copy optimization for
550677
// inputs) This can happen when the caller has pre-staged data on GPU
551678
cudaPointerAttributes attributes{};
@@ -620,6 +747,23 @@ class ET_EXPERIMENTAL CudaBackend final
620747
Result<cudaStream_t> cuda_stream_ret = getCurrentCUDAStream(0);
621748
cudaStream_t cuda_stream = cuda_stream_ret.get();
622749
ET_CHECK_OK_OR_RETURN_ERROR(cuda_stream_ret.error());
750+
751+
if (is_capture_step) {
752+
// ----- CUDA graph CAPTURE -----
753+
ET_LOG(
754+
Info,
755+
"CUDA graph: beginning stream capture for '%s'",
756+
handle->method_name.c_str());
757+
758+
cudaError_t cerr = cudaStreamBeginCapture(
759+
cuda_stream, cudaStreamCaptureModeRelaxed);
760+
ET_CHECK_OR_RETURN_ERROR(
761+
cerr == cudaSuccess,
762+
Internal,
763+
"cudaStreamBeginCapture failed: %s",
764+
cudaGetErrorString(cerr));
765+
}
766+
623767
AOTIRuntimeError error = handle->run(
624768
handle->container_handle,
625769
reinterpret_cast<Tensor**>(gpu_inputs.data()),
@@ -645,6 +789,88 @@ class ET_EXPERIMENTAL CudaBackend final
645789
"AOTInductorModelContainerRun failed with error code %d",
646790
error);
647791

792+
if (is_capture_step) {
793+
// End capture → instantiate graph
794+
cudaError_t gerr =
795+
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph);
796+
ET_CHECK_OR_RETURN_ERROR(
797+
gerr == cudaSuccess,
798+
Internal,
799+
"cudaStreamEndCapture failed: %s",
800+
cudaGetErrorString(gerr));
801+
802+
gerr = cudaGraphInstantiate(
803+
&handle->cuda_graph_exec, handle->cuda_graph,
804+
cudaGraphInstantiateFlagAutoFreeOnLaunch);
805+
ET_CHECK_OR_RETURN_ERROR(
806+
gerr == cudaSuccess,
807+
Internal,
808+
"cudaGraphInstantiate failed: %s",
809+
cudaGetErrorString(gerr));
810+
811+
// Record static output pointers (stable under graph replay)
812+
for (size_t i = 0; i < n_outputs; i++) {
813+
SlimTensor* out = gpu_outputs[i];
814+
handle->static_output_ptrs.push_back(out->data_ptr());
815+
816+
auto out_sizes = out->sizes();
817+
auto out_strides = out->strides();
818+
handle->static_output_sizes.push_back(
819+
std::vector<int64_t>(out_sizes.begin(), out_sizes.end()));
820+
handle->static_output_strides.push_back(
821+
std::vector<int64_t>(out_strides.begin(), out_strides.end()));
822+
handle->static_output_scalar_types.push_back(
823+
static_cast<int>(out->dtype()));
824+
handle->static_output_nbytes.push_back(out->nbytes());
825+
}
826+
827+
handle->cuda_graph_phase = 2; // switch to replay mode
828+
ET_LOG(
829+
Info,
830+
"CUDA graph: captured and instantiated for '%s'",
831+
handle->method_name.c_str());
832+
833+
// Replay once to actually produce output (capture doesn't execute)
834+
gerr = cudaGraphLaunch(handle->cuda_graph_exec, cuda_stream);
835+
ET_CHECK_OR_RETURN_ERROR(
836+
gerr == cudaSuccess,
837+
Internal,
838+
"cudaGraphLaunch (first replay) failed: %s",
839+
cudaGetErrorString(gerr));
840+
841+
// Copy capture-step outputs to CPU
842+
const bool copy_outputs =
843+
!should_skip_copy_for_method(handle->method_name);
844+
if (copy_outputs) {
845+
for (size_t i = 0; i < n_outputs; i++) {
846+
auto* cpu_out = &(args[i + n_inputs]->toTensor());
847+
cudaMemcpyAsync(
848+
cpu_out->mutable_data_ptr(),
849+
handle->static_output_ptrs[i],
850+
handle->static_output_nbytes[i],
851+
cudaMemcpyDeviceToHost,
852+
cuda_stream);
853+
// Don't delete — static buffers are owned by the handle
854+
gpu_outputs[i] = nullptr;
855+
}
856+
}
857+
858+
return Error::Ok;
859+
}
860+
861+
// ----- Normal / WARMUP execution continues here -----
862+
863+
// Decrement warmup counter if in warmup phase
864+
if (handle->cuda_graph_phase == 1 &&
865+
handle->cuda_graph_warmup_remaining > 0) {
866+
handle->cuda_graph_warmup_remaining--;
867+
ET_LOG(
868+
Info,
869+
"CUDA graph warmup: %d steps remaining for '%s'",
870+
handle->cuda_graph_warmup_remaining,
871+
handle->method_name.c_str());
872+
}
873+
648874
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);
649875

650876
if (copy_outputs) {
@@ -739,6 +965,9 @@ class ET_EXPERIMENTAL CudaBackend final
739965
mutable std::mutex skip_copy_method_mutex_;
740966
std::string skip_copy_method_;
741967

968+
mutable std::mutex cuda_graph_method_mutex_;
969+
std::string cuda_graph_method_;
970+
742971
// Shared CUDA stream for all methods. When set (non-null), all methods use
743972
// the same stream to ensure proper ordering (critical for skip-copy
744973
// optimization). Created when use_shared_cuda_stream option is set to true.

backends/cuda/runtime/cuda_delegate_handle.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cuda_runtime.h>
1212
#include <executorch/backends/aoti/aoti_delegate_handle.h>
1313
#include <memory>
14+
#include <vector>
1415

1516
namespace executorch {
1617
namespace backends {
@@ -58,6 +59,45 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
5859
bool has_cuda_stream() const {
5960
return cuda_stream != nullptr && *cuda_stream != nullptr;
6061
}
62+
63+
// --- CUDA graph state ---
64+
// Phase: 0=disabled, 1=warmup, 2=captured (replay mode)
65+
int cuda_graph_phase = 0;
66+
int cuda_graph_warmup_remaining = 0;
67+
68+
// Captured graph and executable instance
69+
cudaGraph_t cuda_graph = nullptr;
70+
cudaGraphExec_t cuda_graph_exec = nullptr;
71+
72+
// Static input/output GPU buffers pinned during capture.
73+
// These hold the tensor metadata; the underlying data pointers are fixed
74+
// addresses that CUDA graph replay will write to / read from.
75+
// SlimTensor pointers — owned by this handle.
76+
std::vector<void*> static_input_ptrs; // raw GPU data pointers for inputs
77+
std::vector<void*> static_output_ptrs; // raw GPU data pointers for outputs
78+
std::vector<std::vector<int64_t>> static_input_sizes;
79+
std::vector<std::vector<int64_t>> static_input_strides;
80+
std::vector<std::vector<int64_t>> static_output_sizes;
81+
std::vector<std::vector<int64_t>> static_output_strides;
82+
std::vector<int> static_input_scalar_types;
83+
std::vector<int> static_output_scalar_types;
84+
std::vector<size_t> static_input_nbytes;
85+
std::vector<size_t> static_output_nbytes;
86+
87+
~CudaDelegateHandle() {
88+
if (cuda_graph_exec) {
89+
cudaGraphExecDestroy(cuda_graph_exec);
90+
}
91+
if (cuda_graph) {
92+
cudaGraphDestroy(cuda_graph);
93+
}
94+
// Only free input buffers — output buffers are owned by the AOTI runtime
95+
// (allocated during graph capture via the caching allocator).
96+
for (auto* ptr : static_input_ptrs) {
97+
if (ptr)
98+
cudaFree(ptr);
99+
}
100+
}
61101
};
62102

63103
} // namespace cuda

examples/models/qwen3_5_moe/main.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <executorch/extension/llm/sampler/util.h>
1414
#include <executorch/extension/module/module.h>
1515
#include <executorch/extension/tensor/tensor.h>
16+
#include <executorch/runtime/backend/interface.h>
17+
#include <executorch/runtime/backend/options.h>
1618
#include <executorch/runtime/platform/log.h>
1719
#include <pytorch/tokenizers/hf_tokenizer.h>
1820

@@ -28,6 +30,7 @@ DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
2830
DEFINE_string(prompt, "Hello", "Prompt text.");
2931
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
3032
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
33+
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");
3134

3235
namespace llm = ::executorch::extension::llm;
3336
using ::executorch::extension::from_blob;
@@ -86,6 +89,14 @@ int main(int argc, char** argv) {
8689

8790
printf("Loading methods...\n");
8891

92+
// Set CUDA graph option if requested (must be before load_method)
93+
if (FLAGS_cuda_graph) {
94+
executorch::runtime::BackendOptions<2> cuda_opts;
95+
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");
96+
executorch::runtime::set_option("CudaBackend", cuda_opts.view());
97+
printf("CUDA graph enabled for decode method\n");
98+
}
99+
89100
// Try loading both methods; fall back to single "forward" method
90101
bool dual_method = true;
91102
std::string prefill_method = "prefill";

0 commit comments

Comments
 (0)