Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .ci/scripts/test_model_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ EOF
fi
;;
qwen3_5_moe)
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0"
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph"
# CUDA graph capture requires cudaMallocAsync backend for stream-ordered allocations
export PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync
;;
voxtral_realtime)
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"
Expand Down
233 changes: 230 additions & 3 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -80,6 +80,9 @@
constexpr char kSkipCopyOutputToCpuForMethod[] =
"skip_copy_output_to_cpu_for_method";
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
constexpr char kEnableCudaGraphForMethod[] =
"enable_cuda_graph_for_method";
constexpr int kCudaGraphWarmupSteps = 3;
} // anonymous namespace

class ET_EXPERIMENTAL CudaBackend final
Expand Down Expand Up @@ -146,6 +149,20 @@
return method_in_csv(method_name, skip_copy_method_);
}

void set_cuda_graph_method(
const std::array<char, kMaxOptionValueLength>& raw) {
std::lock_guard<std::mutex> guard(cuda_graph_method_mutex_);
cuda_graph_method_ = std::string(raw.data());
}

bool should_use_cuda_graph_for_method(const std::string& method_name) const {
if (method_name.empty()) {
return false;
}
std::lock_guard<std::mutex> guard(cuda_graph_method_mutex_);
return method_in_csv(method_name, cuda_graph_method_);
}

// Create the shared CUDA stream. Called when use_shared_cuda_stream option
// is set to true. The presence of shared_cuda_stream_ indicates shared mode.
void create_shared_cuda_stream() {
Expand Down Expand Up @@ -264,6 +281,17 @@
ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream);
return Error::InvalidArgument;
}
} else if (std::strcmp(option.key, kEnableCudaGraphForMethod) == 0) {
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
&option.value)) {
set_cuda_graph_method(*val);
} else {
ET_LOG(
Error,
"Option %s must be a method name string.",
kEnableCudaGraphForMethod);
return Error::InvalidArgument;
}
}
}
return Error::Ok;
Expand Down Expand Up @@ -382,9 +410,7 @@
cudaDeviceSynchronize();
buffer_res->Free();
} else {
ET_LOG(
Info,
"weights_blob '%s' not found or update fn is null",
ET_LOG(Info, "weights_blob '%s' not found or update fn is null",
weights_blob_key.c_str());
}

Expand Down Expand Up @@ -512,6 +538,17 @@
method_name.c_str());
}

// Initialize CUDA graph state if enabled for this method.
if (should_use_cuda_graph_for_method(method_name)) {
handle->cuda_graph_phase = 1; // warmup
handle->cuda_graph_warmup_remaining = kCudaGraphWarmupSteps;
ET_LOG(
Info,
"CUDA graph enabled for method '%s' (warmup=%d)",
method_name.c_str(),
kCudaGraphWarmupSteps);
}

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand All @@ -538,6 +575,59 @@
n_outputs,
args.size())

// ---------------------------------------------------------------
// CUDA graph REPLAY path — skip all tensor setup and just replay
// ---------------------------------------------------------------
if (handle->cuda_graph_phase == 2) {
Result<cudaStream_t> csr = getCurrentCUDAStream(0);
cudaStream_t cs = csr.get();
ET_CHECK_OK_OR_RETURN_ERROR(csr.error());

// Copy new input data into static input buffers
for (size_t i = 0; i < n_inputs; i++) {
auto* cpu_tensor = &(args[i]->toTensor());
cudaMemcpyAsync(
handle->static_input_ptrs[i],
cpu_tensor->const_data_ptr(),
handle->static_input_nbytes[i],
cudaMemcpyHostToDevice,
cs);
}

// Replay the captured graph
cudaError_t gerr = cudaGraphLaunch(handle->cuda_graph_exec, cs);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaGraphLaunch failed: %s",
cudaGetErrorString(gerr));

// Copy outputs back to CPU
const bool copy_outputs =
!should_skip_copy_for_method(handle->method_name);
if (copy_outputs) {
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_out = &(args[i + n_inputs]->toTensor());
cudaMemcpyAsync(
cpu_out->mutable_data_ptr(),
handle->static_output_ptrs[i],
handle->static_output_nbytes[i],
cudaMemcpyDeviceToHost,
cs);
}
cudaStreamSynchronize(cs);
}

return Error::Ok;
}

// ---------------------------------------------------------------
// Normal path (also used for WARMUP and CAPTURE phases)
// ---------------------------------------------------------------
bool is_capture_step =
(handle->cuda_graph_phase == 1 &&
handle->cuda_graph_warmup_remaining == 0);

// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
// optimization. We need to create GPU copies for CUDA kernel execution
// using SlimTensor.
Expand All @@ -548,6 +638,41 @@
for (size_t i = 0; i < n_inputs; i++) {
auto* cpu_tensor = &(args[i]->toTensor());

// CAPTURE step: allocate persistent static GPU buffers
if (is_capture_step) {
auto sizes = cpu_tensor->sizes();
auto strides = cpu_tensor->strides();
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
size_t nbytes = cpu_tensor->nbytes();

void* static_ptr = nullptr;
cudaError_t merr = cudaMalloc(&static_ptr, nbytes);
ET_CHECK_OR_RETURN_ERROR(
merr == cudaSuccess, Internal,
"cudaMalloc for static input %zu failed: %s",
i, cudaGetErrorString(merr));

cudaMemcpy(
static_ptr, cpu_tensor->const_data_ptr(),
nbytes, cudaMemcpyHostToDevice);

handle->static_input_ptrs.push_back(static_ptr);
handle->static_input_sizes.push_back(sizes_vec);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean a method must have all static shapes for its intermediate tensors?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the cuda graph can only support a static pipeline.

handle->static_input_strides.push_back(strides_vec);
handle->static_input_scalar_types.push_back(
static_cast<int>(cpu_tensor->scalar_type()));
handle->static_input_nbytes.push_back(nbytes);

gpu_inputs[i] = new SlimTensor(slim::from_blob(
static_ptr,
slim::makeArrayRef(sizes_vec),
slim::makeArrayRef(strides_vec),
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
DEFAULT_CUDA_DEVICE, 0));
continue;
}

// Check if input data is already on GPU (skip-copy optimization for
// inputs) This can happen when the caller has pre-staged data on GPU
cudaPointerAttributes attributes{};
Expand Down Expand Up @@ -622,6 +747,23 @@
Result<cudaStream_t> cuda_stream_ret = getCurrentCUDAStream(0);
cudaStream_t cuda_stream = cuda_stream_ret.get();
ET_CHECK_OK_OR_RETURN_ERROR(cuda_stream_ret.error());

if (is_capture_step) {
// ----- CUDA graph CAPTURE -----
ET_LOG(
Info,
"CUDA graph: beginning stream capture for '%s'",
handle->method_name.c_str());

cudaError_t cerr = cudaStreamBeginCapture(
cuda_stream, cudaStreamCaptureModeRelaxed);
ET_CHECK_OR_RETURN_ERROR(
cerr == cudaSuccess,
Internal,
"cudaStreamBeginCapture failed: %s",
cudaGetErrorString(cerr));
}

AOTIRuntimeError error = handle->run(
handle->container_handle,
reinterpret_cast<Tensor**>(gpu_inputs.data()),
Expand All @@ -647,6 +789,88 @@
"AOTInductorModelContainerRun failed with error code %d",
error);

if (is_capture_step) {
// End capture → instantiate graph
cudaError_t gerr =
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaStreamEndCapture failed: %s",
cudaGetErrorString(gerr));

gerr = cudaGraphInstantiate(
&handle->cuda_graph_exec, handle->cuda_graph,
cudaGraphInstantiateFlagAutoFreeOnLaunch);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaGraphInstantiate failed: %s",
cudaGetErrorString(gerr));

// Record static output pointers (stable under graph replay)
for (size_t i = 0; i < n_outputs; i++) {
SlimTensor* out = gpu_outputs[i];
handle->static_output_ptrs.push_back(out->data_ptr());

auto out_sizes = out->sizes();
auto out_strides = out->strides();
handle->static_output_sizes.push_back(
std::vector<int64_t>(out_sizes.begin(), out_sizes.end()));
handle->static_output_strides.push_back(
std::vector<int64_t>(out_strides.begin(), out_strides.end()));
handle->static_output_scalar_types.push_back(
static_cast<int>(out->dtype()));
handle->static_output_nbytes.push_back(out->nbytes());
}

handle->cuda_graph_phase = 2; // switch to replay mode
ET_LOG(
Info,
"CUDA graph: captured and instantiated for '%s'",
handle->method_name.c_str());

// Replay once to actually produce output (capture doesn't execute)
gerr = cudaGraphLaunch(handle->cuda_graph_exec, cuda_stream);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaGraphLaunch (first replay) failed: %s",
cudaGetErrorString(gerr));

// Copy capture-step outputs to CPU
const bool copy_outputs =
!should_skip_copy_for_method(handle->method_name);
if (copy_outputs) {
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_out = &(args[i + n_inputs]->toTensor());
cudaMemcpyAsync(
cpu_out->mutable_data_ptr(),
handle->static_output_ptrs[i],
handle->static_output_nbytes[i],
cudaMemcpyDeviceToHost,
cuda_stream);
// Don't delete — static buffers are owned by the handle
gpu_outputs[i] = nullptr;
}
}

return Error::Ok;
}

// ----- Normal / WARMUP execution continues here -----

// Decrement warmup counter if in warmup phase
if (handle->cuda_graph_phase == 1 &&
handle->cuda_graph_warmup_remaining > 0) {
handle->cuda_graph_warmup_remaining--;
ET_LOG(
Info,
"CUDA graph warmup: %d steps remaining for '%s'",
handle->cuda_graph_warmup_remaining,
handle->method_name.c_str());
}

const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);

if (copy_outputs) {
Expand Down Expand Up @@ -741,6 +965,9 @@
mutable std::mutex skip_copy_method_mutex_;
std::string skip_copy_method_;

mutable std::mutex cuda_graph_method_mutex_;
std::string cuda_graph_method_;

// Shared CUDA stream for all methods. When set (non-null), all methods use
// the same stream to ensure proper ordering (critical for skip-copy
// optimization). Created when use_shared_cuda_stream option is set to true.
Expand Down
40 changes: 40 additions & 0 deletions backends/cuda/runtime/cuda_delegate_handle.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand All @@ -11,6 +11,7 @@
#include <cuda_runtime.h>
#include <executorch/backends/aoti/aoti_delegate_handle.h>
#include <memory>
#include <vector>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -58,6 +59,45 @@
bool has_cuda_stream() const {
return cuda_stream != nullptr && *cuda_stream != nullptr;
}

// --- CUDA graph state ---
// Phase: 0=disabled, 1=warmup, 2=captured (replay mode)
int cuda_graph_phase = 0;
int cuda_graph_warmup_remaining = 0;

// Captured graph and executable instance
cudaGraph_t cuda_graph = nullptr;
cudaGraphExec_t cuda_graph_exec = nullptr;

// Static input/output GPU buffers pinned during capture.
// These hold the tensor metadata; the underlying data pointers are fixed
// addresses that CUDA graph replay will write to / read from.
// SlimTensor pointers — owned by this handle.
std::vector<void*> static_input_ptrs; // raw GPU data pointers for inputs
std::vector<void*> static_output_ptrs; // raw GPU data pointers for outputs
std::vector<std::vector<int64_t>> static_input_sizes;
std::vector<std::vector<int64_t>> static_input_strides;
std::vector<std::vector<int64_t>> static_output_sizes;
std::vector<std::vector<int64_t>> static_output_strides;
std::vector<int> static_input_scalar_types;
std::vector<int> static_output_scalar_types;
std::vector<size_t> static_input_nbytes;
std::vector<size_t> static_output_nbytes;

~CudaDelegateHandle() {
if (cuda_graph_exec) {
cudaGraphExecDestroy(cuda_graph_exec);
}
if (cuda_graph) {
cudaGraphDestroy(cuda_graph);
}
// Only free input buffers — output buffers are owned by the AOTI runtime
// (allocated during graph capture via the caching allocator).
for (auto* ptr : static_input_ptrs) {
if (ptr)
cudaFree(ptr);
}
}
};

} // namespace cuda
Expand Down
11 changes: 11 additions & 0 deletions examples/models/qwen3_5_moe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <executorch/extension/llm/sampler/util.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/platform/log.h>
#include <pytorch/tokenizers/hf_tokenizer.h>

Expand All @@ -28,6 +30,7 @@ DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
DEFINE_string(prompt, "Hello", "Prompt text.");
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");

namespace llm = ::executorch::extension::llm;
using ::executorch::extension::from_blob;
Expand Down Expand Up @@ -86,6 +89,14 @@ int main(int argc, char** argv) {

printf("Loading methods...\n");

// Set CUDA graph option if requested (must be before load_method)
if (FLAGS_cuda_graph) {
executorch::runtime::BackendOptions<2> cuda_opts;
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");
executorch::runtime::set_option("CudaBackend", cuda_opts.view());
printf("CUDA graph enabled for decode method\n");
}

// Try loading both methods; fall back to single "forward" method
bool dual_method = true;
std::string prefill_method = "prefill";
Expand Down
Loading