diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index 53b7292179f..cb55edd1c4c 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -282,31 +282,41 @@ def initialize_kv_cache( logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}") cache_kvs_list = [] + # Quantized KV cache (int8/fp8/etc.) uses uint8 storage (1 byte per element). + # Non-quantized cache uses the model's compute dtype (e.g., bfloat16). + cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype + for i in range(self._num_layers): # Generate cache names cache_names = self._get_cache_names(i) logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") - # Create key cache and value cache - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + # Create key cache + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) # Create scale caches for block_wise_fp8 quantization if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - val_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() - ) self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales - self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales - cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + if value_cache_shape: + val_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + else: + cache_kvs_list.extend([key_cache_scales]) paddle.device.cuda.empty_cache() logger.info("kv cache is initialized!") @@ -360,26 +370,35 @@ def initialize_mtp_kv_cache( ) cache_kvs_list = [] + # Quantized KV cache uses uint8 storage; non-quantized uses model compute dtype. + cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype + for i in range(layer_offset, layer_offset + num_mtp_layers): cache_names = self._get_cache_names(i) - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - val_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() - ) self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales - self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales - cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + if value_cache_shape: + val_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + else: + cache_kvs_list.extend([key_cache_scales]) paddle.device.cuda.empty_cache() logger.info("[CacheController] MTP kv cache initialized!") diff --git a/tests/cache_manager/v1/test_cache_controller.py b/tests/cache_manager/v1/test_cache_controller.py index 858dbf69b56..a901ad57e9a 100644 --- a/tests/cache_manager/v1/test_cache_controller.py +++ b/tests/cache_manager/v1/test_cache_controller.py @@ -723,5 +723,153 @@ def test_free_gpu_cache_noop_when_empty(self): self.assertEqual(len(self.controller.cache_kvs_map), 0) +# ============================================================================ +# initialize_kv_cache / initialize_mtp_kv_cache dtype Tests (PR #7757) +# ============================================================================ + + +def make_mock_attn_backend(key_shape=(10, 4, 16, 64), val_shape=None, val_shape_is_none=False): + """Create a mock attn_backend with a fixed get_kv_cache_shape.""" + if val_shape_is_none: + # Simulate MLA variants (e.g., DeepSeek) that return None for value_cache_shape + backend = MagicMock() + backend.get_kv_cache_shape.return_value = (list(key_shape), None) + return backend + if val_shape is None: + val_shape = key_shape + backend = MagicMock() + backend.get_kv_cache_shape.return_value = (list(key_shape), list(val_shape)) + return backend + + +class TestInitializeKVCacheDtype(unittest.TestCase): + """ + Tests for the cache_dtype logic introduced in PR #7757: + cache_dtype = "uint8" if kv_cache_quant_type is not None else model_config.dtype + """ + + def _make_controller(self, model_dtype="bfloat16", num_layers=2): + config = get_default_test_fd_config() + config.cache_config.num_cpu_blocks = 0 # skip host cache init + config.model_config.num_hidden_layers = num_layers + config.model_config.dtype = model_dtype + from fastdeploy.cache_manager.v1.cache_controller import CacheController + + return CacheController(config, local_rank=0, device_id=0) + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_non_quantized_uses_model_dtype(self, mock_quant_type): + """When kv_cache_quant_type is None, cache tensors use model_config.dtype.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + self.assertEqual(len(cache_list), 4) # 2 layers * (key + value) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.bfloat16") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_quantized_uses_uint8(self, mock_quant_type): + """When kv_cache_quant_type is set, cache tensors use uint8 regardless of model dtype.""" + mock_quant_type.return_value = "int8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + self.assertEqual(len(cache_list), 4) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.uint8") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_fp8_quantized_uses_uint8(self, mock_quant_type): + """When kv_cache_quant_type is block_wise_fp8, non-scale cache tensors use uint8.""" + mock_quant_type.return_value = "block_wise_fp8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + # fp8 path also creates scale tensors (float32); filter to only key/value caches + kv_tensors = [t for t in cache_list if str(t.dtype) == "paddle.uint8"] + self.assertEqual(len(kv_tensors), 4) # 2 layers * (key + value) + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_mtp_kv_cache_non_quantized_uses_model_dtype(self, mock_quant_type): + """When kv_cache_quant_type is None, MTP cache tensors use model_config.dtype.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="float16", num_layers=4) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_mtp_kv_cache( + attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4 + ) + + self.assertEqual(len(cache_list), 4) # 2 mtp layers * (key + value) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.float16") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_mtp_kv_cache_quantized_uses_uint8(self, mock_quant_type): + """When kv_cache_quant_type is set, MTP cache tensors use uint8.""" + mock_quant_type.return_value = "int8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=4) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_mtp_kv_cache( + attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4 + ) + + self.assertEqual(len(cache_list), 4) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.uint8") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_populates_cache_kvs_map(self, mock_quant_type): + """Tensors created in initialize_kv_cache are stored in cache_kvs_map with correct dtype.""" + mock_quant_type.return_value = "int8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + for name, tensor in controller.cache_kvs_map.items(): + if "scale" not in name: + self.assertEqual(str(tensor.dtype), "paddle.uint8", f"wrong dtype for {name}") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_null_value_cache_shape(self, mock_quant_type): + """MLA variant: when value_cache_shape is None, only key cache is created.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend(val_shape_is_none=True) + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + self.assertEqual(len(cache_list), 2) # 2 layers * key only + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.bfloat16") + # Verify no value entries in cache_kvs_map + for name in controller.cache_kvs_map: + self.assertNotIn("value", name) + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_mtp_kv_cache_null_value_cache_shape(self, mock_quant_type): + """MLA variant: when value_cache_shape is None, only key cache is created for MTP.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="bfloat16", num_layers=4) + backend = make_mock_attn_backend(val_shape_is_none=True) + + cache_list = controller.initialize_mtp_kv_cache( + attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4 + ) + + self.assertEqual(len(cache_list), 2) # 2 mtp layers * key only + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.bfloat16") + + if __name__ == "__main__": unittest.main()