diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 5cee37b19cf..3a7fe02af0c 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -397,6 +397,27 @@ if [ -n "$EXPECTED_OUTPUT" ]; then else echo "SUCCESS: Runner completed successfully" fi + +# Validate GPU peak memory usage for models with known memory budgets. +# The runner prints "GPU peak memory usage: XXXX.X MiB" at the end. +case "$MODEL_NAME" in + qwen3_5_moe) + MAX_MEMORY_MIB=20480 # 20 GB — must fit on a single GPU (e.g. 4090) + PEAK_MEM=$(echo "$OUTPUT" | grep -oP 'GPU peak memory usage: \K[0-9.]+' || true) + if [ -n "$PEAK_MEM" ]; then + # Compare as integers (truncate decimals) + PEAK_MEM_INT=${PEAK_MEM%%.*} + if [ "$PEAK_MEM_INT" -gt "$MAX_MEMORY_MIB" ]; then + echo "FAIL: GPU peak memory ${PEAK_MEM} MiB exceeds budget ${MAX_MEMORY_MIB} MiB" + exit 1 + else + echo "Success: GPU peak memory ${PEAK_MEM} MiB within budget (max ${MAX_MEMORY_MIB} MiB)" + fi + else + echo "WARNING: GPU peak memory usage not found in output" + fi + ;; +esac echo "::endgroup::" popd diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index eb0a07b8d8f..0e9c5a32473 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -80,7 +80,6 @@ namespace { constexpr char kSkipCopyOutputToCpuForMethod[] = "skip_copy_output_to_cpu_for_method"; constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream"; -constexpr char kShareKvCacheAcrossMethods[] = "share_kv_cache_across_methods"; } // anonymous namespace class ET_EXPERIMENTAL CudaBackend final @@ -288,17 +287,11 @@ class ET_EXPERIMENTAL CudaBackend final ArrayRef compile_specs // This will be my empty list ) const override { std::string method_name; - bool share_kv_cache = false; for (const CompileSpec& spec : compile_specs) { if (std::strcmp(spec.key, "method_name") == 0) { method_name.assign( static_cast(spec.value.buffer), spec.value.nbytes); // no nullptr guarantee, so pass size - } else if (std::strcmp(spec.key, kShareKvCacheAcrossMethods) == 0) { - if (spec.value.nbytes >= 1) { - share_kv_cache = - static_cast(spec.value.buffer)[0] != 0; - } } } @@ -369,30 +362,11 @@ class ET_EXPERIMENTAL CudaBackend final handle->container_handle = container_handle; - // Look into named data map for constant data - std::string weights_blob_key = - method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; - auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); - if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { - ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); - const void* weights_blob = buffer_res->data(); - // Feed the weights blob into the container. Under the hood it's copying - // weights, so we should free the buffer immediately. - auto update_err = handle->update_constants_from_blob( - handle->container_handle, static_cast(weights_blob)); - if (update_err != Error::Ok) { - ET_LOG(Error, "update_constants_from_blob failed"); - return update_err; - } - // Ensure all weight transfers are complete before execution - cudaDeviceSynchronize(); - buffer_res->Free(); - } else { - ET_LOG( - Info, - "weights_blob '%s' not found or update fn is null", - weights_blob_key.c_str()); - } + // Load constants with per-weight caching. + // This replaces the old update_constants_from_blob + cross-method sharing + // with a unified approach that avoids duplicate GPU allocations. + ET_CHECK_OK_OR_RETURN_ERROR( + load_constants_with_cache(handle, named_data_map, method_name)); // Use shared CUDA stream if enabled via options, otherwise create one. // A shared stream ensures proper ordering across multiple methods @@ -419,119 +393,6 @@ class ET_EXPERIMENTAL CudaBackend final method_name.c_str()); } - // --------------------------------------------------------------- - // Cross-method constant sharing (e.g., KV cache between prefill/decode). - // - // Only enabled when share_kv_cache_across_methods compile spec is set. - // The first container to initialize extracts its constants (keyed by - // original FQN) and stores the AtenTensorHandle's. Subsequent containers - // with matching FQNs are updated to point to the same GPU tensors via - // UpdateUserManagedConstantBufferPairs (user_managed = true → no copy, - // the source container retains ownership). - // --------------------------------------------------------------- - if (share_kv_cache && handle->get_num_constants && - handle->get_constant_name && handle->get_constant_original_fqn && - handle->extract_constants_map && - handle->update_user_managed_constant_buffer_pairs) { - size_t num_constants = 0; - handle->get_num_constants(handle->container_handle, &num_constants); - - if (num_constants > 0) { - // Build FQN → internal_name mapping for this container. - std::unordered_map fqn_to_name; - for (size_t i = 0; i < num_constants; i++) { - const char* name = nullptr; - const char* fqn = nullptr; - handle->get_constant_name(handle->container_handle, i, &name); - handle->get_constant_original_fqn(handle->container_handle, i, &fqn); - if (name && fqn && fqn[0] != '\0') { - fqn_to_name[fqn] = name; - } - } - - std::lock_guard guard(shared_constants_mutex_); - - if (!constants_extracted_) { - // First container: extract its constants and store by FQN. - std::unordered_map extracted_map; - auto extract_err = handle->extract_constants_map( - handle->container_handle, - reinterpret_cast(&extracted_map), - /*use_inactive=*/false); - - if (extract_err == Error::Ok) { - for (const auto& [fqn, internal_name] : fqn_to_name) { - auto it = extracted_map.find(fqn); - if (it != extracted_map.end()) { - shared_constant_tensors_[fqn] = it->second; - } - } - constants_extracted_ = true; - ET_LOG( - Info, - "Extracted %zu shared constants from method '%s'", - shared_constant_tensors_.size(), - method_name.c_str()); - } else { - ET_LOG( - Error, - "Failed to extract constants from '%s'", - method_name.c_str()); - delete handle; - return Error::Internal; - } - } else { - // Subsequent container: share matching constants from the first. - std::vector pairs; - for (const auto& [fqn, internal_name] : fqn_to_name) { - auto it = shared_constant_tensors_.find(fqn); - if (it != shared_constant_tensors_.end()) { - // UpdateUserManagedConstantBufferPairs matches against the - // codegen constant name (underscored), not the original FQN. - pairs.push_back({internal_name.c_str(), it->second}); - } - } - - if (!pairs.empty()) { - auto update_err = handle->update_user_managed_constant_buffer_pairs( - handle->container_handle, - pairs.data(), - pairs.size(), - /*use_inactive=*/false, - /*validate_full_update=*/false); - - if (update_err == Error::Ok) { - ET_LOG( - Info, - "Shared %zu constants into method '%s'", - pairs.size(), - method_name.c_str()); - } else { - ET_LOG( - Error, - "Failed to share constants into '%s'", - method_name.c_str()); - delete handle; - return Error::Internal; - } - } - } - } - } else if (share_kv_cache) { - ET_LOG( - Error, - "share_kv_cache_across_methods requested but constant sharing APIs " - "not available for method '%s'", - method_name.c_str()); - delete handle; - return Error::Internal; - } else { - ET_LOG( - Info, - "Constant sharing not requested for method '%s'", - method_name.c_str()); - } - return (DelegateHandle*)handle; // Return the handle post-processing } @@ -778,11 +639,220 @@ class ET_EXPERIMENTAL CudaBackend final unordered_map> cached_outputs_; - // Cross-method constant sharing state. - // When multiple AOTI containers share mutable buffers (e.g., KV cache), - // the first container's constants are extracted and stored here. Subsequent - // containers with matching FQNs share the same GPU tensors via - // UpdateUserManagedConstantBufferPairs. + // --------------------------------------------------------------- + // Per-weight constant cache. + // + // Maintains a singleton FQN → AtenTensorHandle cache across methods. + // When loading constants for a method, constants already in the cache + // are reused (zero-copy via update_user_managed_constant_buffer_pairs). + // Only constants not in the cache are loaded from the blob and added + // to the cache. This avoids duplicate GPU allocations when multiple + // methods (e.g., prefill/decode) share the same weights. + // --------------------------------------------------------------- + + // Load constants for a method using per-weight caching. + // Returns Error::Ok on success. + // + // Flow: + // 1. Enumerate this method's constants and their FQNs. + // 2. For each constant: + // - If FQN is in shared_constant_tensors_ → reuse (cache hit). + // - Otherwise → mark as needing loading (cache miss). + // 3. If all constants are cached → skip blob loading entirely. + // Otherwise → call update_constants_from_blob to load all, then + // extract and cache the new constants. + // 4. For cached constants, call update_user_managed_constant_buffer_pairs + // to point the container to the shared GPU tensors. + Error load_constants_with_cache( + cuda::CudaDelegateHandle* handle, + const NamedDataMap* named_data_map, + const std::string& method_name) const { + // Check if the required APIs are available + if (!handle->get_num_constants || !handle->get_constant_name || + !handle->get_constant_original_fqn || !handle->extract_constants_map || + !handle->update_user_managed_constant_buffer_pairs) { + // Fall back to the legacy path + return load_constants_legacy(handle, named_data_map, method_name); + } + + // Step 1: Enumerate constants and partition into cached/uncached + size_t num_constants = 0; + handle->get_num_constants(handle->container_handle, &num_constants); + if (num_constants == 0) { + ET_LOG(Info, "No constants for method '%s'", method_name.c_str()); + return Error::Ok; + } + + // Build FQN → internal_name mapping and determine cache hits/misses + std::unordered_map fqn_to_name; + std::vector uncached_fqns; + + { + std::lock_guard guard(shared_constants_mutex_); + for (size_t i = 0; i < num_constants; i++) { + const char* name = nullptr; + const char* fqn = nullptr; + handle->get_constant_name(handle->container_handle, i, &name); + handle->get_constant_original_fqn(handle->container_handle, i, &fqn); + if (name && fqn && fqn[0] != '\0') { + fqn_to_name[fqn] = name; + if (shared_constant_tensors_.find(fqn) == + shared_constant_tensors_.end()) { + uncached_fqns.push_back(fqn); + } + } + } + } + + size_t num_cached = fqn_to_name.size() - uncached_fqns.size(); + ET_LOG( + Info, + "Method '%s': %zu constants, %zu cached, %zu uncached", + method_name.c_str(), + fqn_to_name.size(), + num_cached, + uncached_fqns.size()); + + // Step 2: Load uncached constants from blob (if any) + if (!uncached_fqns.empty()) { + // Need to load from blob — use update_constants_from_blob for all, + // then extract the new constants into the cache. + std::string weights_blob_key = + method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; + auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); + + if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { + ET_LOG( + Info, + "Loading constants from blob '%s' for method '%s'", + weights_blob_key.c_str(), + method_name.c_str()); + const void* weights_blob = buffer_res->data(); + auto update_err = handle->update_constants_from_blob( + handle->container_handle, + static_cast(weights_blob)); + if (update_err != Error::Ok) { + ET_LOG(Error, "update_constants_from_blob failed"); + return update_err; + } + cudaDeviceSynchronize(); + buffer_res->Free(); + } else { + ET_LOG( + Error, + "weights_blob '%s' not found or update fn is null", + weights_blob_key.c_str()); + return Error::NotFound; + } + + // Extract all constants and cache the newly loaded ones + std::unordered_map extracted_map; + auto extract_err = handle->extract_constants_map( + handle->container_handle, + reinterpret_cast(&extracted_map), + /*use_inactive=*/false); + + if (extract_err == Error::Ok) { + std::lock_guard guard(shared_constants_mutex_); + for (const auto& fqn : uncached_fqns) { + auto it_name = fqn_to_name.find(fqn); + if (it_name == fqn_to_name.end()) + continue; + // extract_constants_map returns entries keyed by FQN + auto it = extracted_map.find(fqn); + if (it != extracted_map.end()) { + shared_constant_tensors_[fqn] = it->second; + } + } + ET_LOG( + Info, + "Cached %zu new constants from method '%s' (total cache: %zu)", + uncached_fqns.size(), + method_name.c_str(), + shared_constant_tensors_.size()); + } else { + ET_LOG( + Error, + "Failed to extract constants from '%s'", + method_name.c_str()); + return Error::Internal; + } + } else { + // All constants are cached — skip blob loading entirely! + ET_LOG( + Info, + "All %zu constants cached — skipping blob load for method '%s'", + fqn_to_name.size(), + method_name.c_str()); + } + + // Step 3: Point the container to cached tensors via user_managed pairs + if (num_cached > 0 || uncached_fqns.empty()) { + std::vector pairs; + { + std::lock_guard guard(shared_constants_mutex_); + for (const auto& [fqn, internal_name] : fqn_to_name) { + auto it = shared_constant_tensors_.find(fqn); + if (it != shared_constant_tensors_.end()) { + pairs.push_back({internal_name.c_str(), it->second}); + } + } + } + + if (!pairs.empty()) { + auto update_err = handle->update_user_managed_constant_buffer_pairs( + handle->container_handle, + pairs.data(), + pairs.size(), + /*use_inactive=*/false, + /*validate_full_update=*/false); + + if (update_err != Error::Ok) { + ET_LOG( + Error, + "Failed to set cached constants for method '%s'", + method_name.c_str()); + return Error::Internal; + } + ET_LOG( + Info, + "Shared %zu cached constants into method '%s'", + pairs.size(), + method_name.c_str()); + } + } + + return Error::Ok; + } + + // Legacy constant loading: load the entire blob without caching. + // Used as fallback when constant management APIs are unavailable. + Error load_constants_legacy( + cuda::CudaDelegateHandle* handle, + const NamedDataMap* named_data_map, + const std::string& method_name) const { + std::string weights_blob_key = + method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; + auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); + if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { + ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); + const void* weights_blob = buffer_res->data(); + auto update_err = handle->update_constants_from_blob( + handle->container_handle, static_cast(weights_blob)); + if (update_err != Error::Ok) { + ET_LOG(Error, "update_constants_from_blob failed"); + return update_err; + } + cudaDeviceSynchronize(); + buffer_res->Free(); + } else { + ET_LOG( + Info, + "weights_blob '%s' not found or update fn is null", + weights_blob_key.c_str()); + } + return Error::Ok; + } mutable std::mutex shared_constants_mutex_; // FQN → AtenTensorHandle from the source (first) container. @@ -790,9 +860,6 @@ class ET_EXPERIMENTAL CudaBackend final // explicitly deleted — see destroy() comment). mutable std::unordered_map shared_constant_tensors_; - - // Whether we've already extracted constants from a source container. - mutable bool constants_extracted_ = false; }; } // namespace executorch::backends::cuda diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 7f4e60596be..968113d7a49 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -249,5 +249,15 @@ int main(int argc, char** argv) { num_generated * 1000.0 / decode_ms); printf("Prompt tokens: %ld\n", num_prompt_tokens); + // Print peak GPU memory usage + { + size_t free_bytes = 0, total_bytes = 0; + if (cudaMemGetInfo(&free_bytes, &total_bytes) == cudaSuccess) { + double used_mb = + static_cast(total_bytes - free_bytes) / (1024.0 * 1024.0); + printf("GPU peak memory usage: %.1f MiB\n", used_mb); + } + } + return 0; }