Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 175 additions & 37 deletions fastdeploy/cache_manager/v1/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def __init__(self, config: "FDConfig", local_rank: int, device_id: int):
# NUMA binding flag
self._numa_bound = False

self._is_mla = getattr(self.model_config, "kv_lora_rank", 0) > 0
self._is_dsa = self._is_mla and getattr(self.model_config, "index_head_dim", 0) > 0

@property
def write_policy(self) -> Optional[str]:
"""Get the write policy for cache operations."""
Expand Down Expand Up @@ -230,13 +233,22 @@ def _get_cache_names(self, layer_idx: int) -> Dict[str, str]:
"""
local_rank = self._local_rank % self.parallel_config.tensor_parallel_size

return {
"key": f"key_caches_{layer_idx}_rank{local_rank}.device{self._device_id}",
"value": f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}",
"key_scale": f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}",
"value_scale": f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}",
names = {
"key": f"key_cache_{layer_idx}_rank{local_rank}.device{self._device_id}",
}

if self._is_dsa:
names["indexer"] = f"indexer_caches_{layer_idx}_rank{local_rank}.device{self._device_id}"
elif self._is_mla:
Comment on lines +236 to +242
pass # MLA: only key, no value, no indexer
else:
# GQA/MHA: key + value + optional scales
names["value"] = f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}"
names["key_scale"] = f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}"
names["value_scale"] = f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}"
Comment on lines +236 to +248

return names

# ============ KV Cache Management ============

def get_kv_caches(self) -> Optional[Dict[str, Any]]:
Expand All @@ -255,7 +267,7 @@ def initialize_kv_cache(
num_gpu_blocks: int,
) -> List[Any]:
"""
Initialize KV Cache tensors.
Initialize KV Cache tensors (GQA/MHA only).

Create KV Cache tensors on GPU for storing attention Key and Value.

Expand All @@ -266,37 +278,40 @@ def initialize_kv_cache(
Returns:
cache_kvs_list: KV Cache tensor list in [key_cache_layer0, value_cache_layer0, ...] order.
"""
# Get kv cache quantization type
kv_cache_quant_type = self._get_kv_cache_quant_type()
# Dispatch to specialized initializers for MLA/DSA
if self._is_dsa:
return self.initialize_dsa_kv_cache(attn_backend, num_gpu_blocks)
elif self._is_mla:
return self.initialize_mla_kv_cache(attn_backend, num_gpu_blocks)

# Get kv cache shape
# GQA/MHA path
kv_cache_quant_type = self._get_kv_cache_quant_type()
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
cache_dtype = self.model_config.dtype

# Get scale shape for block_wise_fp8 quantization
# Scale shape for block_wise_fp8 quantization
kv_cache_scale_shape = None
if self._is_fp8_quantization(kv_cache_quant_type):
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]

logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}")
logger.info(
f"Initializing GQA kv cache: num_layers={self._num_layers}, "
f"key_shape={key_cache_shape}, value_shape={value_cache_shape}"
)
cache_kvs_list = []

for i in range(self._num_layers):
# Generate cache names
cache_names = self._get_cache_names(i)

logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")

# Create key cache and value cache
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache

val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])

# Create scale caches for block_wise_fp8 quantization
if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape:
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
Expand All @@ -309,14 +324,108 @@ def initialize_kv_cache(
cache_kvs_list.extend([key_cache_scales, val_cache_scales])

paddle.device.cuda.empty_cache()
logger.info("kv cache is initialized!")
logger.info("GQA kv cache initialized!")

# Share cache_kvs_map with transfer manager for data transfer operations
self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map)

# Initialize host cache
self.initialize_host_cache(attn_backend)
return cache_kvs_list

def initialize_mla_kv_cache(
self,
attn_backend: Any,
num_gpu_blocks: int,
) -> List[Any]:
"""
Initialize MLA KV Cache tensors (key only, no value).

Args:
attn_backend: Attention backend instance for getting kv cache shape.
num_gpu_blocks: Maximum number of blocks on GPU.

Returns:
cache_kvs_list: KV Cache tensor list in [key_layer0, key_layer1, ...] order.
"""
kv_cache_quant_type = self._get_kv_cache_quant_type()
key_cache_shape, _ = attn_backend.get_kv_cache_shape(
max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
cache_dtype = self.model_config.dtype

# NOTE: set_data_ipc pins tensor storage so paddle allocator cannot
# reuse/migrate it. Without pinning, CUDAGraph capture records a
# data_ptr that allocator may later mark reusable, corrupting replay.
# Align with V0 path (gpu_model_runner.initialize_kv_cache).
from fastdeploy.model_executor.ops.gpu import set_data_ipc

logger.info(f"Initializing MLA kv cache: num_layers={self._num_layers}, " f"key_shape={key_cache_shape}")
cache_kvs_list = []

for i in range(self._num_layers):
cache_names = self._get_cache_names(i)

key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
set_data_ipc(key_cache, cache_names["key"])
self.cache_kvs_map[cache_names["key"]] = key_cache
cache_kvs_list.append(key_cache)

paddle.device.cuda.empty_cache()
logger.info("MLA kv cache initialized!")

self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map)
return cache_kvs_list

def initialize_dsa_kv_cache(
self,
attn_backend: Any,
num_gpu_blocks: int,
) -> List[Any]:
"""
Initialize DSA KV Cache tensors (key + indexer, two pools).

Creates interleaved [key, indexer, key, indexer, ...] layout.
Future HiSparse extension: add host_blocks parameter for key host backup.

Args:
attn_backend: Attention backend instance for getting kv cache shape.
num_gpu_blocks: Maximum number of blocks on GPU.

Returns:
cache_kvs_list: KV Cache tensor list in [key_layer0, indexer_layer0, ...] order.
"""
key_cache_shape, _, indexer_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_gpu_blocks, kv_cache_quant_type="uint8"
)
cache_dtype = "uint8"

# NOTE: set_data_ipc pins tensor storage so paddle allocator cannot
# reuse/migrate it. Without pinning, CUDAGraph capture records a
# data_ptr that allocator may later mark reusable, corrupting replay.
# Align with V0 path (gpu_model_runner.initialize_kv_cache).
from fastdeploy.model_executor.ops.gpu import set_data_ipc

logger.info(
f"Initializing DSA kv cache: num_layers={self._num_layers}, "
f"key_shape={key_cache_shape}, indexer_shape={indexer_cache_shape}"
)
cache_kvs_list = []

for i in range(self._num_layers):
cache_names = self._get_cache_names(i)

key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
set_data_ipc(key_cache, cache_names["key"])
self.cache_kvs_map[cache_names["key"]] = key_cache

indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype)
set_data_ipc(indexer_cache, cache_names["indexer"])
self.cache_kvs_map[cache_names["indexer"]] = indexer_cache

cache_kvs_list.extend([key_cache, indexer_cache])

paddle.device.cuda.empty_cache()
logger.info("DSA kv cache initialized!")

self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map)
return cache_kvs_list

def initialize_mtp_kv_cache(
Expand Down Expand Up @@ -346,31 +455,50 @@ def initialize_mtp_kv_cache(
"""
kv_cache_quant_type = self._get_kv_cache_quant_type()

key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if self._is_dsa:
kv_cache_quant_type = "uint8"
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug initialize_mtp_kv_cache 的 DSA 分支解包值数量错误,且 indexer_cache_shape 未定义

DSA backend 的 get_kv_cache_shape 返回 3 个值(如 initialize_dsa_kv_cacheinitialize_host_cache 中正确解包为 key_cache_shape, _, indexer_cache_shape 所示),但此处只解包 2 个值,将抛出 ValueError: too many values to unpack

此外,indexer_cache_shape_is_dsa=True 分支中从未赋值,后续 elif indexer_cache_shape: 会触发 NameError

建议修复:

if self._is_dsa:
    kv_cache_quant_type = "uint8"
    key_cache_shape, _, indexer_cache_shape = attn_backend.get_kv_cache_shape(
        max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
    )
    value_cache_shape = []  # DSA 没有 value cache
    cache_dtype = "uint8"
else:
    key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
        max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
    )
    indexer_cache_shape = []
    cache_dtype = self.model_config.dtype

max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
cache_dtype = "uint8"
else:
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
indexer_cache_shape = []
cache_dtype = self.model_config.dtype
Comment on lines +458 to +469

kv_cache_scale_shape = None
if self._is_fp8_quantization(kv_cache_quant_type):
if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type):
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]

logger.info(
f"[CacheController] Initializing MTP kv cache for {num_mtp_layers} layers "
f"(layer_offset={layer_offset}, num_gpu_blocks={num_gpu_blocks})."
f"is_dsa = {self._is_dsa}, _is_mla = {self._is_mla}."
)
cache_kvs_list = []

for i in range(layer_offset, layer_offset + num_mtp_layers):
cache_names = self._get_cache_names(i)

key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache

val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])

if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape:
if value_cache_shape:
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])
elif indexer_cache_shape:
# DSA: key + indexer
indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype)
self.cache_kvs_map[cache_names["indexer"]] = indexer_cache
cache_kvs_list.extend([key_cache, indexer_cache])
else:
# MLA: only key, no value, no indexer
cache_kvs_list.append(key_cache)

if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape:
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
Expand Down Expand Up @@ -542,9 +670,16 @@ def initialize_host_cache(
kv_cache_quant_type = self._get_kv_cache_quant_type()

# Get kv cache shape (pass num_host_blocks as max_num_blocks for host cache)
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if self._is_dsa:
kv_cache_quant_type = "uint8"
key_cache_shape, _, indexer_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type
)
value_cache_shape = []
else:
Comment on lines 672 to +679
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type
)

# Calculate cache sizes (elements per block per layer)
key_cache_size = key_cache_shape[1] * key_cache_shape[2] * key_cache_shape[3]
Expand All @@ -554,8 +689,11 @@ def initialize_host_cache(
value_cache_size = 0

# Get cache dtype and bytes per element
cache_dtype = self.cache_config.cache_dtype
cache_item_bytes = self.cache_config.get_cache_bytes(cache_dtype)
if self._is_dsa:
cache_item_bytes = 1
else:
cache_dtype = self.cache_config.cache_dtype
cache_item_bytes = self.cache_config.get_cache_bytes(cache_dtype)

# Calculate total bytes to allocate
key_need_to_allocate_bytes = num_host_blocks * cache_item_bytes * key_cache_size
Expand Down
Loading
Loading