Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
Expand Down Expand Up @@ -1763,20 +1763,25 @@
transpose_k = use_transposed_cache
transpose_v = use_transposed_cache

# SDPA uses is_seq_at_dim_2=True when any cache is transposed,
# since KVCache always returns [B, H, S, D] for Attention.
sdpa_seq_at_dim_2 = transpose_k or transpose_v

transforms.append(
partial(replace_kv_cache_with_custom_kv_cache, transpose_k=transpose_k, transpose_v=transpose_v)
)
if use_attention_mask_for_custom_sdpa:
transforms.append(
partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=sdpa_seq_at_dim_2)
partial(
replace_sdpa_with_custom_op,
use_attention_mask=True,
is_k_seq_at_dim_2=transpose_k,
is_v_seq_at_dim_2=transpose_v,
)
)
else:
transforms.append(
partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=sdpa_seq_at_dim_2)
partial(
replace_sdpa_with_custom_op,
is_k_seq_at_dim_2=transpose_k,
is_v_seq_at_dim_2=transpose_v,
)
)

if quantize_kv_cache:
Expand Down
39 changes: 19 additions & 20 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -379,7 +379,7 @@
# input_pos: [S], k_val/v_val: [B, H, S, D] from Attention
start_pos = input_pos[0].item()

# Transpose k_val to match cache layout if needed
# Transpose k_val/v_val to match cache layout if needed
k_for_cache = k_val if self.transpose_k else k_val.transpose(1, 2)
v_for_cache = v_val if self.transpose_v else v_val.transpose(1, 2)

Expand All @@ -394,10 +394,15 @@
_ = torch.ops.llama.update_cache(k_for_cache, self.k_cache, start_pos, self.transpose_k)
_ = torch.ops.llama.update_cache(v_for_cache, self.v_cache, start_pos, self.transpose_v)

# Return both caches in [B, H, S, D] for Attention
k_out = self.k_cache if self.transpose_k else self.k_cache.transpose(1, 2)
v_out = self.v_cache if self.transpose_v else self.v_cache.transpose(1, 2)
return (k_out, v_out)
# Return caches in their native layout. The SDPA op handles
# mixed K/V layouts via separate seq dim parameters, avoiding
# expensive runtime transpose copies.
return (self.k_cache, self.v_cache)

@property
def is_seq_at_dim_2(self):
"""Backward compat for quantized KV cache path."""
return self.transpose_k and self.transpose_v


def replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False):
Expand Down Expand Up @@ -519,7 +524,6 @@
kv_cache.cache_type,
kv_cache.use_custom_update_cache_op,
kv_cache.return_float_values,
kv_cache.is_seq_at_dim_2,
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
)

Expand All @@ -532,11 +536,13 @@
n_heads,
head_dim,
dtype=torch.float32,
is_seq_at_dim_2: bool = False,
transpose_k: bool = False,
transpose_v: bool = False,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype, is_seq_at_dim_2
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype,
transpose_k=transpose_k, transpose_v=transpose_v,
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
Expand All @@ -551,18 +557,10 @@
def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
Returns K/V caches in their native storage layout.
"""
# Need to transpose for two reasons
# 1. kv cache is stored as [B, S, H, D]
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
# away transpose at the output of k, v projection
if not self.is_seq_at_dim_2:
seq_len = k_val.transpose(1, 2).size(1)
else:
seq_len = k_val.size(2)
# k_val is always [B, H, S, D] from Attention. Get seq_len from dim 2.
seq_len = k_val.size(2)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
Expand Down Expand Up @@ -593,7 +591,8 @@
n_heads,
head_dim,
dtype=kv_cache.k_cache.dtype,
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
transpose_k=kv_cache.transpose_k,
transpose_v=kv_cache.transpose_v,
)


Expand Down
66 changes: 44 additions & 22 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -23,14 +23,17 @@
self,
dim: int,
use_attention_mask: bool = False,
is_seq_at_dim_2: bool = False,
is_k_seq_at_dim_2: bool = False,
is_v_seq_at_dim_2: bool = False,
):
super().__init__()
self.dim = dim
self.use_attention_mask = use_attention_mask
# When True, Q/K/V are in [B, H, S, D] and custom_sdpa uses seq_dim=2.
# When False, they are transposed to [B, S, H, D] and custom_sdpa uses seq_dim=1.
self.is_seq_at_dim_2 = is_seq_at_dim_2
# Separate seq dim flags for K and V allow mixed cache layouts.
# Q and output always use seq_dim=2 ([B, H, S, D]) since Q is
# always small (current step) and the transpose is negligible.
self.is_k_seq_at_dim_2 = is_k_seq_at_dim_2
self.is_v_seq_at_dim_2 = is_v_seq_at_dim_2

def forward(
self,
Expand All @@ -42,13 +45,8 @@
seqlen,
mask,
):
# Q, K, V arrive in [B, H, S, D] from Attention.
# If is_seq_at_dim_2=False, transpose to [B, S, H, D] for the op.
if not self.is_seq_at_dim_2:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Q arrives in [B, H, S, D] from Attention - always passed with seq_dim=2.
# K and V arrive in their native cache layout (may differ).
input_dtype = q.dtype
q = q.to(dtype=torch.float)
k = k.to(dtype=torch.float)
Expand All @@ -64,7 +62,9 @@
0,
False,
scale=None,
is_seq_dim_2=self.is_seq_at_dim_2,
is_seq_dim_2=True,
is_k_seq_dim_2=self.is_k_seq_at_dim_2,
is_v_seq_dim_2=self.is_v_seq_at_dim_2,
)
else:
output = torch.ops.llama.custom_sdpa(
Expand All @@ -76,15 +76,20 @@
0,
True,
scale=None,
is_seq_dim_2=self.is_seq_at_dim_2,
is_seq_dim_2=True,
is_k_seq_dim_2=self.is_k_seq_at_dim_2,
is_v_seq_dim_2=self.is_v_seq_at_dim_2,
)
if self.is_seq_at_dim_2:
output = output.transpose(1, 2).contiguous()
# Output is [B, H, S, D] (matches Q layout), transpose for reshape
output = output.transpose(1, 2).contiguous()
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)


def _replace_sdpa_with_custom_op(
module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True
module: torch.nn.Module,
use_attention_mask: bool = False,
is_k_seq_at_dim_2: bool = False,
is_v_seq_at_dim_2: bool = False,
):
for name, child in module.named_children():
if isinstance(child, SDPA):
Expand All @@ -94,19 +99,33 @@
SDPACustom(
child.dim,
use_attention_mask=use_attention_mask,
is_seq_at_dim_2=is_seq_at_dim_2,
is_k_seq_at_dim_2=is_k_seq_at_dim_2,
is_v_seq_at_dim_2=is_v_seq_at_dim_2,
),
)
else:
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2)
_replace_sdpa_with_custom_op(
child,
use_attention_mask=use_attention_mask,
is_k_seq_at_dim_2=is_k_seq_at_dim_2,
is_v_seq_at_dim_2=is_v_seq_at_dim_2,
)


def replace_sdpa_with_custom_op(
module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True
module: torch.nn.Module,
use_attention_mask: bool = False,
is_k_seq_at_dim_2: bool = False,
is_v_seq_at_dim_2: bool = False,
) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops # noqa

_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2)
_replace_sdpa_with_custom_op(
module,
use_attention_mask=use_attention_mask,
is_k_seq_at_dim_2=is_k_seq_at_dim_2,
is_v_seq_at_dim_2=is_v_seq_at_dim_2,
)
return module


Expand Down Expand Up @@ -138,6 +157,7 @@
self.float_dtype = torch.float32
self.kv_cache = kv_cache
self.use_attention_mask = use_attention_mask
# Quantized path uses a single flag for all tensors
self.is_seq_at_dim_2 = is_seq_at_dim_2

def forward(
Expand Down Expand Up @@ -225,8 +245,10 @@
sdpa = getattr(module, "SDPA", None)
assert sdpa is not None
assert isinstance(sdpa, SDPACustom)
# TODO: add support for SDPA with attention mask
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, is_seq_at_dim_2=sdpa.is_seq_at_dim_2)) # noqa: B010
# Quantized SDPA uses a single is_seq_at_dim_2 flag;
# derive from K/V flags (both must match for quantized path).
is_seq_at_dim_2 = sdpa.is_k_seq_at_dim_2 and sdpa.is_v_seq_at_dim_2
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, is_seq_at_dim_2=is_seq_at_dim_2)) # noqa: B010


def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module):
Expand Down
19 changes: 4 additions & 15 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,22 +168,11 @@ def custom_sdpa(
is_causal=False,
scale=None,
is_seq_dim_2=False,
is_k_seq_dim_2=False,
is_v_seq_dim_2=False,
):
seq_len = query.size(2) if is_seq_dim_2 else query.size(1)
_validate_params(
query,
key_cache,
value_cache,
key_cache,
value_cache,
start_pos,
seq_len,
attn_mask,
drpout_p,
is_causal,
scale,
)

# Skip _validate_params since it assumes K/V caches have the same layout.
# With mixed transpose (e.g. v_only), K and V have different shapes.
return torch.empty_like(query)


Expand Down
37 changes: 26 additions & 11 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -345,7 +345,9 @@
const optional<Tensor>& k_scales = nullopt,
const optional<Tensor>& v_zero_points = nullopt,
const optional<Tensor>& v_scales = nullopt,
bool is_seq_at_dim_2 = false) {
bool is_seq_at_dim_2 = false,
bool is_k_seq_at_dim_2 = false,
bool is_v_seq_at_dim_2 = false) {
ET_KERNEL_CHECK_MSG(
ctx,
!attn_mask.has_value() || !is_causal,
Expand All @@ -360,11 +362,10 @@
output,
"Invalid arguments");

SeqDim seq_dim{SeqDim::TWO};
if (!is_seq_at_dim_2) {
seq_dim = SeqDim::ONE;
}
int64_t seq_len = q.size(static_cast<int64_t>(seq_dim));
SeqDim q_seq_dim = is_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE;
SeqDim k_seq_dim = is_k_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE;
SeqDim v_seq_dim = is_v_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE;
int64_t seq_len = q.size(static_cast<int64_t>(q_seq_dim));

if (q.scalar_type() == ScalarType::Char) {
ET_KERNEL_CHECK_MSG(
Expand Down Expand Up @@ -447,7 +448,9 @@
k_scales,
v_zero_points,
v_scales,
seq_dim,
q_seq_dim,
k_seq_dim,
v_seq_dim,
start_pos,
num_keys_for_causal_attention);
} else if (seq_len >= 192) {
Expand All @@ -467,7 +470,9 @@
k_scales,
v_zero_points,
v_scales,
seq_dim,
q_seq_dim,
k_seq_dim,
v_seq_dim,
start_pos,
num_keys_for_causal_attention);
} else {
Expand All @@ -487,7 +492,9 @@
k_scales,
v_zero_points,
v_scales,
seq_dim,
q_seq_dim,
k_seq_dim,
v_seq_dim,
start_pos,
num_keys_for_causal_attention);
}
Expand Down Expand Up @@ -532,6 +539,8 @@
k_scales,
v_zero_points,
v_scales,
is_seq_at_dim_2,
is_seq_at_dim_2,
is_seq_at_dim_2);
}

Expand Down Expand Up @@ -562,6 +571,8 @@
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const bool is_seq_dim_2,
const bool is_k_seq_dim_2,
const bool is_v_seq_dim_2,
Tensor& output) {
return custom_sdpa_out_impl(
ctx,
Expand All @@ -580,7 +591,9 @@
nullopt,
nullopt,
nullopt,
is_seq_dim_2);
is_seq_dim_2,
is_k_seq_dim_2,
is_v_seq_dim_2);
}
/*
Input params
Expand Down Expand Up @@ -635,7 +648,9 @@
dropout_p,
is_causal,
scale,
false, // is_seq_dim_2 - default to false for backward compatibility
false, // is_seq_dim_2
false, // is_k_seq_dim_2
false, // is_v_seq_dim_2
output);

return output;
Expand Down
2 changes: 2 additions & 0 deletions extension/llm/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Tensor& custom_sdpa_out(
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const bool is_seq_dim_2,
const bool is_k_seq_dim_2,
const bool is_v_seq_dim_2,
Tensor& output);

Tensor& flash_attention_kernel_out(
Expand Down
Loading
Loading