Skip to content

Commit ca69871

Browse files
author
gasoonjia
committed
Improve CUDA backend error handling and add dual-method runner fallback
- cuda_backend.cpp: Replace ET_CHECK_OK_OR_RETURN_ERROR with explicit error handling + cudaDeviceSynchronize after weight transfer, add logging for missing weights_blob - main.cpp: Support single "forward" method fallback when prefill/decode not available, use prefill_method variable, remove debug printf
1 parent a0a62f1 commit ca69871

2 files changed

Lines changed: 35 additions & 10 deletions

File tree

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,18 @@ class ET_EXPERIMENTAL CudaBackend final
372372
const void* weights_blob = buffer_res->data();
373373
// Feed the weights blob into the container. Under the hood it's copying
374374
// weights, so we should free the buffer immediately.
375-
ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
376-
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
375+
auto update_err = handle->update_constants_from_blob(
376+
handle->container_handle, static_cast<const uint8_t*>(weights_blob));
377+
if (update_err != Error::Ok) {
378+
ET_LOG(Error, "update_constants_from_blob failed");
379+
return update_err;
380+
}
381+
// Ensure all weight transfers are complete before execution
382+
cudaDeviceSynchronize();
377383
buffer_res->Free();
384+
} else {
385+
ET_LOG(Info, "weights_blob '%s' not found or update fn is null",
386+
weights_blob_key.c_str());
378387
}
379388

380389
// Use shared CUDA stream if enabled via options, otherwise create one.

examples/models/qwen3_5_moe/main.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,27 @@ int main(int argc, char** argv) {
8686

8787
printf("Loading methods...\n");
8888

89-
// Load both methods
89+
// Try loading both methods; fall back to single "forward" method
90+
bool dual_method = true;
91+
std::string prefill_method = "prefill";
9092
auto err = module->load_method("prefill");
9193
if (err != Error::Ok) {
92-
ET_LOG(Error, "Failed to load prefill method");
93-
return 1;
94+
// Try "forward" for single-method export
95+
err = module->load_method("forward");
96+
if (err != Error::Ok) {
97+
ET_LOG(Error, "Failed to load prefill/forward method");
98+
return 1;
99+
}
100+
prefill_method = "forward";
101+
dual_method = false;
102+
printf("Using single-method mode (forward)\n");
94103
}
95-
err = module->load_method("decode");
96-
if (err != Error::Ok) {
97-
ET_LOG(Error, "Failed to load decode method");
98-
return 1;
104+
if (dual_method) {
105+
err = module->load_method("decode");
106+
if (err != Error::Ok) {
107+
ET_LOG(Error, "Failed to load decode method");
108+
return 1;
109+
}
99110
}
100111

101112
// Get EOS ids
@@ -138,7 +149,7 @@ int main(int argc, char** argv) {
138149
prefill_inputs.push_back(tokens_tensor);
139150
prefill_inputs.push_back(pos_tensor);
140151

141-
auto prefill_result = module->execute("prefill", prefill_inputs);
152+
auto prefill_result = module->execute(prefill_method, prefill_inputs);
142153
if (prefill_result.error() != Error::Ok) {
143154
ET_LOG(Error, "Prefill failed");
144155
return 1;
@@ -165,6 +176,11 @@ int main(int argc, char** argv) {
165176
// decode method, which may run on a different CUDA stream.
166177
cudaDeviceSynchronize();
167178

179+
if (!dual_method) {
180+
printf("Single-method mode: skipping decode\n");
181+
return 0;
182+
}
183+
168184
// ---------------------------------------------------------------
169185
// Decode — generate tokens one at a time
170186
// ---------------------------------------------------------------

0 commit comments

Comments
 (0)