Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions fastdeploy/cache_manager/v1/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,31 +282,41 @@ def initialize_kv_cache(
logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}")
cache_kvs_list = []

# Quantized KV cache (int8/fp8/etc.) uses uint8 storage (1 byte per element).
# Non-quantized cache uses the model's compute dtype (e.g., bfloat16).
cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype

for i in range(self._num_layers):
# Generate cache names
cache_names = self._get_cache_names(i)

logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")

# Create key cache and value cache
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
# Create key cache
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache

val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])
if value_cache_shape:
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])
else:
cache_kvs_list.extend([key_cache])

# Create scale caches for block_wise_fp8 quantization
if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape:
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales
self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
if value_cache_shape:
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
else:
cache_kvs_list.extend([key_cache_scales])

paddle.device.cuda.empty_cache()
logger.info("kv cache is initialized!")
Comment thread
kevincheng2 marked this conversation as resolved.
Expand Down Expand Up @@ -360,26 +370,35 @@ def initialize_mtp_kv_cache(
)
cache_kvs_list = []

# Quantized KV cache uses uint8 storage; non-quantized uses model compute dtype.
cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype

for i in range(layer_offset, layer_offset + num_mtp_layers):
cache_names = self._get_cache_names(i)

key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache

val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])
if value_cache_shape:
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])
else:
cache_kvs_list.extend([key_cache])

if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape:
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales
self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
if value_cache_shape:
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
else:
cache_kvs_list.extend([key_cache_scales])

paddle.device.cuda.empty_cache()
logger.info("[CacheController] MTP kv cache initialized!")
Expand Down
148 changes: 148 additions & 0 deletions tests/cache_manager/v1/test_cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,5 +723,153 @@ def test_free_gpu_cache_noop_when_empty(self):
self.assertEqual(len(self.controller.cache_kvs_map), 0)


# ============================================================================
# initialize_kv_cache / initialize_mtp_kv_cache dtype Tests (PR #7757)
# ============================================================================


def make_mock_attn_backend(key_shape=(10, 4, 16, 64), val_shape=None, val_shape_is_none=False):
"""Create a mock attn_backend with a fixed get_kv_cache_shape."""
if val_shape_is_none:
# Simulate MLA variants (e.g., DeepSeek) that return None for value_cache_shape
backend = MagicMock()
backend.get_kv_cache_shape.return_value = (list(key_shape), None)
return backend
if val_shape is None:
val_shape = key_shape
backend = MagicMock()
backend.get_kv_cache_shape.return_value = (list(key_shape), list(val_shape))
return backend


class TestInitializeKVCacheDtype(unittest.TestCase):
"""
Tests for the cache_dtype logic introduced in PR #7757:
cache_dtype = "uint8" if kv_cache_quant_type is not None else model_config.dtype
"""

def _make_controller(self, model_dtype="bfloat16", num_layers=2):
config = get_default_test_fd_config()
config.cache_config.num_cpu_blocks = 0 # skip host cache init
config.model_config.num_hidden_layers = num_layers
config.model_config.dtype = model_dtype
from fastdeploy.cache_manager.v1.cache_controller import CacheController

return CacheController(config, local_rank=0, device_id=0)

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_kv_cache_non_quantized_uses_model_dtype(self, mock_quant_type):
"""When kv_cache_quant_type is None, cache tensors use model_config.dtype."""
mock_quant_type.return_value = None
controller = self._make_controller(model_dtype="bfloat16", num_layers=2)
backend = make_mock_attn_backend()

cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10)

self.assertEqual(len(cache_list), 4) # 2 layers * (key + value)
for tensor in cache_list:
self.assertEqual(str(tensor.dtype), "paddle.bfloat16")

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_kv_cache_quantized_uses_uint8(self, mock_quant_type):
"""When kv_cache_quant_type is set, cache tensors use uint8 regardless of model dtype."""
mock_quant_type.return_value = "int8"
controller = self._make_controller(model_dtype="bfloat16", num_layers=2)
backend = make_mock_attn_backend()

cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10)

self.assertEqual(len(cache_list), 4)
for tensor in cache_list:
self.assertEqual(str(tensor.dtype), "paddle.uint8")

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_kv_cache_fp8_quantized_uses_uint8(self, mock_quant_type):
"""When kv_cache_quant_type is block_wise_fp8, non-scale cache tensors use uint8."""
mock_quant_type.return_value = "block_wise_fp8"
controller = self._make_controller(model_dtype="bfloat16", num_layers=2)
backend = make_mock_attn_backend()

cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10)

# fp8 path also creates scale tensors (float32); filter to only key/value caches
kv_tensors = [t for t in cache_list if str(t.dtype) == "paddle.uint8"]
self.assertEqual(len(kv_tensors), 4) # 2 layers * (key + value)

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_mtp_kv_cache_non_quantized_uses_model_dtype(self, mock_quant_type):
"""When kv_cache_quant_type is None, MTP cache tensors use model_config.dtype."""
mock_quant_type.return_value = None
controller = self._make_controller(model_dtype="float16", num_layers=4)
backend = make_mock_attn_backend()

cache_list = controller.initialize_mtp_kv_cache(
attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4
)

self.assertEqual(len(cache_list), 4) # 2 mtp layers * (key + value)
for tensor in cache_list:
self.assertEqual(str(tensor.dtype), "paddle.float16")

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_mtp_kv_cache_quantized_uses_uint8(self, mock_quant_type):
"""When kv_cache_quant_type is set, MTP cache tensors use uint8."""
mock_quant_type.return_value = "int8"
controller = self._make_controller(model_dtype="bfloat16", num_layers=4)
backend = make_mock_attn_backend()

cache_list = controller.initialize_mtp_kv_cache(
attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4
)

self.assertEqual(len(cache_list), 4)
for tensor in cache_list:
self.assertEqual(str(tensor.dtype), "paddle.uint8")

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_kv_cache_populates_cache_kvs_map(self, mock_quant_type):
"""Tensors created in initialize_kv_cache are stored in cache_kvs_map with correct dtype."""
mock_quant_type.return_value = "int8"
controller = self._make_controller(model_dtype="bfloat16", num_layers=2)
backend = make_mock_attn_backend()

controller.initialize_kv_cache(backend, num_gpu_blocks=10)

for name, tensor in controller.cache_kvs_map.items():
if "scale" not in name:
self.assertEqual(str(tensor.dtype), "paddle.uint8", f"wrong dtype for {name}")

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_kv_cache_null_value_cache_shape(self, mock_quant_type):
"""MLA variant: when value_cache_shape is None, only key cache is created."""
mock_quant_type.return_value = None
controller = self._make_controller(model_dtype="bfloat16", num_layers=2)
backend = make_mock_attn_backend(val_shape_is_none=True)

cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10)

self.assertEqual(len(cache_list), 2) # 2 layers * key only
for tensor in cache_list:
self.assertEqual(str(tensor.dtype), "paddle.bfloat16")
# Verify no value entries in cache_kvs_map
for name in controller.cache_kvs_map:
self.assertNotIn("value", name)

@patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type")
def test_initialize_mtp_kv_cache_null_value_cache_shape(self, mock_quant_type):
"""MLA variant: when value_cache_shape is None, only key cache is created for MTP."""
mock_quant_type.return_value = None
controller = self._make_controller(model_dtype="bfloat16", num_layers=4)
backend = make_mock_attn_backend(val_shape_is_none=True)

cache_list = controller.initialize_mtp_kv_cache(
attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4
)

self.assertEqual(len(cache_list), 2) # 2 mtp layers * key only
for tensor in cache_list:
self.assertEqual(str(tensor.dtype), "paddle.bfloat16")


if __name__ == "__main__":
unittest.main()
Loading