Add CUDA graph capture/replay for qwen 3.5 moe decode method#18809
Add CUDA graph capture/replay for qwen 3.5 moe decode method#18809
Conversation
The chunked FLA pipeline (6 Triton kernels) is overkill for T=1 decode. Replace with plain PyTorch einsum ops that Inductor can fuse: - FLA GPU time: 1.085ms → 0.344ms/step (-68%) - Total GPU time: 12.0ms → 9.0ms/step (-25%) - Export changed to static T=1 with enable_dynamic_shape=False
Move decode/prefill dispatch inside the chunk_gated_delta_rule triton_op instead of using torch.cond at model level. This follows the same pattern as the SDPA triton_op (pow2/non-pow2 dispatch) and avoids torch.cond incompatibility with AOTI's FunctionalTensor pipeline. Changes: - chunk_gated_delta_rule.py: Add fused recurrent Triton kernel for T=1, refactor chunked pipeline into _launch_chunked(), dispatch via Python if inside the @triton_op wrapper - model.py: Remove torch.cond from GatedDeltaNet.forward(), call triton_op directly (dispatch is internal) - export.py: Single-method export with dynamic seq_len dim - main.cpp: Fix create_text_llm_runner API signature
Only chunk_gated_delta_rule.py needs modification — dispatch logic is internal to the triton_op, no model/export/runner changes needed.
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive
reference across all FLA test configs
- test_dispatch_multiple_seq_lengths: verify correctness for
T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch
paths and chunk boundary edge cases
- Grid changed from (B*H,) to (V//BV, B*H) — 4x more blocks, better SM occupancy (128 blocks vs 32 on A100) - BV reduced from 128 to 32 — lower register pressure, no spilling - Removed unnecessary .contiguous() copies on squeezed inputs - Removed debug print from triton_op dispatch - GPU kernel time: 6us (3.47x faster than Inductor-fused native ops)
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch recurrent delta rule) methods with explicit state passing - Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect - Keep state tensors GPU-resident across method calls; only copy logits to CPU for sampling via cudaMemcpy - Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline) Modified files: - cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream - main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper - CMakeLists.txt: link CUDA::cudart for cudaMemcpy - model.py: dual-method model definition (prefill + decode) - export.py: export script for dual-method PTE
Revert from explicit state passing back to registered buffers with in-place updates (KVCache, conv_state, recurrent_state). Export with share_mutable_buffers=True so both prefill and forward methods share mutable state via mem_id=2. C++ runner uses share_memory_arenas=true and only passes (tokens, input_pos) — no CUDA runtime dependency. Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile, 65 D2H memcpy (logits only).
Add runtime buffer sharing between AOTI containers so that prefill and decode methods operate on the same GPU tensors (KV cache, conv_state, etc.) without unnecessary H2D/D2H copies or getter/setter overhead. The first container to initialize extracts its constants (keyed by original FQN). Subsequent containers with matching FQNs are updated via AOTInductorModelContainerUpdateUserManagedConstantBufferPairs to point to the same GPU memory (user_managed = true, no copy). Also switch main.cpp prefill to token-by-token decode path while the chunked FLA triton_op numerical issue is being resolved. Tested E2E: "What is the capital of France?" → "Paris" with 966 constants shared between prefill and decode containers on A100.
- cuda_backend.cpp: Use codegen name (from GetConstantName) instead of original FQN when calling UpdateUserManagedConstantBufferPairs. The AOTI API matches against internal codegen names, not FQNs — using FQNs caused silent no-op sharing, breaking KV cache flow between prefill and decode. - main.cpp: Add chunked prefill path using the "prefill" method (T>=2) with cudaDeviceSynchronize between prefill and decode for cross-stream safety. Add --decode_only flag to fall back to token-by-token decode for all tokens. - inference.py: Update docstring to reflect that chunked FLA is used in PTE mode (not eager). Verified E2E: "What is the capital of France?" → "The capital of France is Paris." Prefill: 105 tok/s (chunked FLA), Decode: 87 tok/s (recurrent delta rule).
- cuda_backend.cpp: Replace debug printf with ET_LOG for errors/info only - main.cpp: Remove --decode_only flag, keep only chunked prefill path
- cuda_backend.cpp: Replace ET_CHECK_OK_OR_RETURN_ERROR with explicit error handling + cudaDeviceSynchronize after weight transfer, add logging for missing weights_blob - main.cpp: Support single "forward" method fallback when prefill/decode not available, use prefill_method variable, remove debug printf
Implements CUDA graph support in the CUDA backend to reduce CPU kernel launch overhead during autoregressive decoding: - cuda_backend.cpp: 3-phase execution (warmup → capture → replay) with static input/output GPU buffers, cudaMemcpyAsync for I/O, and cudaGraphInstantiateFlagAutoFreeOnLaunch for cudaMallocAsync compat - cuda_delegate_handle.h: CUDA graph state (phase, graph objects, static buffer metadata) with RAII cleanup in destructor - main.cpp: --cuda_graph flag that sets BackendOptions before load_method - test_model_e2e.sh: Enable --cuda_graph for Qwen3.5 MoE CI, set PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync Benchmark (A100, Qwen3.5-35B-A3B HQQ-INT4): 82→98 tok/s (1.20x)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18809
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 7 New Failures, 14 Unrelated FailuresAs of commit ee75c2e with merge base de61304 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| nbytes, cudaMemcpyHostToDevice); | ||
|
|
||
| handle->static_input_ptrs.push_back(static_ptr); | ||
| handle->static_input_sizes.push_back(sizes_vec); |
There was a problem hiding this comment.
Does this mean a method must have all static shapes for its intermediate tensors?
There was a problem hiding this comment.
yes the cuda graph can only support a static pipeline.
This PR needs a
|
No description provided.