Add CUDA sort shim for AOTI export (thrust-based sort_stable fallback)#18829
Add CUDA sort shim for AOTI export (thrust-based sort_stable fallback)#18829digantdesai wants to merge 5 commits intomainfrom
Conversation
Inductor emits aten::sort.stable for ops like argsort, but lacks a native c-shim for it. This adds a thrust-based implementation (aoti_torch_cuda_sort_stable) that handles int64, int32, and float32 dtypes on contiguous innermost-dim tensors. Registered as a supported fallback kernel in CudaBackend so AOTI-compiled models can use sort.
Sweeps prompt lengths [1..4095] with Qwen3.5-35B-A3B shapes (256 experts, top-8, INT4 W4A16). Validates correctness against loop-based eager reference at small M, benchmarks vectorized eager, torch.compile, and Triton fused_moe. Handles OOM gracefully at large M where eager/compile dequantize all experts.
When the Triton tile size fits within a single quantization group, load one scale per N-element instead of per (K, N) element. Reduces scale memory traffic in both GEMM1 and GEMM2 vec-mat kernels.
Adds a batched (M>1) Triton fused MoE kernel using tensor-core mma instructions for prefill workloads. Includes moe_align_block_size for token-expert sorting and scale broadcast optimization in the batched GEMM inner loops. Weight layout: [E, N, K//2] (packed INT4).
Add use_batched_moe flag on FusedMoEExperts, toggled by _set_batched_moe in export.py before each method's torch.export call. Decode (T=1) uses the vec-mat fused_moe kernel; prefill (T>=2) uses fused_moe_batched_gemm.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18829
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New Failures, 4 Unrelated FailuresAs of commit a0d199a with merge base 266ff2d ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
WIP