From b0cbb896d61ff9dc88f90dd6e1f8e5ae03ee8972 Mon Sep 17 00:00:00 2001 From: Moonchild1227 Date: Wed, 6 May 2026 15:40:55 +0800 Subject: [PATCH 1/4] feat: support dsa for v1 cache manager --- .../cache_manager/v1/cache_controller.py | 122 +++++++++---- .../cache_manager/v1/transfer_manager.py | 168 ++++++++++-------- 2 files changed, 186 insertions(+), 104 deletions(-) diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index 53b7292179f..53c39074e7f 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -104,6 +104,9 @@ def __init__(self, config: "FDConfig", local_rank: int, device_id: int): # NUMA binding flag self._numa_bound = False + self._is_mla = getattr(self.model_config, 'kv_lora_rank', 0) > 0 + self._is_dsa = self._is_mla and getattr(self.model_config, 'index_head_dim', 0) > 0 + @property def write_policy(self) -> Optional[str]: """Get the write policy for cache operations.""" @@ -230,13 +233,22 @@ def _get_cache_names(self, layer_idx: int) -> Dict[str, str]: """ local_rank = self._local_rank % self.parallel_config.tensor_parallel_size - return { - "key": f"key_caches_{layer_idx}_rank{local_rank}.device{self._device_id}", - "value": f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}", - "key_scale": f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}", - "value_scale": f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}", + names = { + "key": f"key_cache_{layer_idx}_rank{local_rank}.device{self._device_id}", } + if self._is_dsa: + names["indexer"] = f"indexer_caches_{layer_idx}_rank{local_rank}.device{self._device_id}" + elif self._is_mla: + pass # MLA: only key, no value, no indexer + else: + # GQA/MHA: key + value + optional scales + names["value"] = f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}" + names["key_scale"] = f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}" + names["value_scale"] = f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}" + + return names + # ============ KV Cache Management ============ def get_kv_caches(self) -> Optional[Dict[str, Any]]: @@ -270,34 +282,53 @@ def initialize_kv_cache( kv_cache_quant_type = self._get_kv_cache_quant_type() # Get kv cache shape - key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( - max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type - ) + if self._is_dsa: + kv_cache_quant_type = "uint8" + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + cache_dtype = "uint8" + else: + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + indexer_cache_shape = [] + cache_dtype = self.model_config.dtype # Get scale shape for block_wise_fp8 quantization kv_cache_scale_shape = None - if self._is_fp8_quantization(kv_cache_quant_type): + if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type): kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] - logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}") + logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}," + f"is_dsa = {self._is_dsa}, _is_mla = {self._is_mla}") cache_kvs_list = [] for i in range(self._num_layers): # Generate cache names cache_names = self._get_cache_names(i) - logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") + logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}, indexer:{indexer_cache_shape}") # Create key cache and value cache - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) + elif indexer_cache_shape: + # DSA: key + indexer + indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["indexer"]] = indexer_cache + cache_kvs_list.extend([key_cache, indexer_cache]) + else: + # MLA: only key, no value, no indexer + cache_kvs_list.append(key_cache) # Create scale caches for block_wise_fp8 quantization - if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: + if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) @@ -346,31 +377,50 @@ def initialize_mtp_kv_cache( """ kv_cache_quant_type = self._get_kv_cache_quant_type() - key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( - max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type - ) + if self._is_dsa: + kv_cache_quant_type = "uint8" + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + cache_dtype = "uint8" + else: + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + indexer_cache_shape = [] + cache_dtype = self.model_config.dtype kv_cache_scale_shape = None - if self._is_fp8_quantization(kv_cache_quant_type): + if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type): kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] logger.info( f"[CacheController] Initializing MTP kv cache for {num_mtp_layers} layers " f"(layer_offset={layer_offset}, num_gpu_blocks={num_gpu_blocks})." + f"is_dsa = {self._is_dsa}, _is_mla = {self._is_mla}." ) cache_kvs_list = [] for i in range(layer_offset, layer_offset + num_mtp_layers): cache_names = self._get_cache_names(i) - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) - - if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) + elif indexer_cache_shape: + # DSA: key + indexer + indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["indexer"]] = indexer_cache + cache_kvs_list.extend([key_cache, indexer_cache]) + else: + # MLA: only key, no value, no indexer + cache_kvs_list.append(key_cache) + + if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) @@ -542,9 +592,16 @@ def initialize_host_cache( kv_cache_quant_type = self._get_kv_cache_quant_type() # Get kv cache shape (pass num_host_blocks as max_num_blocks for host cache) - key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( - max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type - ) + if self._is_dsa: + kv_cache_quant_type = "uint8" + key_cache_shape, _, indexer_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + value_cache_shape = [] + else: + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type + ) # Calculate cache sizes (elements per block per layer) key_cache_size = key_cache_shape[1] * key_cache_shape[2] * key_cache_shape[3] @@ -554,8 +611,11 @@ def initialize_host_cache( value_cache_size = 0 # Get cache dtype and bytes per element - cache_dtype = self.cache_config.cache_dtype - cache_item_bytes = self.cache_config.get_cache_bytes(cache_dtype) + if self._is_dsa: + cache_item_bytes = 1 + else: + cache_dtype = self.cache_config.cache_dtype + cache_item_bytes = self.cache_config.get_cache_bytes(cache_dtype) # Calculate total bytes to allocate key_need_to_allocate_bytes = num_host_blocks * cache_item_bytes * key_cache_size diff --git a/fastdeploy/cache_manager/v1/transfer_manager.py b/fastdeploy/cache_manager/v1/transfer_manager.py index f4ed0bb6539..b3138ced73e 100644 --- a/fastdeploy/cache_manager/v1/transfer_manager.py +++ b/fastdeploy/cache_manager/v1/transfer_manager.py @@ -130,6 +130,11 @@ def __init__( self._storage_connector = create_storage_connector(self.cache_config) self._transfer_connector = create_transfer_connector(self.cache_config) + # ============ MLA & DSA ============ + self._is_mla = getattr(config.model_config, 'kv_lora_rank', 0) > 0 + self._is_dsa = self._is_mla and getattr(config.model_config, 'index_head_dim', 0) > 0 + + # ============ Cache Map Setters ============ @property @@ -169,17 +174,24 @@ def _build_device_layer_indices(self) -> None: self._device_value_scales = [] for layer_idx in range(self._num_layers): - key_name = f"key_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - + key_name = f"key_cache_{layer_idx}_rank{self._local_rank}.device{self._device_id}" self._device_key_caches.append(self._cache_kvs_map.get(key_name)) - self._device_value_caches.append(self._cache_kvs_map.get(val_name)) - if self._is_fp8_quantization(): - self._device_key_scales.append(self._cache_kvs_map.get(key_scale_name)) - self._device_value_scales.append(self._cache_kvs_map.get(val_scale_name)) + if self._is_dsa: + # DSA: indexer treated as "value" for swap purposes + idx_name = f"indexer_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._device_value_caches.append(self._cache_kvs_map.get(idx_name)) + elif not self._is_mla: + # GQA: has value caches + val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._device_value_caches.append(self._cache_kvs_map.get(val_name)) + + if self._is_fp8_quantization(): + key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._device_key_scales.append(self._cache_kvs_map.get(key_scale_name)) + self._device_value_scales.append(self._cache_kvs_map.get(val_scale_name)) + # MLA: no value caches to add @property def host_cache_kvs_map(self) -> Dict[str, Any]: @@ -215,17 +227,24 @@ def _build_host_layer_indices(self) -> None: self._host_value_scales_ptrs = [] for layer_idx in range(self._num_layers): - key_name = f"key_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - + key_name = f"key_cache_{layer_idx}_rank{self._local_rank}.device{self._device_id}" self._host_key_ptrs.append(self._host_cache_kvs_map.get(key_name, 0)) - self._host_value_ptrs.append(self._host_cache_kvs_map.get(val_name, 0)) - if self._is_fp8_quantization(): - self._host_key_scales_ptrs.append(self._host_cache_kvs_map.get(key_scale_name, 0)) - self._host_value_scales_ptrs.append(self._host_cache_kvs_map.get(val_scale_name, 0)) + if self._is_dsa: + # DSA: indexer treated as "value" for swap purposes + idx_name = f"indexer_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._host_value_ptrs.append(self._host_cache_kvs_map.get(idx_name, 0)) + elif not self._is_mla: + # GQA: has value host cache + val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._host_value_ptrs.append(self._host_cache_kvs_map.get(val_name, 0)) + + if self._is_fp8_quantization(): + key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._host_key_scales_ptrs.append(self._host_cache_kvs_map.get(key_scale_name, 0)) + self._host_value_scales_ptrs.append(self._host_cache_kvs_map.get(val_scale_name, 0)) + # MLA: no value host pointers to add # ============ Metadata Properties ============ @@ -329,16 +348,19 @@ def _swap_all_layers( self._device_id, mode, ) - swap_cache_all_layers( - self._device_value_caches, - self._host_value_ptrs, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) - if self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: + # Value cache is only used in GQA + if not self._is_mla and self._device_value_caches: + swap_cache_all_layers( + self._device_value_caches, + self._host_value_ptrs, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, + ) + # Scale cache is only used in GQA + fp8 quantization + if not self._is_mla and self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: swap_cache_all_layers( self._device_key_scales, self._host_key_scales_ptrs, @@ -389,33 +411,26 @@ def _swap_single_layer( try: key_cache = self.get_device_key_cache(layer_idx) - value_cache = self.get_device_value_cache(layer_idx) - if key_cache is None or value_cache is None: + if key_cache is None: return False - key_ptr = self.get_host_key_ptr(layer_idx) - value_ptr = self.get_host_value_ptr(layer_idx) - if key_ptr == 0 or value_ptr == 0: + if key_ptr == 0: return False swap_cache_per_layer( - key_cache, - key_ptr, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) - swap_cache_per_layer( - value_cache, - value_ptr, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, + key_cache, key_ptr, self._num_host_blocks, + device_block_ids, host_block_ids, self._device_id, mode, ) + + if not self._is_mla or self._is_dsa: + value_cache = self.get_device_value_cache(layer_idx) + value_ptr = self.get_host_value_ptr(layer_idx) + if value_cache is None or value_ptr == 0: + return False + swap_cache_per_layer( + value_cache, value_ptr, self._num_host_blocks, + device_block_ids, host_block_ids, self._device_id, mode, + ) return True except Exception: import traceback @@ -466,16 +481,19 @@ def _swap_all_layers_async( self._device_id, mode, ) - swap_cache_all_layers( - self._device_value_caches, - self._host_value_ptrs, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) - if self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: + # Value/indexer cache: GQA has value, DSA has indexer (both in _device_value_caches) + # MLA has neither, so _device_value_caches is empty + if self._device_value_caches and self._host_value_ptrs: + swap_cache_all_layers( + self._device_value_caches, + self._host_value_ptrs, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, + ) + if not self._is_mla and self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: swap_cache_all_layers( self._device_key_scales, self._host_key_scales_ptrs, @@ -527,13 +545,10 @@ def _swap_single_layer_async( stream = self._output_stream if mode == 0 else self._input_stream key_cache = self.get_device_key_cache(layer_idx) - value_cache = self.get_device_value_cache(layer_idx) - if key_cache is None or value_cache is None: + if key_cache is None: return False - key_ptr = self.get_host_key_ptr(layer_idx) - value_ptr = self.get_host_value_ptr(layer_idx) - if key_ptr == 0 or value_ptr == 0: + if key_ptr == 0: return False try: @@ -548,15 +563,22 @@ def _swap_single_layer_async( self._device_id, mode, ) - swap_cache_per_layer_async( - value_cache, - value_ptr, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) + + if not self._is_mla or self._is_dsa: + # GQA: swap value; DSA: swap indexer (stored in value slot) + value_cache = self.get_device_value_cache(layer_idx) + value_ptr = self.get_host_value_ptr(layer_idx) + if value_cache is None or value_ptr == 0: + return False + swap_cache_per_layer_async( + value_cache, + value_ptr, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, + ) return True except Exception: import traceback From 8c043d4c19c936036a68cef123c5fe7b46329eb0 Mon Sep 17 00:00:00 2001 From: Moonchild1227 Date: Wed, 6 May 2026 15:45:07 +0800 Subject: [PATCH 2/4] style: pre-commit --- .../cache_manager/v1/cache_controller.py | 16 +++++--- .../cache_manager/v1/transfer_manager.py | 37 ++++++++++++++----- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index 53c39074e7f..2b96f8d8911 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -104,8 +104,8 @@ def __init__(self, config: "FDConfig", local_rank: int, device_id: int): # NUMA binding flag self._numa_bound = False - self._is_mla = getattr(self.model_config, 'kv_lora_rank', 0) > 0 - self._is_dsa = self._is_mla and getattr(self.model_config, 'index_head_dim', 0) > 0 + self._is_mla = getattr(self.model_config, "kv_lora_rank", 0) > 0 + self._is_dsa = self._is_mla and getattr(self.model_config, "index_head_dim", 0) > 0 @property def write_policy(self) -> Optional[str]: @@ -240,7 +240,7 @@ def _get_cache_names(self, layer_idx: int) -> Dict[str, str]: if self._is_dsa: names["indexer"] = f"indexer_caches_{layer_idx}_rank{local_rank}.device{self._device_id}" elif self._is_mla: - pass # MLA: only key, no value, no indexer + pass # MLA: only key, no value, no indexer else: # GQA/MHA: key + value + optional scales names["value"] = f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}" @@ -300,15 +300,19 @@ def initialize_kv_cache( if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type): kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] - logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}," - f"is_dsa = {self._is_dsa}, _is_mla = {self._is_mla}") + logger.info( + f"Initializing kv cache for all layers. num_layers={self._num_layers}," + f"is_dsa = {self._is_dsa}, _is_mla = {self._is_mla}" + ) cache_kvs_list = [] for i in range(self._num_layers): # Generate cache names cache_names = self._get_cache_names(i) - logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}, indexer:{indexer_cache_shape}") + logger.info( + f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}, indexer:{indexer_cache_shape}" + ) # Create key cache and value cache key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) diff --git a/fastdeploy/cache_manager/v1/transfer_manager.py b/fastdeploy/cache_manager/v1/transfer_manager.py index b3138ced73e..85992abbf96 100644 --- a/fastdeploy/cache_manager/v1/transfer_manager.py +++ b/fastdeploy/cache_manager/v1/transfer_manager.py @@ -131,9 +131,8 @@ def __init__( self._transfer_connector = create_transfer_connector(self.cache_config) # ============ MLA & DSA ============ - self._is_mla = getattr(config.model_config, 'kv_lora_rank', 0) > 0 - self._is_dsa = self._is_mla and getattr(config.model_config, 'index_head_dim', 0) > 0 - + self._is_mla = getattr(config.model_config, "kv_lora_rank", 0) > 0 + self._is_dsa = self._is_mla and getattr(config.model_config, "index_head_dim", 0) > 0 # ============ Cache Map Setters ============ @@ -360,7 +359,12 @@ def _swap_all_layers( mode, ) # Scale cache is only used in GQA + fp8 quantization - if not self._is_mla and self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: + if ( + not self._is_mla + and self._is_fp8_quantization() + and self._device_key_scales + and self._host_key_scales_ptrs + ): swap_cache_all_layers( self._device_key_scales, self._host_key_scales_ptrs, @@ -418,8 +422,13 @@ def _swap_single_layer( return False swap_cache_per_layer( - key_cache, key_ptr, self._num_host_blocks, - device_block_ids, host_block_ids, self._device_id, mode, + key_cache, + key_ptr, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, ) if not self._is_mla or self._is_dsa: @@ -428,8 +437,13 @@ def _swap_single_layer( if value_cache is None or value_ptr == 0: return False swap_cache_per_layer( - value_cache, value_ptr, self._num_host_blocks, - device_block_ids, host_block_ids, self._device_id, mode, + value_cache, + value_ptr, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, ) return True except Exception: @@ -493,7 +507,12 @@ def _swap_all_layers_async( self._device_id, mode, ) - if not self._is_mla and self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: + if ( + not self._is_mla + and self._is_fp8_quantization() + and self._device_key_scales + and self._host_key_scales_ptrs + ): swap_cache_all_layers( self._device_key_scales, self._host_key_scales_ptrs, From f54eaa70b664a3ff85863542113d06a58b7566cb Mon Sep 17 00:00:00 2001 From: Moonchild1227 Date: Wed, 6 May 2026 16:53:28 +0800 Subject: [PATCH 3/4] feat: Support DSA and reserved pooled interface for DSA offloading. --- .../cache_manager/v1/cache_controller.py | 147 ++++++++++++------ 1 file changed, 103 insertions(+), 44 deletions(-) diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index 2b96f8d8911..ffab8ebbcc4 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -267,7 +267,7 @@ def initialize_kv_cache( num_gpu_blocks: int, ) -> List[Any]: """ - Initialize KV Cache tensors. + Initialize KV Cache tensors (GQA/MHA only). Create KV Cache tensors on GPU for storing attention Key and Value. @@ -278,61 +278,41 @@ def initialize_kv_cache( Returns: cache_kvs_list: KV Cache tensor list in [key_cache_layer0, value_cache_layer0, ...] order. """ - # Get kv cache quantization type - kv_cache_quant_type = self._get_kv_cache_quant_type() - - # Get kv cache shape + # Dispatch to specialized initializers for MLA/DSA if self._is_dsa: - kv_cache_quant_type = "uint8" - key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( - max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type - ) - cache_dtype = "uint8" - else: - key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( - max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type - ) - indexer_cache_shape = [] - cache_dtype = self.model_config.dtype + return self.initialize_dsa_kv_cache(attn_backend, num_gpu_blocks) + elif self._is_mla: + return self.initialize_mla_kv_cache(attn_backend, num_gpu_blocks) + + # GQA/MHA path + kv_cache_quant_type = self._get_kv_cache_quant_type() + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + cache_dtype = self.model_config.dtype - # Get scale shape for block_wise_fp8 quantization + # Scale shape for block_wise_fp8 quantization kv_cache_scale_shape = None - if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type): + if self._is_fp8_quantization(kv_cache_quant_type): kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] logger.info( - f"Initializing kv cache for all layers. num_layers={self._num_layers}," - f"is_dsa = {self._is_dsa}, _is_mla = {self._is_mla}" + f"Initializing GQA kv cache: num_layers={self._num_layers}, " + f"key_shape={key_cache_shape}, value_shape={value_cache_shape}" ) cache_kvs_list = [] for i in range(self._num_layers): - # Generate cache names cache_names = self._get_cache_names(i) - logger.info( - f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}, indexer:{indexer_cache_shape}" - ) - - # Create key cache and value cache key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - if value_cache_shape: - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) - elif indexer_cache_shape: - # DSA: key + indexer - indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype) - self.cache_kvs_map[cache_names["indexer"]] = indexer_cache - cache_kvs_list.extend([key_cache, indexer_cache]) - else: - # MLA: only key, no value, no indexer - cache_kvs_list.append(key_cache) + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) - # Create scale caches for block_wise_fp8 quantization - if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: + if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) @@ -344,14 +324,93 @@ def initialize_kv_cache( cache_kvs_list.extend([key_cache_scales, val_cache_scales]) paddle.device.cuda.empty_cache() - logger.info("kv cache is initialized!") + logger.info("GQA kv cache initialized!") - # Share cache_kvs_map with transfer manager for data transfer operations self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map) - - # Initialize host cache self.initialize_host_cache(attn_backend) + return cache_kvs_list + + def initialize_mla_kv_cache( + self, + attn_backend: Any, + num_gpu_blocks: int, + ) -> List[Any]: + """ + Initialize MLA KV Cache tensors (key only, no value). + + Args: + attn_backend: Attention backend instance for getting kv cache shape. + num_gpu_blocks: Maximum number of blocks on GPU. + Returns: + cache_kvs_list: KV Cache tensor list in [key_layer0, key_layer1, ...] order. + """ + kv_cache_quant_type = self._get_kv_cache_quant_type() + key_cache_shape, _ = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + cache_dtype = self.model_config.dtype + + logger.info(f"Initializing MLA kv cache: num_layers={self._num_layers}, " f"key_shape={key_cache_shape}") + cache_kvs_list = [] + + for i in range(self._num_layers): + cache_names = self._get_cache_names(i) + + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["key"]] = key_cache + cache_kvs_list.append(key_cache) + + paddle.device.cuda.empty_cache() + logger.info("MLA kv cache initialized!") + + self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map) + return cache_kvs_list + + def initialize_dsa_kv_cache( + self, + attn_backend: Any, + num_gpu_blocks: int, + ) -> List[Any]: + """ + Initialize DSA KV Cache tensors (key + indexer, two pools). + + Creates interleaved [key, indexer, key, indexer, ...] layout. + Future HiSparse extension: add host_blocks parameter for key host backup. + + Args: + attn_backend: Attention backend instance for getting kv cache shape. + num_gpu_blocks: Maximum number of blocks on GPU. + + Returns: + cache_kvs_list: KV Cache tensor list in [key_layer0, indexer_layer0, ...] order. + """ + key_cache_shape, _, indexer_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type="uint8" + ) + cache_dtype = "uint8" + + logger.info( + f"Initializing DSA kv cache: num_layers={self._num_layers}, " + f"key_shape={key_cache_shape}, indexer_shape={indexer_cache_shape}" + ) + cache_kvs_list = [] + + for i in range(self._num_layers): + cache_names = self._get_cache_names(i) + + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["key"]] = key_cache + + indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["indexer"]] = indexer_cache + + cache_kvs_list.extend([key_cache, indexer_cache]) + + paddle.device.cuda.empty_cache() + logger.info("DSA kv cache initialized!") + + self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map) return cache_kvs_list def initialize_mtp_kv_cache( From 7b1dc24f4a4d916fd635b58e1f2143da5f75535f Mon Sep 17 00:00:00 2001 From: Moonchild1227 Date: Mon, 11 May 2026 14:27:41 +0800 Subject: [PATCH 4/4] Align V1 KV cache init with V0 by calling set_data_ipc --- fastdeploy/cache_manager/v1/cache_controller.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index ffab8ebbcc4..049a261ae2b 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -351,6 +351,12 @@ def initialize_mla_kv_cache( ) cache_dtype = self.model_config.dtype + # NOTE: set_data_ipc pins tensor storage so paddle allocator cannot + # reuse/migrate it. Without pinning, CUDAGraph capture records a + # data_ptr that allocator may later mark reusable, corrupting replay. + # Align with V0 path (gpu_model_runner.initialize_kv_cache). + from fastdeploy.model_executor.ops.gpu import set_data_ipc + logger.info(f"Initializing MLA kv cache: num_layers={self._num_layers}, " f"key_shape={key_cache_shape}") cache_kvs_list = [] @@ -358,6 +364,7 @@ def initialize_mla_kv_cache( cache_names = self._get_cache_names(i) key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) + set_data_ipc(key_cache, cache_names["key"]) self.cache_kvs_map[cache_names["key"]] = key_cache cache_kvs_list.append(key_cache) @@ -390,6 +397,12 @@ def initialize_dsa_kv_cache( ) cache_dtype = "uint8" + # NOTE: set_data_ipc pins tensor storage so paddle allocator cannot + # reuse/migrate it. Without pinning, CUDAGraph capture records a + # data_ptr that allocator may later mark reusable, corrupting replay. + # Align with V0 path (gpu_model_runner.initialize_kv_cache). + from fastdeploy.model_executor.ops.gpu import set_data_ipc + logger.info( f"Initializing DSA kv cache: num_layers={self._num_layers}, " f"key_shape={key_cache_shape}, indexer_shape={indexer_cache_shape}" @@ -400,9 +413,11 @@ def initialize_dsa_kv_cache( cache_names = self._get_cache_names(i) key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) + set_data_ipc(key_cache, cache_names["key"]) self.cache_kvs_map[cache_names["key"]] = key_cache indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype) + set_data_ipc(indexer_cache, cache_names["indexer"]) self.cache_kvs_map[cache_names["indexer"]] = indexer_cache cache_kvs_list.extend([key_cache, indexer_cache])