Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ jobs:
# Run CUDA backend Python tests
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="

# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="

export-model-cuda-artifact:
name: export-model-cuda-artifact
Expand Down
9 changes: 5 additions & 4 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
runtime/shims/cuda_guard.cpp
)

# Only build int4mm shim when CUDA language/toolchain is available.
# Only build CUDA shims when CUDA language/toolchain is available.
if(CMAKE_CUDA_COMPILER)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
runtime/shims/sort.cu
runtime/shims/sort.cu runtime/shims/rand.cu
)
endif()

Expand Down Expand Up @@ -152,7 +152,8 @@ endif()
# retention.
if(_cuda_is_msvc_toolchain)
target_link_libraries(
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS}
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
${CMAKE_DL_LIBS}
)
# Link object library directly so symbols are pulled exactly once while
# avoiding duplicate static/object inclusion and interface leakage.
Expand All @@ -162,7 +163,7 @@ else()
aoti_cuda_shims
PRIVATE cuda_platform
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
CUDA::cudart ${CMAKE_DL_LIBS}
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
)
endif()

Expand Down
1 change: 1 addition & 0 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
return {
"at::_ops::_weight_int4pack_mm::call": None,
"at::_ops::sort_stable::call": None,
"aoti_torch_cuda_randint_low_out": None,
}

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ runtime.cxx_library(
"shims/cuda_guard.cpp",
"shims/int4mm.cu",
"shims/memory.cpp",
"shims/rand.cu",
"shims/sort.cu",
"shims/tensor_attribute.cpp",
],
Expand All @@ -41,6 +42,7 @@ runtime.cxx_library(
"shims/int4mm.cuh",
"shims/int4mm.h",
"shims/memory.h",
"shims/rand.h",
"shims/sort.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
3 changes: 3 additions & 0 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ class ET_EXPERIMENTAL CudaBackend final

gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata(
static_ptr, cpu_tensor);

continue;
}

Expand Down Expand Up @@ -805,6 +806,7 @@ class ET_EXPERIMENTAL CudaBackend final
// End capture → instantiate graph
cudaError_t gerr =
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph_state.graph);

ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
Expand All @@ -814,6 +816,7 @@ class ET_EXPERIMENTAL CudaBackend final
gerr = cudaGraphInstantiate(
&handle->cuda_graph_state.graph_exec,
handle->cuda_graph_state.graph,

cudaGraphInstantiateFlagAutoFreeOnLaunch);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Expand Down
38 changes: 38 additions & 0 deletions backends/cuda/runtime/cuda_delegate_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,44 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {

// CUDA graph state (warmup, capture, replay, static buffers)
CudaGraphState cuda_graph_state;
// --- CUDA graph state ---
// Phase: 0=disabled, 1=warmup, 2=captured (replay mode)
int cuda_graph_phase = 0;
int cuda_graph_warmup_remaining = 0;

// Captured graph and executable instance
cudaGraph_t cuda_graph = nullptr;
cudaGraphExec_t cuda_graph_exec = nullptr;

// Static input/output GPU buffers pinned during capture.
// These hold the tensor metadata; the underlying data pointers are fixed
// addresses that CUDA graph replay will write to / read from.
// SlimTensor pointers — owned by this handle.
std::vector<void*> static_input_ptrs; // raw GPU data pointers for inputs
std::vector<void*> static_output_ptrs; // raw GPU data pointers for outputs
std::vector<std::vector<int64_t>> static_input_sizes;
std::vector<std::vector<int64_t>> static_input_strides;
std::vector<std::vector<int64_t>> static_output_sizes;
std::vector<std::vector<int64_t>> static_output_strides;
std::vector<int> static_input_scalar_types;
std::vector<int> static_output_scalar_types;
std::vector<size_t> static_input_nbytes;
std::vector<size_t> static_output_nbytes;

~CudaDelegateHandle() {
if (cuda_graph_exec) {
cudaGraphExecDestroy(cuda_graph_exec);
}
if (cuda_graph) {
cudaGraphDestroy(cuda_graph);
}
// Only free input buffers — output buffers are owned by the AOTI runtime
// (allocated during graph capture via the caching allocator).
for (auto* ptr : static_input_ptrs) {
if (ptr)
cudaFree(ptr);
}
}
};

} // namespace cuda
Expand Down
Loading
Loading