Skip to content

Commit 72918b7

Browse files
[examples] Add torch ingress for llama example. (#41)
1 parent 8922c63 commit 72918b7

File tree

2 files changed

+50
-11
lines changed

2 files changed

+50
-11
lines changed

examples/llama/ref_model.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,22 +116,24 @@ def __init__(self, args: ModelArgs):
116116
bias=False,
117117
)
118118

119-
self.cache_k = torch.zeros(
119+
cache_k = torch.zeros(
120120
(
121121
args.max_batch_size,
122122
args.max_seq_len,
123123
self.n_kv_heads,
124124
self.head_dim,
125125
)
126126
)
127-
self.cache_v = torch.zeros(
127+
cache_v = torch.zeros(
128128
(
129129
args.max_batch_size,
130130
args.max_seq_len,
131131
self.n_kv_heads,
132132
self.head_dim,
133133
)
134134
)
135+
self.register_buffer("cache_k", cache_k, persistent=False)
136+
self.register_buffer("cache_v", cache_v, persistent=False)
135137

136138
def forward(
137139
self,
@@ -149,14 +151,17 @@ def forward(
149151

150152
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
151153

152-
self.cache_k = self.cache_k.to(xq)
153-
self.cache_v = self.cache_v.to(xq)
154-
155-
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
156-
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
154+
# TODO: the original implementation doesn't work with export.
155+
# Local tensors instead of in-place buffer updates to please it.
156+
cache_k_updated = self.cache_k.index_copy(
157+
1, torch.arange(start_pos, start_pos + seqlen, device=xk.device), xk
158+
)
159+
cache_v_updated = self.cache_v.index_copy(
160+
1, torch.arange(start_pos, start_pos + seqlen, device=xv.device), xv
161+
)
157162

158-
keys = self.cache_k[:bsz, : start_pos + seqlen]
159-
values = self.cache_v[:bsz, : start_pos + seqlen]
163+
keys = cache_k_updated[:bsz, : start_pos + seqlen]
164+
values = cache_v_updated[:bsz, : start_pos + seqlen]
160165

161166
# repeat k/v heads if n_kv_heads < n_heads
162167
keys = repeat_kv(
@@ -246,17 +251,17 @@ def __init__(self, params: ModelArgs):
246251
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
247252
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
248253

249-
self.freqs_cis = precompute_freqs_cis(
254+
freqs_cis = precompute_freqs_cis(
250255
params.dim // params.n_heads,
251256
params.max_seq_len * 2,
252257
params.rope_theta,
253258
)
259+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
254260

255261
@torch.inference_mode()
256262
def forward(self, tokens: torch.Tensor, start_pos: int):
257263
_bsz, seqlen = tokens.shape
258264
h = self.tok_embeddings(tokens)
259-
self.freqs_cis = self.freqs_cis.to(h.device)
260265
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
261266

262267
mask = None

examples/llama/torch_ingress.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# RUN: %PYTHON %s
2+
# REQUIRES: torch
3+
4+
import os
5+
from pathlib import Path
6+
7+
import torch
8+
9+
from lighthouse.ingress.torch import import_from_model
10+
from ref_model import ModelArgs, Transformer
11+
12+
script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
13+
model_path = script_dir / "ref_model.py"
14+
15+
model_args = ModelArgs(
16+
dim=512,
17+
n_layers=2,
18+
n_heads=8,
19+
vocab_size=10000,
20+
max_batch_size=1,
21+
max_seq_len=128,
22+
)
23+
24+
model = Transformer(model_args)
25+
sample_input = (torch.randint(0, model_args.vocab_size, (1, model_args.max_seq_len)), 0)
26+
27+
mlir_module_str = import_from_model(
28+
model, sample_args=sample_input, dialect="linalg-on-tensors"
29+
)
30+
31+
dense_resource_idx = mlir_module_str.find("\n{-#\n dialect_resources: {")
32+
assert dense_resource_idx != -1
33+
34+
print(mlir_module_str[:dense_resource_idx])

0 commit comments

Comments
 (0)