Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 12 additions & 49 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,31 +212,12 @@ def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
seq_length = k_val.size(2)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# We use .narrow() here to make the compiler happy
# pyre-ignore: Incompatible parameter type [6]
narrowed_k = self.k_cache.narrow(2, start_pos, seq_length)
# pyre-ignore: Incompatible parameter type [6]
narrowed_v = self.v_cache.narrow(2, start_pos, seq_length)

narrowed_k.copy_(k_val)
narrowed_v.copy_(v_val)
return self.k_cache, self.v_cache
else:
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out
return k_out, v_out


class SDPA(nn.Module):
Expand Down Expand Up @@ -272,15 +253,7 @@ def forward(
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)
else:
attn_mask = mask[None, None, input_pos]
attn_mask = mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down Expand Up @@ -514,22 +487,12 @@ def forward(
input_pos is not None
), "input_pos must be provided when use_kv_cache is True"

if self.params.enable_dynamic_shape:
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
torch._check(input_pos_item < self.params.max_seq_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
# pyre-ignore: Incompatible parameter type [6]
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
else:
# When not using dynamic shape, use of the .item results in
# symints, due to querying the data from tensor.
# this path avoids that for mps backend, although probably mps backend
# can support dynamic shape?
freqs_cos = self.freqs_cos[input_pos]
freqs_sin = self.freqs_sin[input_pos]
# When not using dynamic shape, use of the .item results in
# symints, due to querying the data from tensor.
# this path avoids that for mps backend, although probably mps backend
# can support dynamic shape?
freqs_cos = self.freqs_cos[input_pos]
freqs_sin = self.freqs_sin[input_pos]

else:
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
Expand Down