class MLALayerOptimized(nn.Module):
"""
一个纯粹的、无位置编码 (NoPE) 且完全向量化的
Multi-head Latent Attention (MLA) 的优化实现。
- 训练/Prefill模式: 使用 F.scaled_dot_product_attention 以获得最佳性能 (支持 Flash Attention)。
- 推理模式: 实现论文中描述的、通过恒等变换达成的 MQA 式计算优化。
- 支持 Prefill 和单步解码。
- 解决了 c_norm 在不同路径下的逻辑一致性问题。
"""
def __init__(self, d_model: int, num_heads: int, d_latent: int, d_head: int = None, output_dim: int = None, **kwargs):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_latent = d_latent
self.d_head = d_head if d_head is not None else d_model // num_heads
# 确保 d_head * num_heads 不会出错
self.inner_dim = self.num_heads * self.d_head
# 投影矩阵
self.W_q = nn.Linear(d_model, self.inner_dim, bias=False)
self.W_c = nn.Linear(d_model, d_latent, bias=False)
self.W_k = nn.Linear(d_latent, self.inner_dim, bias=False)
self.W_v = nn.Linear(d_latent, self.inner_dim, bias=False)
self.W_o = nn.Linear(self.inner_dim, d_model if not output_dim else output_dim, bias=False)
self.c_norm = nn.RMSNorm(d_latent)
self.q_norm = nn.RMSNorm(self.inner_dim)
def forward(self, x: torch.Tensor, use_cache: bool = False, cache: torch.Tensor = None, attn_mask=None, **kwargs):
batch_size, seq_len, _ = x.shape
# ------------------------------------------------------------------
# 路径 1: 单步解码 (Decoding) - 当且仅当 use_cache=True 且 cache 已存在
# ------------------------------------------------------------------
if use_cache and cache is not None:
if seq_len != 1:
raise ValueError(f"Decoding with cache requires seq_len=1. cache: {cache}")
# 1. 计算当前 token 的 c 并更新 cache
c = self.W_c(x) # x shape: (B, 1, d_model) -> c shape: (B, 1, d_latent)
# if hasattr(self, 'c_norm'):
c = self.c_norm(c)
c_full = torch.cat([cache, c], dim=1) if cache is not None else c
# 2. 计算当前 token 的 Q
q = self.W_q(x) # (B, 1, inner_dim)
q = self.q_norm(q) # Apply q_norm: 这一步是安全的
q_current = q.view(batch_size, 1, self.num_heads, self.d_head)
# 3. 核心优化:实现 q' = q @ Wk.T
# q_current: (B, 1, H, D_h)
# W_k.weight: (H*D_h, D_l) -> (H, D_h, D_l)
# q_prime: (B, 1, H, D_l)
W_k_reshaped = self.W_k.weight.view(self.num_heads, self.d_head, self.d_latent)
q_prime = torch.einsum('bqhd,hdl->bqhl', q_current, W_k_reshaped)
# 4. 计算注意力分数 q' @ c.T
# q_prime: (B, 1, H, D_l)
# c_full: (B, L, D_l)
# attn_scores: (B, 1, H, L)
attn_scores = torch.einsum('bqhl,bkl->bqhk', q_prime, c_full) / math.sqrt(self.d_head)
# 5. 计算权重并对 c 进行加权求和 ("先求和")
# attn_weights: (B, 1, H, L)
# intermediate: (B, 1, H, D_l)
attn_weights = F.softmax(attn_scores - attn_scores.max(dim=-1, keepdim=True)[0] , dim=-1)
intermediate = torch.einsum('bqhk,bkl->bqhl', attn_weights, c_full)
# 6. 用 Wv 对中间结果进行变换 ("后变换")
# W_v.weight: (H*D_h, D_l) -> (H, D_h, D_l)
# head_output: (B, 1, H, D_h)
W_v_reshaped = self.W_v.weight.view(self.num_heads, self.d_head, self.d_latent)
head_output = torch.einsum('bqhl,hdl->bqhd', intermediate, W_v_reshaped)
# 7. 合并 head 并输出
combined_heads = head_output.contiguous().view(batch_size, 1, -1)
output = self.W_o(combined_heads)
return output, c_full
# ------------------------------------------------------------------
# 路径 2: 并行处理 (训练 或 推理的 Prefill 阶段)
# ------------------------------------------------------------------
c = self.W_c(x) # (B, L, d_latent)
q = self.W_q(x)
q = q.view(batch_size, seq_len, self.num_heads, self.d_head)
k = self.W_k(c).view(batch_size, seq_len, self.num_heads, self.d_head)
v = self.W_v(c).view(batch_size, seq_len, self.num_heads, self.d_head)
# 使用 PyTorch 2.0+ 的高效实现,is_causal=True 会自动应用因果掩码
# (B, L, H, D) -> (B, H, L, D) for SDPA
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# F.scaled_dot_product_attention 内部处理 softmax 和缩放
head_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=True)
combined_heads = head_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
output = self.W_o(combined_heads)
cache_to_return = c if use_cache else None
return output, cache_to_return
helm/helm/modules/hmla.py
Lines 126 to 176 in e8b4821
There, I didnt see something like down below:
Example NoPE MLA Code
I mean, HMLA doesn't looks like there is any to about KV Cache, It's more looking like MHA without KV-Cache inference running path.