diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 68ded356b99..c3b7c058ee6 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -145,8 +145,8 @@ jobs: # Run CUDA backend Python tests python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts=" - # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache) - python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts=" + # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler) + python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts=" export-model-cuda-artifact: name: export-model-cuda-artifact diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index b3e82e2ec51..157cc05a54f 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -107,10 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp runtime/shims/cuda_guard.cpp ) -# Only build int4mm shim when CUDA language/toolchain is available. +# Only build CUDA shims when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu - runtime/shims/sort.cu + runtime/shims/sort.cu runtime/shims/rand.cu ) endif() @@ -152,7 +152,8 @@ endif() # retention. if(_cuda_is_msvc_toolchain) target_link_libraries( - aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS} + aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand + ${CMAKE_DL_LIBS} ) # Link object library directly so symbols are pulled exactly once while # avoiding duplicate static/object inclusion and interface leakage. @@ -162,7 +163,7 @@ else() aoti_cuda_shims PRIVATE cuda_platform PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive - CUDA::cudart ${CMAKE_DL_LIBS} + CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} ) endif() diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 061b0d6a29a..5c6395c8b5b 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -146,6 +146,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "at::_ops::_weight_int4pack_mm::call": None, "at::_ops::sort_stable::call": None, + "aoti_torch_cuda_randint_low_out": None, } @classmethod diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 726f89c8125..f13f41ab8b7 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -33,6 +33,7 @@ runtime.cxx_library( "shims/cuda_guard.cpp", "shims/int4mm.cu", "shims/memory.cpp", + "shims/rand.cu", "shims/sort.cu", "shims/tensor_attribute.cpp", ], @@ -41,6 +42,7 @@ runtime.cxx_library( "shims/int4mm.cuh", "shims/int4mm.h", "shims/memory.h", + "shims/rand.h", "shims/sort.h", "shims/tensor_attribute.h", "utils.h", diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index f98c862f69a..94b992430c4 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -693,6 +693,7 @@ class ET_EXPERIMENTAL CudaBackend final gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata( static_ptr, cpu_tensor); + continue; } @@ -805,6 +806,7 @@ class ET_EXPERIMENTAL CudaBackend final // End capture → instantiate graph cudaError_t gerr = cudaStreamEndCapture(cuda_stream, &handle->cuda_graph_state.graph); + ET_CHECK_OR_RETURN_ERROR( gerr == cudaSuccess, Internal, @@ -814,6 +816,7 @@ class ET_EXPERIMENTAL CudaBackend final gerr = cudaGraphInstantiate( &handle->cuda_graph_state.graph_exec, handle->cuda_graph_state.graph, + cudaGraphInstantiateFlagAutoFreeOnLaunch); ET_CHECK_OR_RETURN_ERROR( gerr == cudaSuccess, diff --git a/backends/cuda/runtime/cuda_delegate_handle.h b/backends/cuda/runtime/cuda_delegate_handle.h index 87845fbc312..bf00769de5d 100644 --- a/backends/cuda/runtime/cuda_delegate_handle.h +++ b/backends/cuda/runtime/cuda_delegate_handle.h @@ -149,6 +149,44 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle { // CUDA graph state (warmup, capture, replay, static buffers) CudaGraphState cuda_graph_state; + // --- CUDA graph state --- + // Phase: 0=disabled, 1=warmup, 2=captured (replay mode) + int cuda_graph_phase = 0; + int cuda_graph_warmup_remaining = 0; + + // Captured graph and executable instance + cudaGraph_t cuda_graph = nullptr; + cudaGraphExec_t cuda_graph_exec = nullptr; + + // Static input/output GPU buffers pinned during capture. + // These hold the tensor metadata; the underlying data pointers are fixed + // addresses that CUDA graph replay will write to / read from. + // SlimTensor pointers — owned by this handle. + std::vector static_input_ptrs; // raw GPU data pointers for inputs + std::vector static_output_ptrs; // raw GPU data pointers for outputs + std::vector> static_input_sizes; + std::vector> static_input_strides; + std::vector> static_output_sizes; + std::vector> static_output_strides; + std::vector static_input_scalar_types; + std::vector static_output_scalar_types; + std::vector static_input_nbytes; + std::vector static_output_nbytes; + + ~CudaDelegateHandle() { + if (cuda_graph_exec) { + cudaGraphExecDestroy(cuda_graph_exec); + } + if (cuda_graph) { + cudaGraphDestroy(cuda_graph); + } + // Only free input buffers — output buffers are owned by the AOTI runtime + // (allocated during graph capture via the caching allocator). + for (auto* ptr : static_input_ptrs) { + if (ptr) + cudaFree(ptr); + } + } }; } // namespace cuda diff --git a/backends/cuda/runtime/shims/rand.cu b/backends/cuda/runtime/shims/rand.cu new file mode 100644 index 00000000000..375cbc9255c --- /dev/null +++ b/backends/cuda/runtime/shims/rand.cu @@ -0,0 +1,273 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +namespace c10 = executorch::backends::aoti::slim::c10; +using c10::Device; +using c10::DeviceIndex; +using c10::DeviceType; +using c10::ScalarType; +using executorch::backends::aoti::slim::empty_strided; +using executorch::backends::aoti::slim::IntArrayRef; +using executorch::backends::aoti::slim::makeArrayRef; + +namespace { + +// ---- GPU-resident RNG state ---- +// Seed and counter live in device memory allocated during the first call +// (warmup phase, before CUDA graph capture). The counter is atomically +// advanced by each kernel invocation on-device, so it automatically +// produces different random sequences on every CUDA graph replay. + +struct RngState { + unsigned long long seed; + unsigned long long counter; + // Per-launch scratch — written by advance_counter_kernel and read by + // the main RNG kernels. Single-threaded host driver is assumed + // (typical inference / CUDA-graph replay use case). + unsigned long long base_scratch; +}; + +static RngState* d_rng = nullptr; +// std::call_once guarantees one-shot initialization even when shims are +// invoked from multiple host threads (e.g. concurrent models / streams). +static std::once_flag g_rng_init_flag; + +// Initialize RNG state on the given stream. +// Must be called during warmup (before graph capture). Subsequent calls +// from any thread are no-ops thanks to std::call_once. +void ensure_rng_init(cudaStream_t stream) { + std::call_once(g_rng_init_flag, [&]() { + cudaMallocAsync(&d_rng, sizeof(RngState), stream); + RngState h; + h.seed = static_cast(time(nullptr)); + h.counter = 0; + h.base_scratch = 0; + cudaMemcpyAsync( + d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream); + // Synchronize to ensure the copy completes before we return + // (the host-side RngState `h` is on the stack). + cudaStreamSynchronize(stream); + }); +} + +// Philox-based randint kernel. Reads its base offset from `rng->base_scratch` +// (populated by `advance_counter_kernel` immediately before this launch). +// This replaces the previous per-element atomicAdd contention with a single +// atomic per kernel launch. +__global__ void philox_randint_graph_kernel( + int64_t* __restrict__ out, + int64_t numel, + int64_t low, + int64_t range, + RngState* __restrict__ rng) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, rng->base_scratch, &state); + double val = curand_uniform_double(&state); + int64_t ival = static_cast(val * range); + out[idx] = low + (ival >= range ? range - 1 : ival); + } +} + +// Philox-based uniform float32 generator (graph-safe version). +__global__ void philox_rand_float_graph_kernel( + float* __restrict__ out, + int64_t numel, + RngState* __restrict__ rng) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, rng->base_scratch, &state); + out[idx] = curand_uniform(&state); + } +} + +// Philox-based uniform bfloat16 generator (graph-safe version). +__global__ void philox_rand_bf16_graph_kernel( + uint16_t* __restrict__ out, + int64_t numel, + RngState* __restrict__ rng) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < numel) { + curandStatePhilox4_32_10_t state; + curand_init(rng->seed, idx, rng->base_scratch, &state); + float val = curand_uniform(&state); + uint32_t bits; + memcpy(&bits, &val, sizeof(uint32_t)); + uint32_t lsb = (bits >> 16) & 1; + bits += 0x7FFFu + lsb; + out[idx] = static_cast(bits >> 16); + } +} + +// Single-thread helper that grabs a contiguous range of `numel` offsets +// from the on-device counter and writes the base into `rng->base_scratch`. +// Replaces `numel` per-element atomics with a single atomic per launch +// while staying graph-capturable. +__global__ void advance_counter_kernel( + RngState* __restrict__ rng, + unsigned long long numel) { + if (blockIdx.x == 0 && threadIdx.x == 0) { + rng->base_scratch = atomicAdd(&rng->counter, numel); + } +} + +} // anonymous namespace + +extern "C" { + +AOTITorchError aoti_torch_cuda_rand( + const int64_t* size, + int64_t size_len_, + int32_t* dtype, + int32_t* layout, + int32_t* device, + int32_t device_index_, + int32_t* pin_memory, + SlimTensor** ret0) { + (void)layout; + (void)device; + (void)pin_memory; + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_rand: ret0 is null"); + + // Default to float32 if dtype not specified. + ScalarType scalar_type = ScalarType::Float; + if (dtype != nullptr) { + scalar_type = static_cast(*dtype); + } + + // Compute contiguous strides and total elements. + std::vector strides(size_len_); + int64_t numel = 1; + for (int64_t i = size_len_ - 1; i >= 0; i--) { + strides[i] = numel; + numel *= size[i]; + } + + // Allocate output tensor. + IntArrayRef sizes_ref(size, static_cast(size_len_)); + *ret0 = new SlimTensor(empty_strided( + sizes_ref, + makeArrayRef(strides), + scalar_type, + Device(DeviceType::CUDA, static_cast(device_index_)))); + + if (numel == 0) { + return Error::Ok; + } + + // Get the current CUDA stream. + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), + Internal, + "aoti_torch_cuda_rand: failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + ensure_rng_init(stream); + + constexpr int kThreads = 256; + int blocks = static_cast((numel + kThreads - 1) / kThreads); + + // Single atomicAdd per launch — grabs `numel` consecutive counter slots + // for the kernel below, eliminating per-element contention on the GPU + // counter. + advance_counter_kernel<<<1, 1, 0, stream>>>( + d_rng, static_cast(numel)); + + if (scalar_type == ScalarType::Float) { + philox_rand_float_graph_kernel<<>>( + static_cast((*ret0)->data_ptr()), numel, d_rng); + } else if (scalar_type == ScalarType::BFloat16) { + philox_rand_bf16_graph_kernel<<>>( + static_cast((*ret0)->data_ptr()), numel, d_rng); + } else { + ET_LOG( + Error, + "aoti_torch_cuda_rand: unsupported dtype %d", + static_cast(scalar_type)); + return Error::NotSupported; + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_) { + ET_CHECK_OR_RETURN_ERROR( + out != nullptr, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: out tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + high > low, + InvalidArgument, + "aoti_torch_cuda_randint_low_out: requires high > low"); + + int64_t numel = 1; + for (int64_t i = 0; i < size_len_; i++) { + numel *= size[i]; + } + if (numel == 0) { + return Error::Ok; + } + + // Get the current CUDA stream. + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), + Internal, + "aoti_torch_cuda_randint_low_out: failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + ensure_rng_init(stream); + + int64_t range = high - low; + int64_t* out_data = static_cast(out->data_ptr()); + + constexpr int kThreads = 256; + int blocks = static_cast((numel + kThreads - 1) / kThreads); + // One atomicAdd per launch; subsequent kernel reads `rng->base_scratch`. + advance_counter_kernel<<<1, 1, 0, stream>>>( + d_rng, static_cast(numel)); + philox_randint_graph_kernel<<>>( + out_data, numel, low, range, d_rng); + + return Error::Ok; +} + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/rand.h b/backends/cuda/runtime/shims/rand.h new file mode 100644 index 00000000000..e0c63be75a1 --- /dev/null +++ b/backends/cuda/runtime/shims/rand.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; +using AOTITorchError = Error; + +using SlimTensor = executorch::backends::aoti::slim::SlimTensor; + +extern "C" { + +/** + * Generates a tensor filled with uniform random values in [0, 1). + * + * Implements the AOTI shim for aten::rand.default on CUDA. Uses cuRAND + * Philox counter-based RNG with GPU-resident state. The counter is + * atomically advanced by each kernel invocation on-device, making it + * fully compatible with CUDA graph capture and replay — each replay + * produces different values because the counter increments on the GPU. + * + * Supports float32 and bfloat16 output dtypes. + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_rand( + const int64_t* size, + int64_t size_len_, + int32_t* dtype, + int32_t* layout, + int32_t* device, + int32_t device_index_, + int32_t* pin_memory, + SlimTensor** ret0); + +/** + * Fills a pre-allocated int64 tensor with random integers in [low, high). + * + * Implements the AOTI shim for aten::randint.low_out on CUDA. Used by + * Inductor's Philox RNG to generate random seeds. Each thread atomically + * advances a GPU-resident counter for unique offsets, making this fully + * compatible with CUDA graph capture and replay. + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out( + SlimTensor* out, + int64_t low, + int64_t high, + const int64_t* size, + int64_t size_len_); + +} // extern "C" + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index c80db9f523d..aec5219d680 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -42,6 +42,7 @@ set(CUDA_SHIM_TESTS test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor test_aoti_torch_copy_ + test_aoti_torch_cuda_rand test_aoti_torch_new_tensor_handle test_aoti_torch_item_bool test_aoti_torch_assign_tensors_out diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index cf8bc5c93e5..b68043f7feb 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -38,6 +38,7 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_copy_") cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") + cuda_shim_cpp_unittest("aoti_torch_cuda_rand") cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") cuda_shim_cpp_unittest("aoti_torch_item_bool") cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_rand.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_rand.cpp new file mode 100644 index 00000000000..b776bb39497 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_rand.cpp @@ -0,0 +1,382 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Use explicit types to avoid ambiguity between different Tensor typedefs +using executorch::backends::cuda::aoti_torch_cuda_rand; +using executorch::backends::cuda::aoti_torch_cuda_randint_low_out; +using executorch::backends::cuda::aoti_torch_empty_strided; +using executorch::backends::cuda::AOTITorchError; +using executorch::runtime::Error; +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +// Tensor type definition using SlimTensor +using Tensor = executorch::backends::aoti::slim::SlimTensor; + +namespace { + +// Helper: convert raw bfloat16 bits (uint16_t) to float for value checks. +float bfloat16_bits_to_float(uint16_t bits) { + uint32_t expanded = static_cast(bits) << 16; + float result; + std::memcpy(&result, &expanded, sizeof(float)); + return result; +} + +} // namespace + +// Test fixture for aoti_torch_cuda_rand tests +class AOTITorchCudaRandTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + } + + // Helper: create a pre-allocated Int64 (Long) CUDA tensor used as the + // `out` argument for randint_low_out. + Tensor* create_int64_tensor(const std::vector& sizes) { + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, // default contiguous strides + static_cast(slim_c10::ScalarType::Long), + static_cast(slim_c10::DeviceType::CUDA), + 0, // device index + &tensor); + return (error == Error::Ok) ? tensor : nullptr; + } + + // Helper: copy a CUDA tensor's raw bytes back to host. + template + std::vector copy_to_host(const Tensor* tensor, int64_t numel) { + std::vector host(static_cast(numel)); + cudaError_t err = cudaMemcpy( + host.data(), + tensor->data_ptr(), + static_cast(numel) * sizeof(T), + cudaMemcpyDeviceToHost); + EXPECT_EQ(err, cudaSuccess) << "cudaMemcpy D2H failed"; + return host; + } +}; + +// ---------------------------------------------------------------------------- +// aoti_torch_cuda_rand tests +// ---------------------------------------------------------------------------- + +// Basic float32 rand: produces a tensor in [0, 1). +TEST_F(AOTITorchCudaRandTest, RandFloat32Basic) { + std::vector sizes = {4, 8}; + int64_t numel = 4 * 8; + int32_t dtype = static_cast(slim_c10::ScalarType::Float); + + Tensor* out = nullptr; + AOTITorchError error = aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + &dtype, + /*layout=*/nullptr, + /*device=*/nullptr, + /*device_index_=*/0, + /*pin_memory=*/nullptr, + &out); + + ASSERT_EQ(error, Error::Ok) << "aoti_torch_cuda_rand should succeed"; + ASSERT_NE(out, nullptr); + EXPECT_EQ(out->dim(), 2); + EXPECT_EQ(out->size(0), 4); + EXPECT_EQ(out->size(1), 8); + EXPECT_EQ(out->numel(), numel); + ASSERT_NE(out->data_ptr(), nullptr); + + cudaDeviceSynchronize(); + auto host = copy_to_host(out, numel); + for (int64_t i = 0; i < numel; ++i) { + EXPECT_GE(host[i], 0.0f) << "value at " << i << " < 0"; + EXPECT_LT(host[i], 1.0f) << "value at " << i << " >= 1"; + } +} + +// Default dtype path: when dtype is null the shim defaults to float32. +TEST_F(AOTITorchCudaRandTest, RandDefaultDtypeIsFloat) { + std::vector sizes = {16}; + Tensor* out = nullptr; + AOTITorchError error = aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + /*dtype=*/nullptr, + /*layout=*/nullptr, + /*device=*/nullptr, + /*device_index_=*/0, + /*pin_memory=*/nullptr, + &out); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(out, nullptr); + EXPECT_EQ(out->dtype(), slim_c10::ScalarType::Float); + EXPECT_EQ(out->numel(), 16); +} + +// BFloat16 rand: values must lie in [0, 1). +TEST_F(AOTITorchCudaRandTest, RandBFloat16Basic) { + std::vector sizes = {32}; + int64_t numel = 32; + int32_t dtype = static_cast(slim_c10::ScalarType::BFloat16); + + Tensor* out = nullptr; + AOTITorchError error = aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + &dtype, + /*layout=*/nullptr, + /*device=*/nullptr, + /*device_index_=*/0, + /*pin_memory=*/nullptr, + &out); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(out, nullptr); + EXPECT_EQ(out->dtype(), slim_c10::ScalarType::BFloat16); + EXPECT_EQ(out->numel(), numel); + + cudaDeviceSynchronize(); + auto host = copy_to_host(out, numel); + for (int64_t i = 0; i < numel; ++i) { + float v = bfloat16_bits_to_float(host[i]); + EXPECT_GE(v, 0.0f) << "bf16 value at " << i << " < 0"; + EXPECT_LT(v, 1.0f) << "bf16 value at " << i << " >= 1"; + } +} + +// Empty tensor: numel == 0 should be a no-op success. +TEST_F(AOTITorchCudaRandTest, RandEmptyTensor) { + std::vector sizes = {0, 4}; + int32_t dtype = static_cast(slim_c10::ScalarType::Float); + + Tensor* out = nullptr; + AOTITorchError error = aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + &dtype, + /*layout=*/nullptr, + /*device=*/nullptr, + /*device_index_=*/0, + /*pin_memory=*/nullptr, + &out); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(out, nullptr); + EXPECT_EQ(out->numel(), 0); +} + +// Unsupported dtype should return NotSupported. +TEST_F(AOTITorchCudaRandTest, RandUnsupportedDtypeFails) { + std::vector sizes = {8}; + // Long is not supported by aoti_torch_cuda_rand. + int32_t dtype = static_cast(slim_c10::ScalarType::Long); + + Tensor* out = nullptr; + AOTITorchError error = aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + &dtype, + /*layout=*/nullptr, + /*device=*/nullptr, + /*device_index_=*/0, + /*pin_memory=*/nullptr, + &out); + + EXPECT_EQ(error, Error::NotSupported); +} + +// Null ret0 should fail with InvalidArgument. +TEST_F(AOTITorchCudaRandTest, RandNullRet0Fails) { + std::vector sizes = {4}; + int32_t dtype = static_cast(slim_c10::ScalarType::Float); + + AOTITorchError error = aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + &dtype, + /*layout=*/nullptr, + /*device=*/nullptr, + /*device_index_=*/0, + /*pin_memory=*/nullptr, + /*ret0=*/nullptr); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Two invocations should advance the GPU-resident counter and produce +// different sequences (extremely high probability for non-trivial sizes). +TEST_F(AOTITorchCudaRandTest, RandTwoCallsProduceDifferentValues) { + std::vector sizes = {64}; + int64_t numel = 64; + int32_t dtype = static_cast(slim_c10::ScalarType::Float); + + Tensor* out_a = nullptr; + Tensor* out_b = nullptr; + + ASSERT_EQ( + aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + &dtype, + nullptr, + nullptr, + 0, + nullptr, + &out_a), + Error::Ok); + ASSERT_EQ( + aoti_torch_cuda_rand( + sizes.data(), + static_cast(sizes.size()), + &dtype, + nullptr, + nullptr, + 0, + nullptr, + &out_b), + Error::Ok); + + cudaDeviceSynchronize(); + auto host_a = copy_to_host(out_a, numel); + auto host_b = copy_to_host(out_b, numel); + + // The two draws must not be bit-for-bit identical. + bool any_diff = false; + for (int64_t i = 0; i < numel; ++i) { + if (host_a[i] != host_b[i]) { + any_diff = true; + break; + } + } + EXPECT_TRUE(any_diff) + << "two consecutive aoti_torch_cuda_rand calls produced identical values"; +} + +// ---------------------------------------------------------------------------- +// aoti_torch_cuda_randint_low_out tests +// ---------------------------------------------------------------------------- + +// Basic randint into a pre-allocated int64 tensor; values lie in [low, high). +TEST_F(AOTITorchCudaRandTest, RandintBasicRange) { + std::vector sizes = {32}; + int64_t numel = 32; + int64_t low = -5; + int64_t high = 17; + + Tensor* out = create_int64_tensor(sizes); + ASSERT_NE(out, nullptr); + + AOTITorchError error = aoti_torch_cuda_randint_low_out( + out, low, high, sizes.data(), static_cast(sizes.size())); + ASSERT_EQ(error, Error::Ok); + + cudaDeviceSynchronize(); + auto host = copy_to_host(out, numel); + for (int64_t i = 0; i < numel; ++i) { + EXPECT_GE(host[i], low); + EXPECT_LT(host[i], high); + } +} + +// Empty out tensor: numel == 0 should be a no-op success. +TEST_F(AOTITorchCudaRandTest, RandintEmptyTensor) { + std::vector sizes = {0}; + Tensor* out = create_int64_tensor(sizes); + ASSERT_NE(out, nullptr); + + AOTITorchError error = aoti_torch_cuda_randint_low_out( + out, /*low=*/0, /*high=*/10, sizes.data(), 1); + EXPECT_EQ(error, Error::Ok); +} + +// Null `out` tensor must return InvalidArgument. +TEST_F(AOTITorchCudaRandTest, RandintNullOutFails) { + std::vector sizes = {4}; + AOTITorchError error = aoti_torch_cuda_randint_low_out( + /*out=*/nullptr, 0, 10, sizes.data(), 1); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// `high <= low` must return InvalidArgument. +TEST_F(AOTITorchCudaRandTest, RandintInvalidRangeFails) { + std::vector sizes = {4}; + Tensor* out = create_int64_tensor(sizes); + ASSERT_NE(out, nullptr); + + // high == low + EXPECT_EQ( + aoti_torch_cuda_randint_low_out( + out, /*low=*/3, /*high=*/3, sizes.data(), 1), + Error::InvalidArgument); + // high < low + EXPECT_EQ( + aoti_torch_cuda_randint_low_out( + out, /*low=*/5, /*high=*/2, sizes.data(), 1), + Error::InvalidArgument); +} + +// Calling randint twice should produce different sequences via the on-device +// counter advance. With numel=1 (the typical Inductor seed-gen pattern) we +// run a few iterations to make collision extremely unlikely. +TEST_F(AOTITorchCudaRandTest, RandintAdvancesCounter) { + std::vector sizes = {1}; + Tensor* out = create_int64_tensor(sizes); + ASSERT_NE(out, nullptr); + + constexpr int kIters = 8; + std::vector draws; + draws.reserve(kIters); + for (int i = 0; i < kIters; ++i) { + AOTITorchError error = aoti_torch_cuda_randint_low_out( + out, + /*low=*/0, + /*high=*/std::numeric_limits::max(), + sizes.data(), + 1); + ASSERT_EQ(error, Error::Ok); + cudaDeviceSynchronize(); + auto host = copy_to_host(out, 1); + draws.push_back(host[0]); + } + + // Not all draws should be equal. + bool any_diff = false; + for (int i = 1; i < kIters; ++i) { + if (draws[i] != draws[0]) { + any_diff = true; + break; + } + } + EXPECT_TRUE(any_diff) << "randint counter did not advance across calls"; +} diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index f4feb6fae10..780bc420878 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -769,10 +769,11 @@ def _export_cuda(model, config, args): print("Exporting decode method...") decode_tokens = torch.tensor([[0]], dtype=torch.long) decode_pos = torch.tensor([0], dtype=torch.long) + decode_temperature = torch.tensor([1.0], dtype=torch.float32) with torch.no_grad(): decode_ep = export( model, - (decode_tokens, decode_pos), + (decode_tokens, decode_pos, decode_temperature), strict=True, ) print("Decode export successful!") @@ -784,18 +785,21 @@ def _export_cuda(model, config, args): # that reject longer prompts at runtime. _set_batched_moe(model, True) print("Exporting prefill method...") + example_prefill_len = config.max_seq_len - 1 prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long) prefill_pos = torch.arange(example_prefill_len, dtype=torch.long) + prefill_temperature = torch.tensor([1.0], dtype=torch.float32) seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1) prefill_dynamic_shapes = ( {1: seq_dim}, # tokens {0: seq_dim}, # input_pos + None, # temperature (static scalar tensor) ) with torch.no_grad(): prefill_ep = export( model, - (prefill_tokens, prefill_pos), + (prefill_tokens, prefill_pos, prefill_temperature), dynamic_shapes=prefill_dynamic_shapes, strict=True, ) diff --git a/examples/models/qwen3_5_moe/inference.py b/examples/models/qwen3_5_moe/inference.py index c824f6a6444..6f9aa933535 100644 --- a/examples/models/qwen3_5_moe/inference.py +++ b/examples/models/qwen3_5_moe/inference.py @@ -77,46 +77,45 @@ def generate( Prefills one token at a time (the recurrent path; chunked FLA via @triton_op is used for T>1 prefill in the exported PTE). + + The model performs Gumbel-max sampling on-device: forward() returns + a sampled token ID [B, 1] instead of logits [B, T, V]. """ if eos_token_ids is None: eos_token_ids = set() input_ids = tokenizer.encode(prompt).ids + # Temperature tensor (use small epsilon for greedy to avoid div-by-zero) + temp_val = max(temperature, 1e-6) + temp_tensor = torch.tensor([temp_val], dtype=torch.float32, device="cuda") + # Prefill: one token at a time with torch.no_grad(): for i, tok_id in enumerate(input_ids): tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") pos = torch.tensor([i], dtype=torch.long, device="cuda") - logits = model(tok, pos) + sampled = model(tok, pos, temp_tensor) - # Sample first generated token - next_token = _sample(logits[:, -1, :], temperature) - generated = [next_token.item()] + # First generated token (model returns [B, 1] float token ID) + next_token_id = int(sampled.item()) + generated = [next_token_id] # Decode: one token at a time seq_len = len(input_ids) with torch.no_grad(): for i in range(max_new_tokens - 1): - pos = torch.tensor([seq_len + i], device="cuda") - logits = model(next_token.unsqueeze(0), pos) - next_token = _sample(logits[:, -1, :], temperature) - tok_id = next_token.item() - generated.append(tok_id) - if tok_id in eos_token_ids: + tok = torch.tensor([[next_token_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([seq_len + i], dtype=torch.long, device="cuda") + sampled = model(tok, pos, temp_tensor) + next_token_id = int(sampled.item()) + generated.append(next_token_id) + if next_token_id in eos_token_ids: break return tokenizer.decode(generated) -def _sample(logits, temperature): - """Sample from logits with temperature.""" - if temperature <= 0: - return logits.argmax(dim=-1) - probs = torch.softmax(logits / temperature, dim=-1) - return torch.multinomial(probs, num_samples=1).squeeze(-1) - - def main(): parser = argparse.ArgumentParser( description="Run inference on prequantized Qwen3.5 MoE" diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 89ee58baa4b..434f6764983 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include #include @@ -47,6 +47,33 @@ using ::executorch::runtime::EValue; using SizesType = executorch::aten::SizesType; +// Read a sampled token from the model output tensor [B, 1]. +// The model performs Gumbel-max sampling on-device and returns a single +// float token ID. This function copies it from GPU and casts to uint64. +static uint64_t read_token(const executorch::aten::Tensor& output) { + const void* ptr = output.const_data_ptr(); + + cudaPointerAttributes attrs; + bool on_device = cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && + attrs.type == cudaMemoryTypeDevice; + + float val; + if (on_device) { + cudaError_t err = + cudaMemcpy(&val, ptr, sizeof(float), cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + ET_LOG( + Error, + "read_token: cudaMemcpy D2H failed: %s", + cudaGetErrorString(err)); + return 0; + } + } else { + memcpy(&val, ptr, sizeof(float)); + } + return static_cast(val); +} + int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -80,7 +107,7 @@ int main(int argc, char** argv) { return 1; } - // Create Module with share_memory_arenas=true so prefill and forward + // Create Module with share_memory_arenas=true so prefill and decode // share mutable buffers (KV cache, conv_state, recurrent_state). std::vector data_files; if (!FLAGS_data_path.empty()) { @@ -160,10 +187,20 @@ int main(int argc, char** argv) { stats.inference_start_ms = llm::time_in_ms(); // --------------------------------------------------------------- - // Prefill or decode-only + // Sampling tensors (shared between prefill and decode) // --------------------------------------------------------------- auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + // Use a very small temperature for greedy to avoid division by zero + // while keeping the Gumbel noise negligible relative to logit differences. + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); + + // --------------------------------------------------------------- + // Prefill + // --------------------------------------------------------------- uint64_t cur_token = 0; // Use prefill method for T>=2, decode method for T=1 @@ -190,6 +227,7 @@ int main(int argc, char** argv) { std::vector prefill_inputs; prefill_inputs.push_back(tokens_tensor); prefill_inputs.push_back(pos_tensor); + prefill_inputs.push_back(temp_tensor); auto prefill_result = module->execute(run_method, prefill_inputs); if (prefill_result.error() != Error::Ok) { @@ -198,10 +236,7 @@ int main(int argc, char** argv) { } auto& prefill_outputs = prefill_result.get(); - auto logits_tensor = prefill_outputs[0].toTensor(); - auto logits_ptr = - std::make_shared(std::move(logits_tensor)); - cur_token = llm::logits_to_token(*logits_ptr, FLAGS_temperature); + cur_token = read_token(prefill_outputs[0].toTensor()); stats.prompt_eval_end_ms = llm::time_in_ms(); @@ -240,6 +275,7 @@ int main(int argc, char** argv) { std::vector decode_inputs; decode_inputs.push_back(EValue(decode_tokens)); decode_inputs.push_back(EValue(decode_pos)); + decode_inputs.push_back(EValue(temp_tensor)); auto decode_result = module->execute("decode", decode_inputs); if (decode_result.error() != Error::Ok) { @@ -248,14 +284,8 @@ int main(int argc, char** argv) { } auto& decode_outputs = decode_result.get(); - auto step_logits = decode_outputs[0].toTensor(); - auto step_logits_ptr = - std::make_shared(std::move(step_logits)); - prev_token = cur_token; - stats.on_sampling_begin(); - cur_token = llm::logits_to_token(*step_logits_ptr, FLAGS_temperature); - stats.on_sampling_end(); + cur_token = read_token(decode_outputs[0].toTensor()); if (step == 0) { stats.first_token_ms = llm::time_in_ms(); diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 261008b43e2..d8186e24d7a 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -17,10 +17,12 @@ import os import re from dataclasses import dataclass, field +from typing import Optional import torch import torch.nn as nn +from executorch.examples.models.qwen3_5_moe.sampler import sample from torch.nn import functional as F @@ -608,6 +610,10 @@ def forward(self, x, input_pos): return x +# --------------------------------------------------------------------------- +# Top-level model + + class Qwen35MoE(nn.Module): def __init__(self, config): @@ -621,13 +627,30 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward( - self, tokens: torch.LongTensor, input_pos: torch.LongTensor + self, + tokens: torch.LongTensor, + input_pos: torch.LongTensor, + temperature: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = self.embed_tokens(tokens) for layer in self.layers: x = layer(x, input_pos) x = self.norm(x) - return self.lm_head(x) + # When no sampling is requested, return the full ``[B, T, V]`` + # logits so callers (eval, custom samplers) can inspect every + # position. Otherwise apply the prefill optimization and only + # materialize ``[B, V]`` for the last token. + if temperature is None: + return self.lm_head(x).float() # [B, T, V] float32 + logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32 + # GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is + # equivalent to drawing from softmax(logits/T) but stays entirely + # on-device. Algorithm reference: + # https://huggingface.co/blog/cxdu/fastsampling + # TODO(gasoonjia): once the on-device sampling stack lands, promote + # ``sample`` into a shared CUDA sampling utility reusable by other + # models, and add top-k / top-p filtering support. + return sample(logits, temperature) # [B, 1] @staticmethod def from_hf_checkpoint(model_dir, max_seq_len=4096): diff --git a/examples/models/qwen3_5_moe/sampler.py b/examples/models/qwen3_5_moe/sampler.py new file mode 100644 index 00000000000..dd1cec5bcd6 --- /dev/null +++ b/examples/models/qwen3_5_moe/sampler.py @@ -0,0 +1,60 @@ +""" +GPU-side Gumbel-max sampler. + +Self-contained sampling utility that can be imported by other models. Lives +in its own file so it can be reused without pulling in the heavy MoE module. + +``temperature`` is a runtime tensor so a single exported program can be +re-driven with different sampling configurations without re-export. + +""" + +from typing import Optional + +import torch + + +def sample( + logits: torch.Tensor, + temperature: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """GPU-side Gumbel-max sampler. + + When ``temperature`` is ``None`` (the eager / eval default) the function + is a no-op and returns ``logits`` unchanged — useful for callers that + just want to inspect raw logits. + + Otherwise it draws from ``softmax(logits / temperature)`` entirely + on-device using the Gumbel-max trick: + ``argmax(logits / T + gumbel_noise)`` + (reference: https://huggingface.co/blog/cxdu/fastsampling). + + NOTE: the ``1e-20`` epsilons used in the Gumbel transform assume + float32 logits. The contract is documented as ``[B, V]`` float32 and + callers are expected to ``.float()``-cast before invoking ``sample``. + + TODO(gasoonjia): add top-k / top-p filtering support in a follow-up PR. + + Args: + logits: ``[B, V]`` float32 logits. + temperature: 0-D or 1-D float tensor (clamped to >= 1e-6 to avoid + divide-by-zero). ``None`` skips temperature scaling and the + sampler returns the unmodified ``logits`` tensor. + + Returns: + ``[B, 1]`` float32 tensor of sampled token IDs, or the unmodified + ``logits`` tensor when ``temperature`` is ``None``. + """ + # No sampling configured — return raw logits. + if temperature is None: + return logits + + logits = logits / temperature.clamp(min=1e-6) + + # Gumbel-max sampling — equivalent to sampling from softmax(logits) + # but fully on-device and CUDA-graph friendly. The 1e-20 epsilons are + # safe for float32 (smallest positive normal ~1.18e-38) — see the + # float32 note in the docstring. + noise = torch.rand_like(logits) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + return (logits + gumbel).argmax(dim=-1, keepdim=True).float() diff --git a/examples/models/qwen3_5_moe/test_sampler.py b/examples/models/qwen3_5_moe/test_sampler.py new file mode 100644 index 00000000000..260c3029525 --- /dev/null +++ b/examples/models/qwen3_5_moe/test_sampler.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the standalone ``sample`` function in +``examples/models/qwen3_5_moe/sampler.py``. + +``temperature`` is a runtime scalar tensor so the same exported graph can +be re-driven with different sampling configurations without re-export. + +NOTE: top-k / top-p tests are intentionally omitted — that support is +deferred to a follow-up PR. + +Usage: + python -m pytest examples/models/qwen3_5_moe/test_sampler.py -v +""" + +import unittest + +import torch + +from executorch.examples.models.qwen3_5_moe.sampler import sample + + +def _temp(value: float = 1.0) -> torch.Tensor: + return torch.tensor([value], dtype=torch.float32) + + +class TestSampler(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + # ------------------------------------------------------------------ + # No-op path: when temperature is None the function returns the + # input logits unchanged. + # ------------------------------------------------------------------ + def test_temperature_none_returns_logits(self): + logits = torch.randn(2, 8) + out = sample(logits) + self.assertIs(out, logits) + self.assertEqual(out.shape, (2, 8)) + + # ------------------------------------------------------------------ + # Output shape / dtype contract when sampling is enabled. + # ------------------------------------------------------------------ + def test_output_shape_and_dtype(self): + logits = torch.randn(3, 17) + out = sample(logits, temperature=_temp(1.0)) + self.assertEqual(out.shape, (3, 1)) + self.assertEqual(out.dtype, torch.float32) + self.assertTrue(torch.all(out >= 0)) + self.assertTrue(torch.all(out < logits.size(-1))) + + # ------------------------------------------------------------------ + # Sampling with temperature matches the inline Gumbel-max sampler + # bit-for-bit. + # ------------------------------------------------------------------ + def test_temperature_matches_legacy_gumbel(self): + logits = torch.randn(2, 32) + temperature = _temp(0.8) + + torch.manual_seed(123) + scaled = logits / temperature.clamp(min=1e-6) + noise = torch.rand_like(scaled) + gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) + expected = (scaled + gumbel).argmax(dim=-1, keepdim=True).float() + + torch.manual_seed(123) + actual = sample(logits, temperature=temperature) + self.assertTrue(torch.equal(actual, expected)) + + # ------------------------------------------------------------------ + # Low temperature → near-greedy sampling. + # ------------------------------------------------------------------ + def test_low_temperature_is_near_greedy(self): + torch.manual_seed(0) + logits = torch.randn(1, 32) * 2.0 + argmax_id = int(logits.argmax(dim=-1).item()) + + agree = 0 + trials = 100 + for trial in range(trials): + torch.manual_seed(trial + 500) + tok = int(sample(logits, temperature=_temp(1e-4)).item()) + if tok == argmax_id: + agree += 1 + self.assertGreaterEqual(agree, int(trials * 0.9)) + + # ------------------------------------------------------------------ + # Empirical distribution check: top-1 should be the most frequent. + # ------------------------------------------------------------------ + def test_distribution_peaks_at_argmax(self): + torch.manual_seed(0) + logits = torch.tensor([[3.0, 1.0, 0.5, -1.0, 0.0]]) + argmax_id = int(logits.argmax(dim=-1).item()) + + counts = torch.zeros(logits.size(-1), dtype=torch.long) + trials = 2000 + for trial in range(trials): + torch.manual_seed(trial) + tok = int(sample(logits, temperature=_temp(1.0)).item()) + counts[tok] += 1 + self.assertEqual(int(counts.argmax().item()), argmax_id) + + # ------------------------------------------------------------------ + # Runtime control: changing temperature between calls produces + # different draws without re-creating the graph. + # ------------------------------------------------------------------ + def test_temperature_runtime_controllable(self): + torch.manual_seed(0) + logits = torch.randn(1, 64) * 4.0 + argmax_id = int(logits.argmax(dim=-1).item()) + + torch.manual_seed(7) + cold = int(sample(logits, temperature=_temp(1e-4)).item()) + torch.manual_seed(7) + hot = int(sample(logits, temperature=_temp(5.0)).item()) + + # Cold sampling should hit argmax; hot sampling is unconstrained + # but is still a valid token id. + self.assertEqual(cold, argmax_id) + self.assertGreaterEqual(hot, 0) + self.assertLess(hot, logits.size(-1)) + + # ------------------------------------------------------------------ + # CUDA smoke test. + # ------------------------------------------------------------------ + def test_runs_on_cuda(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + logits = torch.randn(2, 64, device="cuda") + temperature = torch.tensor([0.9], dtype=torch.float32, device="cuda") + out = sample(logits, temperature=temperature) + self.assertEqual(out.device.type, "cuda") + self.assertEqual(out.shape, (2, 1)) + + +if __name__ == "__main__": + unittest.main()