Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,3 +1106,223 @@ def forward(self, x: paddle.Tensor, proj_type: str = "k") -> paddle.Tensor:
return self.forward_v_b(x)
else:
raise ValueError(f"proj_type must be 'k' or 'v', got {proj_type}")


class QKVGateParallelLinear(ColumnParallelLinear):
"""
QKVGateParallelLinear
"""

def __init__(
self,
fd_config,
prefix,
with_bias=False,
num_heads: Optional[int] = None,
kv_num_heads: Optional[int] = None,
hidden_size: Optional[int] = None,
head_dim: Optional[int] = None,
skip_quant: bool = False,
weight_dtype: str = "",
):
self.prefix = prefix

self.qkv_weight_key = f"{prefix}.weight".replace("qkvg", "qkv")
self.gate_weight_key = f"{prefix}.weight".replace("qkvg_proj", "gate")
self.qkv_bias_key = f"{prefix}.bias".replace("qkvg", "qkv")
self.gate_bias_key = f"{prefix}.bias".replace("qkvg_proj", "gate")

self.num_heads = fd_config.model_config.num_attention_heads if num_heads is None else num_heads
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)

if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
self.kv_num_heads_per_rank = 1
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
output_size = (2 * self.num_heads + 2 * self.tp_size) * self.head_dim
else:
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
self.num_kv_head_replicas = 1
output_size = (2 * self.num_heads + 2 * self.kv_num_heads) * self.head_dim
input_size = self.hidden_size
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=output_size,
with_bias=with_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
)

def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
shard_size_mapping = {
"q": self.num_heads_per_rank * head_dim,
"k": self.kv_num_heads_per_rank * head_dim,
"v": self.kv_num_heads_per_rank * head_dim,
}
return shard_size_mapping.get(loaded_shard_id)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in [
"qkv",
"gate",
], f"loaded_shard_id must be one of ['qkv', 'gate'], but got {loaded_shard_id}"

if loaded_shard_id == "qkv":
self.qkv_weight_loader(param, loaded_weight, None)
else:
self.gate_weight_loader(param, loaded_weight)

def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0

# q_head + gate_head + kv_head
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if loaded_shard_id is None:
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
param.weight_need_transpose = False
# Loaded weight is already fused on disk
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.num_heads * head_dim),
("k", self.num_heads * head_dim, self.kv_num_heads * head_dim),
("v", (self.num_heads + self.kv_num_heads) * head_dim, self.kv_num_heads * head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.qkv_weight_loader(param, loaded_weight_shard, shard_id)
else:
# split q k v
assert loaded_shard_id in ["q", "k", "v"]
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)

if not param._is_initialized():
param.initialize()

if loaded_shard_id == "q":
param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads_per_rank * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.kv_num_heads_per_rank * head_dim
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)

param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
h2d_copy(param, loaded_weight)

def gate_weight_loader(self, param, loaded_weight):
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
dim = -1 if output_dim else 0
# q_head + gate_head + kv_head
head_dim = param.shape[dim] // (2 * self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
weight_need_transpose = getattr(param, "weight_need_transpose", False)

if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])

# Tensor parallelism splits the weight along the output_dim
if self.tp_size > 1 and output_dim is not None:
block_size = self.num_heads_per_rank * head_dim
shard_offset = self.local_rank * block_size
shard_size = block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)

if not param._is_initialized():
param.initialize()

param_shard_offset = (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * head_dim
param_shard_size = self.num_heads_per_rank * head_dim

if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)

param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
h2d_copy(param, loaded_weight)

def load_weight(self, state_dict: dict):
"""
Load the weight from the state dictionary.

Args:
state_dict (dict): A dictionary containing the weights
"""
qkv_weight_tensor = get_tensor(state_dict.pop(self.qkv_weight_key))
gate_weight_tensor = get_tensor(state_dict.pop(self.gate_weight_key))
qkvg_weight_tensor = paddle.concat([qkv_weight_tensor, gate_weight_tensor], axis=-1)

self.quant_method.process_loaded_weights(self, qkvg_weight_tensor)

def load_state_dict(self, state_dict: dict):
"""
Load the checkpoint state dictionary into the layer.

Args:
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
# weight
assert (
self.qkv_weight_key in state_dict.keys() and self.gate_weight_key in state_dict.keys()
), f"{self.qkv_weight_key} or {self.gate_weight_key} not found in state_dict"

if self.is_quantized:
self.load_prequant_weight(state_dict)
else:
self.load_weight(state_dict)

# bias
if self.with_bias:
assert (
self.qkv_bias_key in state_dict.keys() and self.gate_bias_key in state_dict.keys()
), f"{self.qkv_bias_key} or {self.gate_bias_key} not found in state_dict"
qkv_bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.qkv_bias_key)))
gate_bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.gate_bias_key)))
bias_tensor = paddle.concat([qkv_bias_tensor, gate_bias_tensor], axis=-1)

self.bias.set_value(bias_tensor)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVGateParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.moe import FusedMoE
Expand Down Expand Up @@ -160,6 +161,7 @@ def create_weights(self, layer, **extra_weight_attrs):
isinstance(layer, MergedColumnParallelLinear)
or isinstance(layer, QKVParallelLinear)
or isinstance(layer, MergedReplicatedLinear)
or isinstance(layer, QKVGateParallelLinear)
):
tensor_output_dim = (self.model_format == "torch") ^ quant_attrs.get("output_dim", True)
quant_attrs = {
Expand Down
Loading
Loading