FP8 kv cache quantization#4563
Draft
CUHKSZzxy wants to merge 8 commits intoInternLM:mainfrom
Draft
Conversation
Adds FP8 KV cache quantization (QuantPolicy.FP8 = 16) using torch.float8_e4m3fn with per-token symmetric scale (no zero point). Key design: - Reuses existing fill_kv_cache_blocked_fp8() with group_size=head_dim for per-token scale semantics in the fill path - Dequant in flatten_kv_cache and paged_attention via x.to(f32)*scale - Scale tensor shape [..., 1]: symmetric, no zero point - No bit packing (head_dim unchanged, unlike INT4/TURBO_QUANT) Also fixes pre-existing TestFillKVCacheBlockedFP8 test failures caused by calling .max() on float8_e4m3fn tensors (cast to float32 first). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoid constructing a temporary cu_seqlen_q tensor in the FP8 cache-fill path by letting fill_kv_cache_blocked_fp8 consume the existing q_start_loc and q_seq_length metadata directly. The kernel keeps the old cumulative-seqlen mode for direct callers via a USE_CU_SEQLEN constexpr. Move default paged-decode FP8 dequant scaling across the attention dot products. K scales are applied after QK, and V scales are applied to probabilities before PV, which preserves the per-token/head scale algebra while avoiding full K/V tile dequantization in the hot decode loop. Add a focused FP8 paged-attention test that compares against a dequantized-FP8 reference, including a split-head-dim case, so the fused scale placement is covered without conflating it with expected quantization error.
Split normal FP8 KV cache from the dynamic per-token/head FP8 path. Normal fp8/fp8_e4m3 and fp8_e5m2 now use scalar K/V scales with FP8 cache tensors and no k_scales_zeros/v_scales_zeros metadata allocation, while fp8_per_token_head variants keep the existing per-token/head scale-cache behavior. Thread scalar k_scale/v_scale through PyTorch attention dispatch, cache fill, flatten, and paged decode kernels so normal FP8 can quantize on cache write and apply scalar dequant in decode/prefill without materialized metadata tensors. Add optional one-shot calculate_kv_scales support and guard CUDA graph capture while scale calculation is pending, mirroring vLLM's eager first-pass behavior. Add focused CLI/config/cache descriptor tests and scalar/per-token FP8 kernel reference coverage. Validation: py_compile on changed runtime/kernel/test files; pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py; git diff --check. CUDA kernel tests were not run because nvidia-smi cannot communicate with the driver in this environment.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR adds FP8 KV-cache quantization support for the PyTorch backend, with the normal
fp8behavior aligned to vLLM-style scalar-scale FP8 KV cache.The goal is to reduce KV-cache memory usage while avoiding the performance overhead of the dynamic per-token/head scale metadata path for the common non-calibrated FP8 case.
Modification
Quant policy split
fp8/fp8_e4m3: normal scalar-scale FP8 E4M3 KV cachefp8_e5m2: normal scalar-scale FP8 E5M2 KV cachefp8_per_token_head/fp8_e4m3_per_token_head: dynamic per-token/head FP8 E4M3 KV cachefp8_e5m2_per_token_head: dynamic per-token/head FP8 E5M2 KV cacheNormal FP8 scalar-scale path
Normal FP8 now uses scalar K/V scales owned by each attention layer:
k_scale = v_scale = 1.0k_scales_zeros/v_scales_zerosmetadata cache is allocatedThis avoids per-token/head scale writes, reads, and metadata traffic for normal
fp8.Per-token/head FP8 path
The previous dynamic FP8 implementation is kept under explicit per-token/head policies:
Runtime scale calculation
Added optional
calculate_kv_scalessupport for normal FP8:CLI/config updates
--calculate-kv-scalesfor PyTorch backend.BC-breaking
No intended BC-breaking change for existing non-FP8 behavior.
For FP8 specifically,
fp8now denotes the normal scalar-scale path. The previous dynamic per-token/head behavior is available throughfp8_per_token_head.Test Plan
python -m py_compileon changed runtime, kernel, and test filespython -m pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.pygit diff --check