[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461
[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461VeeraRajasekhar wants to merge 6 commits intodevfrom
Conversation
Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..
- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
fused_attn_rocm/: declarations and implementation adapted from
varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
transformer_engine/common/CMakeLists.txt.
- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
== 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
kernel expects sequence-level batch).
- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
(workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
max_seqlen_kv; on real run call get_runtime_max_seqlen then
fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
get_runtime_max_seqlen, workspace size, and small-seq bwd.
- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
path (forward write, backward read);
- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
buffer matches C++ attention-weights convention.
- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
(parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
C++.
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
|
Let's make this PR work for jax extension first. Later we can support pytorch. One key difference btw jax and pytorch fused-attn dispatch is that pytorch can calculate, request, and allocate softmax_aux, workspace during runtime with actual cu_seqlen_q/kv data. However, in jax extension, softmax_aux and workspace calculation is done in TransformerEngine/transformer_engine/jax/cpp_extensions/attention.py Lines 364 to 375 in b685686 General guideline: |
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
Outdated
Show resolved
Hide resolved
| NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace."); | ||
|
|
||
| float sqr_dk_scale = attn_scale; | ||
| hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream); |
There was a problem hiding this comment.
Probably no need for this cast. cudaStream_t will be hipified correctly to hipStream_t
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
…port to small-seq kernels
| << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") | ||
| << std::endl; | ||
| } | ||
| (void)h_kv; |
There was a problem hiding this comment.
What's the purpose of those (void) statement? To ensure they are not 0?
There was a problem hiding this comment.
No, I removed them cause there are lot of warnings when initially fixing build issues. I will remove this, this just silences the warnings
| float sqr_dk_scale = attn_scale; | ||
| hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream); | ||
|
|
||
| if (qkv_dtype == DType::kBFloat16) { |
There was a problem hiding this comment.
Try this TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT?
| NVTE_ERROR("small-seq path supports only BF16 and FP16."); | ||
| } | ||
|
|
||
| if (workspace_size) { |
There was a problem hiding this comment.
Do we need to set this workspace size? Will it be used later?
There was a problem hiding this comment.
Removed, as no longer required for Jax
| NVTE_ERROR("small-seq path supports only BF16 and FP16."); | ||
| } | ||
|
|
||
| if (workspace_size) |
There was a problem hiding this comment.
Same question here about whether we need to set the workspace size.
There was a problem hiding this comment.
Removed, as no longer required for Jax
| softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) | ||
| else: | ||
| batch_size = reduce(operator.mul, batch_shape) | ||
| old_ck_softmax_size = (batch_size * attn_heads * q_max_seqlen * 1) |
There was a problem hiding this comment.
Later let's come up with a better name. "old" seems informal :-)
| raise ValueError(f"Unsupported {backend=}") | ||
|
|
||
| if os.environ.get("NVTE_LOG_CK_CONFIG"): | ||
| print( |
There was a problem hiding this comment.
Probably you need jax.debug.print. Otherwise it won't print with jax jit/cudagraph
| size_t workspace_elems = product(work_shape); | ||
| size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); | ||
| size_t workspace_bytes = workspace_elems * elt_size; | ||
| size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16) |
There was a problem hiding this comment.
Actually our small cross-attn is unfused
| size_t workspace_bytes = workspace_elems * elt_size; | ||
| size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16) | ||
|
|
||
| if (is_ragged && workspace_bytes < fused_small_seq_workspace) { |
There was a problem hiding this comment.
You can put is_ragged into line 514 if condition to filter earlier
tests/jax/test_fused_attn.py
Outdated
| @pytest.mark.parametrize( | ||
| "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", | ||
| [ | ||
| pytest.param(4000, 1, 2, 16, 16, 128, 128, jnp.bfloat16, |
There was a problem hiding this comment.
Let's also have a sequence packing in q testcase. For example:
1). b=2048, s_q=2, s_kv=4, num_seg_per_seq=2
2). b=2, s_q=4096, s_kv=8192, num_seg_per_seq=4096
| """ | ||
| Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout. | ||
| Uses THD_THD_THD (Q,K,V all THD). | ||
| """ |
There was a problem hiding this comment.
Set env NVTE_FUSED_ATTN_CK_SMALLSEQ=1 for this pytest and clear it after finished
| std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl; | ||
| } | ||
|
|
||
| if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { |
There was a problem hiding this comment.
Also add the filter condition that qkv_layout==THD_THD_THD, same for the bwd pass
There was a problem hiding this comment.
this is_ragged serves the same purpose? right
- tests/jax: CK small-seq tests use fixture to set/restore NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq = max_seqlen_q for THD else 2. - JAX attention.py: THD softmax shape/dtype uses small-seq path only when env=1, else original layout - JAX attention.cpp: Added env guard - fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch)..
Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt.
In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch).
In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd.
Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read);
JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention.
Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: