Thank you for your excellent work and for open-sourcing this project.
I have encountered some confusion while trying to reproduce the results:
In pyramidkv_utils.py, the following snippet is used in multiple methods to select indices and gather corresponding values:
indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
if self.merge is not None:
key_states, value_states = merge_kv(key_states, value_states, indices, self.window_size, self.merge)
return key_states, value_states
k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
Here, the indices are selected based on the scores in attn_cache (sorted by their values), rather than by their original positional order. My concern is whether this approach could lead to incorrect ordering in the KV cache when performing the gather operation for k_past_compress and v_past_compress. Could this potentially cause issues in subsequent inference?
I would appreciate any insights or clarifications on this matter. Thank you!
Thank you for your excellent work and for open-sourcing this project.
I have encountered some confusion while trying to reproduce the results:
In
pyramidkv_utils.py, the following snippet is used in multiple methods to select indices and gather corresponding values:Here, the
indicesare selected based on the scores inattn_cache(sorted by their values), rather than by their original positional order. My concern is whether this approach could lead to incorrect ordering in the KV cache when performing the gather operation fork_past_compressandv_past_compress. Could this potentially cause issues in subsequent inference?I would appreciate any insights or clarifications on this matter. Thank you!