diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 72ce31438d6..07d2cf98aec 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -898,18 +898,11 @@ def _init_wo(self, config: ModelArgs) -> None: def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None: if self.use_qk_norm: - if getattr(config, "qk_norm_affine", True): - self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) - if is_kv_shared_layer: - self.k_norm = nn.Identity() - else: - self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + if is_kv_shared_layer: + self.k_norm = nn.Identity() else: - self.q_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) - if is_kv_shared_layer: - self.k_norm = nn.Identity() - else: - self.k_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) + self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) else: self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity()