Skip to content

fix: refactor kvcache structure and support kernel_block_size#724

Open
SJTUGavinLiu wants to merge 6 commits intomainfrom
develop/chanyin/hybrid_kvcache_fix
Open

fix: refactor kvcache structure and support kernel_block_size#724
SJTUGavinLiu wants to merge 6 commits intomainfrom
develop/chanyin/hybrid_kvcache_fix

Conversation

@SJTUGavinLiu
Copy link
Collaborator

@SJTUGavinLiu SJTUGavinLiu commented Feb 27, 2026

Background & Motivation

This PR systematically refactors the KVCache data structure and related attention operator interfaces to address two core issues:

  1. Per-layer view inconsistency in hybrid cache scenarios: The original KVCache struct stored all layers' KV caches as a monolithic tensor. In MHA/MLA hybrid layouts, each layer may have a different shape and cannot be directly indexed by layer id, forcing callers to handle this with ad-hoc logic.
  2. Coupling between KV management granularity and attention kernel granularity: Attention operators (FlashInfer, TRTLLM, etc.) have preferences on block size, but the KV manager's physical block size could not be adjusted independently, limiting optimization opportunities on the kernel side.

Key Changes

1. KVCache Data Structure Refactoring

  • Introduced a standalone LayerKVCache struct as a per-layer KV cache view, holding kv_cache_base, kv_scale_base, seq_size_per_block, and layer_id for a single layer.
  • Removed the legacy monolithic kv_cache_base / kv_scale_base tensor fields from KVCache, replacing them with per-layer kv_cache_base_by_layer / kv_scale_base_by_layer vectors. Added metadata fields: num_kv_heads, head_dim, use_mla, kv_lora_rank, rope_head_dim.
  • getLayerCache(idx) now returns a LayerKVCache, automatically reshaping raw 2D buffers in hybrid cache mode: MHA layers → [block_num, 2, kv_heads, seq_size, head_dim], MLA layers → [block_num, seq_size, lora_rank + rope_dim].

2. Introduce kernel_block_size to Decouple Logical and Kernel Blocks

  • Added kernel_seq_size_per_block to CacheConfig, configurable via --kernel_seq_size_per_block launch argument or KERNEL_SEQ_SIZE_PER_BLOCK environment variable. When smaller than seq_size_per_block, each physical KV block is split into multiple kernel blocks at the operator level.
  • Refactored BlockIds with a dual-index mechanism: block_indices (physical blocks) and kernel_block_indices_ (kernel blocks) are always kept in sync. A kernelBlocks() accessor is exposed, and all mutation operations (add(), remove(), swap(), setAt(), resize(), popBack()) automatically update both index arrays, replacing direct raw vector manipulation.
  • KVCacheResource and BatchKVCacheResource interfaces extended to propagate blocks_per_kv_block and group_types.

@SJTUGavinLiu SJTUGavinLiu force-pushed the develop/chanyin/hybrid_kvcache_fix branch 8 times, most recently from 2c961cf to 787e60d Compare March 4, 2026 06:44
@SJTUGavinLiu SJTUGavinLiu requested a review from xinfei-shi March 6, 2026 08:42
@SJTUGavinLiu SJTUGavinLiu changed the title fix: refactor KVCache and add LayerKVCache fix: refactor kvcache structure and support kernel_block_size Mar 6, 2026
@SJTUGavinLiu SJTUGavinLiu force-pushed the develop/chanyin/hybrid_kvcache_fix branch 4 times, most recently from 39a24cd to 61d985d Compare March 8, 2026 09:41
const size_t old_size = block_indices.size();
block_indices.resize(new_size, value);
if (!is_full_) {
kernel_block_indices_.resize(new_size, value);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果在非full的这层,blocks_per_kv_block_ 直接定为 1,是不是就不用每个地方都要进行 分支处理了。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是可以的,那我去掉 is_full 这个标记

const auto type = (static_cast<size_t>(gid) < layout.group_types.size()) ?
layout.group_types[static_cast<size_t>(gid)] :
rtp_llm::CacheGroupType::FULL;
kv_cache.layer_attn_types.push_back(type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是否可以 kv cache manager的get layout接口直接提供 layer_attn_types 这个字段。

Copy link
Collaborator Author

@SJTUGavinLiu SJTUGavinLiu Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和 group_types 一样做到 CacheConfig 理好了,CacheLayout 也加上这个字段。

f"kv_cache_base={self.kv_cache.kv_cache_base.shape if self.kv_cache.kv_cache_base is not None else None}, "
f"kv_scale_base={self.kv_cache.kv_scale_base.shape if self.kv_cache.kv_scale_base is not None else None}, "
f"num_kv_layers={num_layers}, "
f"layer0_kv_cache_shape={layer0_shape}, "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还挺奇怪的,专门打印了layer0的shape,要不就别打印了吧,因为不同层的shape还不一样。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试的时候加的,我检查一下都去掉

@SJTUGavinLiu SJTUGavinLiu force-pushed the develop/chanyin/hybrid_kvcache_fix branch 8 times, most recently from c1e2ebc to 8f99312 Compare March 16, 2026 09:00
@SJTUGavinLiu SJTUGavinLiu force-pushed the develop/chanyin/hybrid_kvcache_fix branch from 8f99312 to 056d1ef Compare March 16, 2026 14:38
@SJTUGavinLiu SJTUGavinLiu force-pushed the develop/chanyin/hybrid_kvcache_fix branch 3 times, most recently from 29bafc3 to 5c45cc0 Compare March 17, 2026 07:08
@SJTUGavinLiu SJTUGavinLiu force-pushed the develop/chanyin/hybrid_kvcache_fix branch from 5c45cc0 to 5460176 Compare March 17, 2026 08:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants