Skip to content

[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461

Open
VeeraRajasekhar wants to merge 6 commits intodevfrom
veergopu/fused-varlen-ck-smallseq-integration
Open

[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461
VeeraRajasekhar wants to merge 6 commits intodevfrom
veergopu/fused-varlen-ck-smallseq-integration

Conversation

@VeeraRajasekhar
Copy link
Contributor

Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch)..

  • Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.

  • Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt.

  • In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch).

  • In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd.

  • Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read);

  • JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention.

  • Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..

- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
  fused_attn_rocm/: declarations and implementation adapted from
  varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
  grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
  max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.

- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
  transformer_engine/common/CMakeLists.txt.

- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
  max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
  == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
  max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
  h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
  call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
  count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
  output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
  matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
  kernel expects sequence-level batch).

- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
  (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
  max_seqlen_kv; on real run call get_runtime_max_seqlen then
  fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
  get_runtime_max_seqlen, workspace size, and small-seq bwd.

- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
  path (forward write, backward read);

- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
  kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
  q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
  buffer matches C++ attention-weights convention.

- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
  (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
  SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
  C++.
@wangye805
Copy link
Collaborator

wangye805 commented Feb 25, 2026

Let's make this PR work for jax extension first. Later we can support pytorch.

One key difference btw jax and pytorch fused-attn dispatch is that pytorch can calculate, request, and allocate softmax_aux, workspace during runtime with actual cu_seqlen_q/kv data. However, in jax extension, softmax_aux and workspace calculation is done in

if backend == NVTE_Fused_Attn_Backend.NVTE_AOTriton:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
elif backend == NVTE_Fused_Attn_Backend.NVTE_CK:
if config.qkv_layout.is_thd():
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
and
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
, without knowing actual runtime cu_seqlen_q/kv. Aux tensors are prepared in
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
, also without the knowledge of runtime cu_seqlen_q/kv

General guideline:
1). Pre-allocate large enough softmax_aux and workspace ahead of time. Do not modify the aux preparation function or the c++ level aux workspace calculation/preparation, since we know our softmax aux and workspace size will be large enough for both flow, and the special flow only need a valid start pointer address.
2). During actual kernel dispatch, we do a seqlen_q/kv check, if it satisfy the special cross-attn condition, we launch it here
3). Use an env to guard this new flow and disable it when CP is used

NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace.");

float sqr_dk_scale = attn_scale;
hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no need for this cast. cudaStream_t will be hipified correctly to hipStream_t

<< (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?")
<< std::endl;
}
(void)h_kv;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of those (void) statement? To ensure they are not 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I removed them cause there are lot of warnings when initially fixing build issues. I will remove this, this just silences the warnings

float sqr_dk_scale = attn_scale;
hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream);

if (qkv_dtype == DType::kBFloat16) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try this TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT?

#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

NVTE_ERROR("small-seq path supports only BF16 and FP16.");
}

if (workspace_size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set this workspace size? Will it be used later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, as no longer required for Jax

NVTE_ERROR("small-seq path supports only BF16 and FP16.");
}

if (workspace_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here about whether we need to set the workspace size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, as no longer required for Jax

softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
batch_size = reduce(operator.mul, batch_shape)
old_ck_softmax_size = (batch_size * attn_heads * q_max_seqlen * 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later let's come up with a better name. "old" seems informal :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

raise ValueError(f"Unsupported {backend=}")

if os.environ.get("NVTE_LOG_CK_CONFIG"):
print(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you need jax.debug.print. Otherwise it won't print with jax jit/cudagraph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

size_t workspace_elems = product(work_shape);
size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype());
size_t workspace_bytes = workspace_elems * elt_size;
size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually our small cross-attn is unfused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

size_t workspace_bytes = workspace_elems * elt_size;
size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16)

if (is_ragged && workspace_bytes < fused_small_seq_workspace) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can put is_ragged into line 514 if condition to filter earlier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
[
pytest.param(4000, 1, 2, 16, 16, 128, 128, jnp.bfloat16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also have a sequence packing in q testcase. For example:
1). b=2048, s_q=2, s_kv=4, num_seg_per_seq=2
2). b=2, s_q=4096, s_kv=8192, num_seg_per_seq=4096

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

"""
Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout.
Uses THD_THD_THD (Q,K,V all THD).
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set env NVTE_FUSED_ATTN_CK_SMALLSEQ=1 for this pytest and clear it after finished

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl;
}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add the filter condition that qkv_layout==THD_THD_THD, same for the bwd pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {

this is_ragged serves the same purpose? right

- tests/jax: CK small-seq tests use fixture to set/restore
  NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing
  cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq =
  max_seqlen_q for THD else 2.
- JAX attention.py: THD softmax shape/dtype uses small-seq path only when
  env=1, else original layout
- JAX attention.cpp: Added env guard
- fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for
  fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants