[fix](Qwen3-30b-A3B): fix nonsense output by correcting RoPE write-back#31
[fix](Qwen3-30b-A3B): fix nonsense output by correcting RoPE write-back#31yyj6666667 wants to merge 1 commit intokvcache-ai:mainfrom
Conversation
- Restore in-place storage for q/k and ensure k_cache is updated with the rotated values.
There was a problem hiding this comment.
Code Review
This pull request modifies the fused_rope_store_kernel to ensure that rotated values are consistently written back to the input tensor for both query and key inputs, while also updating the k_cache. The review feedback identifies a critical missing update in the Python wrapper's mutates_args for the k tensor, which is necessary for correct dependency tracking and graph capture. Additionally, there is a suggestion to reuse existing pointers for the RoPE dimension offsets to improve code readability and maintain consistency with other kernel implementations.
| store_as<InputStorage>(input, input_vec_x, lane_id); | ||
| const auto input_y_out = pointer::offset(input, (kRopeDim / 2) * sizeof(DType)); | ||
| store_as<InputStorage>(input_y_out, input_vec_y, lane_id); |
There was a problem hiding this comment.
You can reuse the existing input_x and input_y pointers (defined at lines 185-186) instead of recalculating the offset for the second half of the RoPE dimension. This improves readability and maintains consistency with the fused_rope_kernel implementation in the same file.
store_as<InputStorage>(input_x, input_vec_x, lane_id);
store_as<InputStorage>(input_y, input_vec_y, lane_id);
There was a problem hiding this comment.
Pull request overview
Fixes incorrect RoPE application in the fused RoPE+KV-cache store path for Qwen3-30B-A3B by ensuring rotated keys are written back to the caller’s k tensor (in addition to k_cache), preventing prefill attention from mixing rotated q with unrotated k.
Changes:
- Update
fused_rope_store_kernelto always write rotated outputs back in-place to theinputtensor for bothqandk. - Preserve existing behavior of also writing rotated
kintok_cachefor decode.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Always write rotated values back to `input` (q itself when | ||
| // load_q==true, k itself when load_q==false). When processing k, also | ||
| // write to k_cache for future decode. | ||
| store_as<InputStorage>(input, input_vec_x, lane_id); | ||
| const auto input_y_out = pointer::offset(input, (kRopeDim / 2) * sizeof(DType)); | ||
| store_as<InputStorage>(input_y_out, input_vec_y, lane_id); | ||
| if (!load_q) { | ||
| const auto k_out = pointer::offset(k_cache, loc * cache_stride_bytes, head_id * head_stride_bytes); | ||
| store_as<InputStorage>(k_out, input_vec_x, lane_id); |
There was a problem hiding this comment.
This kernel now writes rotated values back into the k input buffer when load_q == false. The Python custom-op wrapper for apply_rope_inplace_with_kvcache currently does not declare k in mutates_args, which can lead to incorrect behavior under torch.compile/graph captures because the op will be treated as not mutating k. Please update the wrapper registration to include k as a mutated argument to match this new in-place writeback.
| // Always write rotated value back to `input`; also to k_cache when processing k. | ||
| store_as<InputStorage>(input, input_vec, lane_id); | ||
| if (!load_q) { | ||
| const auto k_out = pointer::offset(k_cache, loc * cache_stride_bytes, head_id * head_stride_bytes); | ||
| store_as<InputStorage>(k_out, input_vec, lane_id); | ||
| } |
There was a problem hiding this comment.
The existing fused RoPE+KV-cache tests validate q and k_cache, but they don’t assert that the in-place k tensor is rotated. Since this change makes k writeback part of the contract, please extend the test to also compare the post-call k against the reference rotated k to prevent regressions.
Motivation
Modifications
rotated k into k_cache.
Tests