From 189c2eb9a871bcf42952c09586ca885e7f28c917 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Sat, 9 May 2026 13:58:30 +0800 Subject: [PATCH 1/4] [KVCache][BugFix] fix cache_manager_v1 allocating kv cache with wrong dtype when kv_cache_quant_type is set When enable_cache_manager_v1=True and kv_cache_quant_type is configured (e.g., int8), cache_controller.v1 was allocating KV cache tensors using model compute dtype (bfloat16) instead of uint8. This caused a C++ dtype mismatch crash in append_attention_gpu because the attention kernel accesses int8/fp8 quantized caches as uint8_t* internally. Fix: use "uint8" as the cache allocation dtype whenever kv_cache_quant_type is not None, consistent with how gpu_model_runner handles this in the non-v1 code path. Affected: initialize_kv_cache() and initialize_mtp_kv_cache() in CacheController. --- fastdeploy/cache_manager/v1/cache_controller.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index 53b7292179f..feae55ed0c5 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -282,6 +282,10 @@ def initialize_kv_cache( logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}") cache_kvs_list = [] + # Quantized KV cache (int8/fp8/etc.) uses uint8 storage (1 byte per element). + # Non-quantized cache uses the model's compute dtype (e.g., bfloat16). + cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype + for i in range(self._num_layers): # Generate cache names cache_names = self._get_cache_names(i) @@ -289,10 +293,10 @@ def initialize_kv_cache( logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") # Create key cache and value cache - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["value"]] = val_cache cache_kvs_list.extend([key_cache, val_cache]) @@ -360,13 +364,16 @@ def initialize_mtp_kv_cache( ) cache_kvs_list = [] + # Quantized KV cache uses uint8 storage; non-quantized uses model compute dtype. + cache_dtype = "uint8" if kv_cache_quant_type is not None else self.model_config.dtype + for i in range(layer_offset, layer_offset + num_mtp_layers): cache_names = self._get_cache_names(i) - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["value"]] = val_cache cache_kvs_list.extend([key_cache, val_cache]) From afa6623b6a39b7e1cf34eed2ddbece540ee446ea Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Mon, 11 May 2026 11:25:01 +0800 Subject: [PATCH 2/4] [KVCache][Test] add unit tests for kv cache quantization dtype selection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #7757 修改了 initialize_kv_cache 和 initialize_mtp_kv_cache, 量化场景下(kv_cache_quant_type is not None)使用 uint8 存储, 非量化场景使用 model_config.dtype,补充对应单元测试。 ## Modifications 新增 TestInitializeKVCacheDtype 测试类(6 个用例): - 无量化时 initialize_kv_cache 使用 model_config.dtype(bfloat16/float16) - int8 量化时 initialize_kv_cache 使用 uint8 - block_wise_fp8 量化时 initialize_kv_cache key/value 张量使用 uint8 - 无量化时 initialize_mtp_kv_cache 使用 model_config.dtype - int8 量化时 initialize_mtp_kv_cache 使用 uint8 - 量化时 cache_kvs_map 中存储的张量也是 uint8 --- .../cache_manager/v1/test_cache_controller.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/cache_manager/v1/test_cache_controller.py b/tests/cache_manager/v1/test_cache_controller.py index 858dbf69b56..b9905436137 100644 --- a/tests/cache_manager/v1/test_cache_controller.py +++ b/tests/cache_manager/v1/test_cache_controller.py @@ -723,5 +723,117 @@ def test_free_gpu_cache_noop_when_empty(self): self.assertEqual(len(self.controller.cache_kvs_map), 0) +# ============================================================================ +# initialize_kv_cache / initialize_mtp_kv_cache dtype Tests (PR #7757) +# ============================================================================ + + +def make_mock_attn_backend(key_shape=(10, 4, 16, 64), val_shape=None): + """Create a mock attn_backend with a fixed get_kv_cache_shape.""" + if val_shape is None: + val_shape = key_shape + backend = MagicMock() + backend.get_kv_cache_shape.return_value = (list(key_shape), list(val_shape)) + return backend + + +class TestInitializeKVCacheDtype(unittest.TestCase): + """ + Tests for the cache_dtype logic introduced in PR #7757: + cache_dtype = "uint8" if kv_cache_quant_type is not None else model_config.dtype + """ + + def _make_controller(self, model_dtype="bfloat16", num_layers=2): + config = get_default_test_fd_config() + config.cache_config.num_cpu_blocks = 0 # skip host cache init + config.model_config.num_hidden_layers = num_layers + config.model_config.dtype = model_dtype + from fastdeploy.cache_manager.v1.cache_controller import CacheController + + return CacheController(config, local_rank=0, device_id=0) + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_non_quantized_uses_model_dtype(self, mock_quant_type): + """When kv_cache_quant_type is None, cache tensors use model_config.dtype.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + self.assertEqual(len(cache_list), 4) # 2 layers * (key + value) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.bfloat16") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_quantized_uses_uint8(self, mock_quant_type): + """When kv_cache_quant_type is set, cache tensors use uint8 regardless of model dtype.""" + mock_quant_type.return_value = "int8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + self.assertEqual(len(cache_list), 4) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.uint8") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_fp8_quantized_uses_uint8(self, mock_quant_type): + """When kv_cache_quant_type is block_wise_fp8, non-scale cache tensors use uint8.""" + mock_quant_type.return_value = "block_wise_fp8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + # fp8 path also creates scale tensors (float32); filter to only key/value caches + kv_tensors = [t for t in cache_list if str(t.dtype) == "paddle.uint8"] + self.assertEqual(len(kv_tensors), 4) # 2 layers * (key + value) + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_mtp_kv_cache_non_quantized_uses_model_dtype(self, mock_quant_type): + """When kv_cache_quant_type is None, MTP cache tensors use model_config.dtype.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="float16", num_layers=4) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_mtp_kv_cache( + attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4 + ) + + self.assertEqual(len(cache_list), 4) # 2 mtp layers * (key + value) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.float16") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_mtp_kv_cache_quantized_uses_uint8(self, mock_quant_type): + """When kv_cache_quant_type is set, MTP cache tensors use uint8.""" + mock_quant_type.return_value = "int8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=4) + backend = make_mock_attn_backend() + + cache_list = controller.initialize_mtp_kv_cache( + attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4 + ) + + self.assertEqual(len(cache_list), 4) + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.uint8") + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_populates_cache_kvs_map(self, mock_quant_type): + """Tensors created in initialize_kv_cache are stored in cache_kvs_map with correct dtype.""" + mock_quant_type.return_value = "int8" + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend() + + controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + for name, tensor in controller.cache_kvs_map.items(): + if "scale" not in name: + self.assertEqual(str(tensor.dtype), "paddle.uint8", f"wrong dtype for {name}") + + if __name__ == "__main__": unittest.main() From cf22cf2f7cd4ab5a9990e5e501dd7c5be5ff7e8f Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Mon, 11 May 2026 11:11:18 +0800 Subject: [PATCH 3/4] [KVCache][BugFix] Fix cache_controller_v1 crash when value_cache_shape is None Motivation initialize_kv_cache and initialize_mtp_kv_cache in CacheControllerV1 always unconditionally create a value cache tensor, which causes a crash (None shape) for attention backends that return value_cache_shape=None (e.g. MLA variants). Modifications - initialize_kv_cache: handle get_kv_cache_shape returning None value_cache_shape; only create val_cache / val_cache_scales when value_cache_shape is not None; cache_kvs_list order now matches gpu_model_runner.py: [key] or [key, val]. - initialize_mtp_kv_cache: apply the same fix for MTP layers. --- .../cache_manager/v1/cache_controller.py | 46 ++++++++++++------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index feae55ed0c5..cb55edd1c4c 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -292,25 +292,31 @@ def initialize_kv_cache( logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") - # Create key cache and value cache + # Create key cache key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) # Create scale caches for block_wise_fp8 quantization if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - val_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() - ) self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales - self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales - cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + if value_cache_shape: + val_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + else: + cache_kvs_list.extend([key_cache_scales]) paddle.device.cuda.empty_cache() logger.info("kv cache is initialized!") @@ -373,20 +379,26 @@ def initialize_mtp_kv_cache( key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - val_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() - ) self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales - self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales - cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + if value_cache_shape: + val_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + else: + cache_kvs_list.extend([key_cache_scales]) paddle.device.cuda.empty_cache() logger.info("[CacheController] MTP kv cache initialized!") From 92e9052dfc0f7abeec26fa3069a77b6411479ef0 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Mon, 11 May 2026 14:23:40 +0800 Subject: [PATCH 4/4] [KVCache][Test] add unit tests for null value_cache_shape (MLA/DeepSeek) --- .../cache_manager/v1/test_cache_controller.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/cache_manager/v1/test_cache_controller.py b/tests/cache_manager/v1/test_cache_controller.py index b9905436137..a901ad57e9a 100644 --- a/tests/cache_manager/v1/test_cache_controller.py +++ b/tests/cache_manager/v1/test_cache_controller.py @@ -728,8 +728,13 @@ def test_free_gpu_cache_noop_when_empty(self): # ============================================================================ -def make_mock_attn_backend(key_shape=(10, 4, 16, 64), val_shape=None): +def make_mock_attn_backend(key_shape=(10, 4, 16, 64), val_shape=None, val_shape_is_none=False): """Create a mock attn_backend with a fixed get_kv_cache_shape.""" + if val_shape_is_none: + # Simulate MLA variants (e.g., DeepSeek) that return None for value_cache_shape + backend = MagicMock() + backend.get_kv_cache_shape.return_value = (list(key_shape), None) + return backend if val_shape is None: val_shape = key_shape backend = MagicMock() @@ -834,6 +839,37 @@ def test_initialize_kv_cache_populates_cache_kvs_map(self, mock_quant_type): if "scale" not in name: self.assertEqual(str(tensor.dtype), "paddle.uint8", f"wrong dtype for {name}") + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_kv_cache_null_value_cache_shape(self, mock_quant_type): + """MLA variant: when value_cache_shape is None, only key cache is created.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="bfloat16", num_layers=2) + backend = make_mock_attn_backend(val_shape_is_none=True) + + cache_list = controller.initialize_kv_cache(backend, num_gpu_blocks=10) + + self.assertEqual(len(cache_list), 2) # 2 layers * key only + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.bfloat16") + # Verify no value entries in cache_kvs_map + for name in controller.cache_kvs_map: + self.assertNotIn("value", name) + + @patch("fastdeploy.cache_manager.v1.cache_controller.CacheController._get_kv_cache_quant_type") + def test_initialize_mtp_kv_cache_null_value_cache_shape(self, mock_quant_type): + """MLA variant: when value_cache_shape is None, only key cache is created for MTP.""" + mock_quant_type.return_value = None + controller = self._make_controller(model_dtype="bfloat16", num_layers=4) + backend = make_mock_attn_backend(val_shape_is_none=True) + + cache_list = controller.initialize_mtp_kv_cache( + attn_backend=backend, num_gpu_blocks=10, num_mtp_layers=2, layer_offset=4 + ) + + self.assertEqual(len(cache_list), 2) # 2 mtp layers * key only + for tensor in cache_list: + self.assertEqual(str(tensor.dtype), "paddle.bfloat16") + if __name__ == "__main__": unittest.main()