@@ -116,22 +116,24 @@ def __init__(self, args: ModelArgs):
116116 bias = False ,
117117 )
118118
119- self . cache_k = torch .zeros (
119+ cache_k = torch .zeros (
120120 (
121121 args .max_batch_size ,
122122 args .max_seq_len ,
123123 self .n_kv_heads ,
124124 self .head_dim ,
125125 )
126126 )
127- self . cache_v = torch .zeros (
127+ cache_v = torch .zeros (
128128 (
129129 args .max_batch_size ,
130130 args .max_seq_len ,
131131 self .n_kv_heads ,
132132 self .head_dim ,
133133 )
134134 )
135+ self .register_buffer ("cache_k" , cache_k , persistent = False )
136+ self .register_buffer ("cache_v" , cache_v , persistent = False )
135137
136138 def forward (
137139 self ,
@@ -149,14 +151,17 @@ def forward(
149151
150152 xq , xk = apply_rotary_emb (xq , xk , freqs_cis = freqs_cis )
151153
152- self .cache_k = self .cache_k .to (xq )
153- self .cache_v = self .cache_v .to (xq )
154-
155- self .cache_k [:bsz , start_pos : start_pos + seqlen ] = xk
156- self .cache_v [:bsz , start_pos : start_pos + seqlen ] = xv
154+ # TODO: the original implementation doesn't work with export.
155+ # Local tensors instead of in-place buffer updates to please it.
156+ cache_k_updated = self .cache_k .index_copy (
157+ 1 , torch .arange (start_pos , start_pos + seqlen , device = xk .device ), xk
158+ )
159+ cache_v_updated = self .cache_v .index_copy (
160+ 1 , torch .arange (start_pos , start_pos + seqlen , device = xv .device ), xv
161+ )
157162
158- keys = self . cache_k [:bsz , : start_pos + seqlen ]
159- values = self . cache_v [:bsz , : start_pos + seqlen ]
163+ keys = cache_k_updated [:bsz , : start_pos + seqlen ]
164+ values = cache_v_updated [:bsz , : start_pos + seqlen ]
160165
161166 # repeat k/v heads if n_kv_heads < n_heads
162167 keys = repeat_kv (
@@ -246,17 +251,17 @@ def __init__(self, params: ModelArgs):
246251 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
247252 self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
248253
249- self . freqs_cis = precompute_freqs_cis (
254+ freqs_cis = precompute_freqs_cis (
250255 params .dim // params .n_heads ,
251256 params .max_seq_len * 2 ,
252257 params .rope_theta ,
253258 )
259+ self .register_buffer ("freqs_cis" , freqs_cis , persistent = False )
254260
255261 @torch .inference_mode ()
256262 def forward (self , tokens : torch .Tensor , start_pos : int ):
257263 _bsz , seqlen = tokens .shape
258264 h = self .tok_embeddings (tokens )
259- self .freqs_cis = self .freqs_cis .to (h .device )
260265 freqs_cis = self .freqs_cis [start_pos : start_pos + seqlen ]
261266
262267 mask = None
0 commit comments