Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Please refer to README.md in the same folder for more information.

import math
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -264,17 +265,27 @@ def forward(
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
mask = self.mask[None, None, input_pos]
# mask = self.mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0
)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = F.softmax(scores.float(), dim=-1).type_as(q)
output = torch.matmul(scores, v) # (bs, n_local_heads, seqlen, head_dim)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)
return output
# y = F.scaled_dot_product_attention(
# q, k, v, attn_mask=mask, dropout_p=0.0
# )

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
# y = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = self.wo(y)
# y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

# y = self.wo(y)
return y
else:
from .custom_ops.sdpa_with_kv_cache import sdpa_with_kv_cache # noqa
Expand All @@ -300,16 +311,13 @@ def forward(
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

assert hasattr(self, "mask")

mask = self.mask[:seqlen, :seqlen]

output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = F.softmax(scores.float(), dim=-1).type_as(q)
output = torch.matmul(scores, v) # (bs, n_local_heads, seqlen, head_dim)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)

return output


Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/params/demo_config.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 5, "norm_eps": 1e-05, "vocab_size": 512}
{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 2, "norm_eps": 1e-05, "vocab_size": 512}
5 changes: 3 additions & 2 deletions examples/portable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def _core_aten_to_edge(
constant_methods=edge_constant_methods,
compile_config=edge_compile_config,
)
if verbose:
logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}")
# if verbose:
# logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}")
edge_manager.exported_program().graph_module.print_readable()
return edge_manager


Expand Down