Skip to content

FP8 kv cache quantization#4563

Draft
CUHKSZzxy wants to merge 8 commits intoInternLM:mainfrom
CUHKSZzxy:feat/fp8-kv-cache-quant
Draft

FP8 kv cache quantization#4563
CUHKSZzxy wants to merge 8 commits intoInternLM:mainfrom
CUHKSZzxy:feat/fp8-kv-cache-quant

Conversation

@CUHKSZzxy
Copy link
Copy Markdown
Collaborator

@CUHKSZzxy CUHKSZzxy commented Apr 29, 2026

Motivation

This PR adds FP8 KV-cache quantization support for the PyTorch backend, with the normal fp8 behavior aligned to vLLM-style scalar-scale FP8 KV cache.

The goal is to reduce KV-cache memory usage while avoiding the performance overhead of the dynamic per-token/head scale metadata path for the common non-calibrated FP8 case.

Modification

Quant policy split

  • fp8 / fp8_e4m3: normal scalar-scale FP8 E4M3 KV cache
  • fp8_e5m2: normal scalar-scale FP8 E5M2 KV cache
  • fp8_per_token_head / fp8_e4m3_per_token_head: dynamic per-token/head FP8 E4M3 KV cache
  • fp8_e5m2_per_token_head: dynamic per-token/head FP8 E5M2 KV cache

Normal FP8 scalar-scale path

Normal FP8 now uses scalar K/V scales owned by each attention layer:

  • default non-calibrated scales are k_scale = v_scale = 1.0
  • no k_scales_zeros / v_scales_zeros metadata cache is allocated
  • cache fill quantizes with scalar scales
  • decode applies scalar dequant inside paged attention
  • prefill flatten applies scalar dequant when recovering flattened K/V

This avoids per-token/head scale writes, reads, and metadata traffic for normal fp8.

Per-token/head FP8 path

The previous dynamic FP8 implementation is kept under explicit per-token/head policies:

  • per-token/head scales are computed at cache fill time
  • scale metadata is stored as extra cache tensors
  • decode/flatten consume the per-token/head scale cache

Runtime scale calculation

Added optional calculate_kv_scales support for normal FP8:

  • computes scalar K/V scales from the first normal-FP8 forward pass
  • freezes the scales after the first calculation
  • disables CUDA graph capture while scale calculation is pending, matching vLLM’s eager first-pass behavior

CLI/config updates

  • Added CLI aliases for normal and per-token/head FP8 policies.
  • Added --calculate-kv-scales for PyTorch backend.
  • Updated PyTorch cache config and engine config plumbing.
  • Normal FP8 returns no quant metadata descriptors; per-token/head FP8 keeps scale metadata descriptors.

BC-breaking

No intended BC-breaking change for existing non-FP8 behavior.

For FP8 specifically, fp8 now denotes the normal scalar-scale path. The previous dynamic per-token/head behavior is available through fp8_per_token_head.

Test Plan

  • python -m py_compile on changed runtime, kernel, and test files
  • python -m pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py
  • git diff --check
  • Pre-commit hooks during commit

CUHKSZzxy and others added 8 commits April 23, 2026 14:57
Adds FP8 KV cache quantization (QuantPolicy.FP8 = 16) using
torch.float8_e4m3fn with per-token symmetric scale (no zero point).

Key design:
- Reuses existing fill_kv_cache_blocked_fp8() with group_size=head_dim
  for per-token scale semantics in the fill path
- Dequant in flatten_kv_cache and paged_attention via x.to(f32)*scale
- Scale tensor shape [..., 1]: symmetric, no zero point
- No bit packing (head_dim unchanged, unlike INT4/TURBO_QUANT)

Also fixes pre-existing TestFillKVCacheBlockedFP8 test failures caused
by calling .max() on float8_e4m3fn tensors (cast to float32 first).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoid constructing a temporary cu_seqlen_q tensor in the FP8 cache-fill path by letting fill_kv_cache_blocked_fp8 consume the existing q_start_loc and q_seq_length metadata directly. The kernel keeps the old cumulative-seqlen mode for direct callers via a USE_CU_SEQLEN constexpr.

Move default paged-decode FP8 dequant scaling across the attention dot products. K scales are applied after QK, and V scales are applied to probabilities before PV, which preserves the per-token/head scale algebra while avoiding full K/V tile dequantization in the hot decode loop.

Add a focused FP8 paged-attention test that compares against a dequantized-FP8 reference, including a split-head-dim case, so the fused scale placement is covered without conflating it with expected quantization error.
Split normal FP8 KV cache from the dynamic per-token/head FP8 path. Normal fp8/fp8_e4m3 and fp8_e5m2 now use scalar K/V scales with FP8 cache tensors and no k_scales_zeros/v_scales_zeros metadata allocation, while fp8_per_token_head variants keep the existing per-token/head scale-cache behavior.

Thread scalar k_scale/v_scale through PyTorch attention dispatch, cache fill, flatten, and paged decode kernels so normal FP8 can quantize on cache write and apply scalar dequant in decode/prefill without materialized metadata tensors. Add optional one-shot calculate_kv_scales support and guard CUDA graph capture while scale calculation is pending, mirroring vLLM's eager first-pass behavior.

Add focused CLI/config/cache descriptor tests and scalar/per-token FP8 kernel reference coverage. Validation: py_compile on changed runtime/kernel/test files; pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py; git diff --check. CUDA kernel tests were not run because nvidia-smi cannot communicate with the driver in this environment.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant