From 04a94dac7ae27dc3c074b1a568215cbc77b89bdc Mon Sep 17 00:00:00 2001 From: Jonas Rohweder Date: Wed, 15 Jan 2025 17:03:07 +0100 Subject: [PATCH 1/5] possible implementation --- .../components/abstract_attention.py | 20 +-- .../components/rotary_embeddings.py | 132 ++++++++++++++++++ .../factories/rotary_embedding_factory.py | 11 ++ 3 files changed, 150 insertions(+), 13 deletions(-) create mode 100644 transformer_lens/components/rotary_embeddings.py create mode 100644 transformer_lens/factories/rotary_embedding_factory.py diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 009d2cfb8..9c9a647a5 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -15,8 +15,7 @@ from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear -from transformer_lens.utils import get_offset_position_ids - +from transformer_lens.factories.rotary_embedding_factory import RotaryEmbeddingFactory if is_bitsandbytes_available(): import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit @@ -123,14 +122,7 @@ def __init__( self.hook_rot_q = HookPoint() if self.cfg.rotary_dim is None: # keep mypy happy raise ValueError("Rotary dim must be provided for rotary positional embeddings") - sin, cos = self.calculate_sin_cos_rotary( - self.cfg.rotary_dim, - self.cfg.n_ctx, - base=self.cfg.rotary_base, - dtype=self.cfg.dtype, - ) - self.register_buffer("rotary_sin", sin) - self.register_buffer("rotary_cos", cos) + self.rotary_module = RotaryEmbeddingFactory.create_rotary(self.cfg) elif self.cfg.positional_embedding_type == "alibi": # ALiBi bias wil be constructed on the first forward pass. # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. @@ -204,10 +196,12 @@ def forward( kv_cache_pos_offset = 0 if self.cfg.positional_embedding_type == "rotary": - q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) + q = self.hook_rot_q( + self.rotary_module(q, kv_cache_pos_offset, attention_mask) + ) k = self.hook_rot_k( - self.apply_rotary(k, 0, attention_mask) - ) # keys are cached so no offset + self.rotary_module(k, 0, attention_mask) + ) if self.cfg.dtype not in [torch.float32, torch.float64]: # If using 16 bits, increase the precision to avoid numerical instabilities diff --git a/transformer_lens/components/rotary_embeddings.py b/transformer_lens/components/rotary_embeddings.py new file mode 100644 index 000000000..f8f8ea87d --- /dev/null +++ b/transformer_lens/components/rotary_embeddings.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from jaxtyping import Float, Int +from typing import Dict, Optional, Tuple, Union +from transformer_lens.utils import get_offset_position_ids +import einops + +class RotaryEmbedding(nn.Module): + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg = cfg + sin, cos = self.calculate_sin_cos_rotary( + rotary_dim=cfg.rotary_dim, + n_ctx=cfg.n_ctx, + base=cfg.rotary_base, + dtype=cfg.dtype + ) + self.register_buffer("rotary_sin", sin) + self.register_buffer("rotary_cos", cos) + + def calculate_sin_cos_rotary( + self, + rotary_dim: int, + n_ctx: int, + base: int = 10000, + dtype: torch.dtype = torch.float32, + ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: + """ + Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details + + Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. + To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. + """ + high_precision = torch.float32 if dtype != torch.float64 else torch.float64 + pos = torch.arange(n_ctx, dtype=high_precision) + dim = torch.arange(rotary_dim // 2, dtype=high_precision) + freq = base ** (dim / (rotary_dim / 2)) + if self.cfg.rotary_adjacent_pairs: + freq = einops.repeat(freq, "d -> (d 2)") + else: + freq = einops.repeat(freq, "d -> (2 d)") + # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency + angles = pos[:, None] / freq[None, :] + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) + + def forward( + self, + x: Float[torch.Tensor, "batch pos head_index d_head"], + past_kv_pos_offset=0, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos head_index d_head"]: + # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) + x_pos = x.size(1) + x_rot = x[..., : self.cfg.rotary_dim] + x_pass = x[..., self.cfg.rotary_dim :] + x_flip = self.rotate_every_two(x_rot) + + if attention_mask is None: + rotary_cos = self.rotary_cos[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + rotary_sin = self.rotary_sin[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + x_rotated = x_rot * rotary_cos + x_flip * rotary_sin + else: + offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) + offset_position_ids = offset_position_ids.to(self.rotary_cos.device) + mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] + mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] + x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin + + return torch.cat([x_rotated, x_pass], dim=-1) + + + def rotate_every_two( + self, x: Float[torch.Tensor, "... rotary_dim"] + ) -> Float[torch.Tensor, "... rotary_dim"]: + """ + Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] + + The final axis of x must have even length. + + GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. + """ + rot_x = x.clone() + if self.cfg.rotary_adjacent_pairs: + rot_x[..., ::2] = -x[..., 1::2] + rot_x[..., 1::2] = x[..., ::2] + else: + n = x.size(-1) // 2 + rot_x[..., :n] = -x[..., n:] + rot_x[..., n:] = x[..., :n] + + return rot_x + + + +class DynamicNTKScalingRotary(RotaryEmbedding): + + def calculate_sin_cos(self, rotary_dim, n_ctx, base, dtype, factor, low_freq_factor, high_freq_factor): + # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 + # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) + ) + factor = self.cfg.NTK_by_parts_factor + low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor + high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor + old_context_len = n_ctx + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + freq = 1 / inv_freq_llama + if self.cfg.rotary_adjacent_pairs: + freq = einops.repeat(freq, "d -> (d 2)") + angles = pos[:, None] / freq[None, :] + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) + + diff --git a/transformer_lens/factories/rotary_embedding_factory.py b/transformer_lens/factories/rotary_embedding_factory.py new file mode 100644 index 000000000..a5fb62175 --- /dev/null +++ b/transformer_lens/factories/rotary_embedding_factory.py @@ -0,0 +1,11 @@ +from transformer_lens.components.rotary_embeddings import RotaryEmbedding, DynamicNTKScalingRotary +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +class RotaryEmbeddingFactory: + @staticmethod + def create_rotary(cfg: HookedTransformerConfig) -> RotaryEmbedding: + if cfg.use_NTK_by_parts_rope: + return DynamicNTKScalingRotary(cfg) + else: + return RotaryEmbedding(cfg) From db236822d7d5c301ae705e2f27089943a6629e7f Mon Sep 17 00:00:00 2001 From: Jonas Rohweder Date: Wed, 15 Jan 2025 19:31:57 +0100 Subject: [PATCH 2/5] fully split rotary embeddings with attention --- .../components/abstract_attention.py | 61 ++------------- .../components/rotary_embeddings.py | 78 +++++++++---------- .../factories/rotary_embedding_factory.py | 5 +- 3 files changed, 48 insertions(+), 96 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 9c9a647a5..d3658c4cc 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -11,11 +11,12 @@ from transformers.utils import is_bitsandbytes_available from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.factories.rotary_embedding_factory import RotaryEmbeddingFactory from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry from transformer_lens.utilities.attention import complex_attn_linear, simple_attn_linear -from transformer_lens.factories.rotary_embedding_factory import RotaryEmbeddingFactory + if is_bitsandbytes_available(): import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit @@ -122,7 +123,7 @@ def __init__( self.hook_rot_q = HookPoint() if self.cfg.rotary_dim is None: # keep mypy happy raise ValueError("Rotary dim must be provided for rotary positional embeddings") - self.rotary_module = RotaryEmbeddingFactory.create_rotary(self.cfg) + self.rotary_module = RotaryEmbeddingFactory.create_rotary(self.cfg) elif self.cfg.positional_embedding_type == "alibi": # ALiBi bias wil be constructed on the first forward pass. # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. @@ -196,12 +197,8 @@ def forward( kv_cache_pos_offset = 0 if self.cfg.positional_embedding_type == "rotary": - q = self.hook_rot_q( - self.rotary_module(q, kv_cache_pos_offset, attention_mask) - ) - k = self.hook_rot_k( - self.rotary_module(k, 0, attention_mask) - ) + q = self.hook_rot_q(self.rotary_module(q, kv_cache_pos_offset, attention_mask)) + k = self.hook_rot_k(self.rotary_module(k, 0, attention_mask)) if self.cfg.dtype not in [torch.float32, torch.float64]: # If using 16 bits, increase the precision to avoid numerical instabilities @@ -518,55 +515,7 @@ def calculate_sin_cos_rotary( angles = pos[:, None] / freq[None, :] return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) - def rotate_every_two( - self, x: Float[torch.Tensor, "... rotary_dim"] - ) -> Float[torch.Tensor, "... rotary_dim"]: - """ - Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] - - The final axis of x must have even length. - - GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. - """ - rot_x = x.clone() - if self.cfg.rotary_adjacent_pairs: - rot_x[..., ::2] = -x[..., 1::2] - rot_x[..., 1::2] = x[..., ::2] - else: - n = x.size(-1) // 2 - rot_x[..., :n] = -x[..., n:] - rot_x[..., n:] = x[..., :n] - - return rot_x - - def apply_rotary( - self, - x: Float[torch.Tensor, "batch pos head_index d_head"], - past_kv_pos_offset=0, - attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, - ) -> Float[torch.Tensor, "batch pos head_index d_head"]: - # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) - x_pos = x.size(1) - x_rot = x[..., : self.cfg.rotary_dim] - x_pass = x[..., self.cfg.rotary_dim :] - x_flip = self.rotate_every_two(x_rot) - - if attention_mask is None: - rotary_cos = self.rotary_cos[ - None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : - ] - rotary_sin = self.rotary_sin[ - None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : - ] - x_rotated = x_rot * rotary_cos + x_flip * rotary_sin - else: - offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) - offset_position_ids = offset_position_ids.to(self.rotary_cos.device) - mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] - mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] - x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin - return torch.cat([x_rotated, x_pass], dim=-1) @staticmethod def create_alibi_slope( diff --git a/transformer_lens/components/rotary_embeddings.py b/transformer_lens/components/rotary_embeddings.py index f8f8ea87d..8b9874605 100644 --- a/transformer_lens/components/rotary_embeddings.py +++ b/transformer_lens/components/rotary_embeddings.py @@ -1,20 +1,21 @@ +import math +from typing import Optional, Tuple + +import einops import torch import torch.nn as nn -from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from jaxtyping import Float, Int -from typing import Dict, Optional, Tuple, Union + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utils import get_offset_position_ids -import einops + class RotaryEmbedding(nn.Module): def __init__(self, cfg: HookedTransformerConfig): super().__init__() self.cfg = cfg sin, cos = self.calculate_sin_cos_rotary( - rotary_dim=cfg.rotary_dim, - n_ctx=cfg.n_ctx, - base=cfg.rotary_base, - dtype=cfg.dtype + rotary_dim=cfg.rotary_dim, n_ctx=cfg.n_ctx, base=cfg.rotary_base, dtype=cfg.dtype ) self.register_buffer("rotary_sin", sin) self.register_buffer("rotary_cos", cos) @@ -36,14 +37,11 @@ def calculate_sin_cos_rotary( pos = torch.arange(n_ctx, dtype=high_precision) dim = torch.arange(rotary_dim // 2, dtype=high_precision) freq = base ** (dim / (rotary_dim / 2)) - if self.cfg.rotary_adjacent_pairs: - freq = einops.repeat(freq, "d -> (d 2)") - else: - freq = einops.repeat(freq, "d -> (2 d)") + freq = einops.repeat(freq, "d -> (d 2)") # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency angles = pos[:, None] / freq[None, :] return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) - + def forward( self, x: Float[torch.Tensor, "batch pos head_index d_head"], @@ -73,35 +71,40 @@ def forward( return torch.cat([x_rotated, x_pass], dim=-1) - def rotate_every_two( - self, x: Float[torch.Tensor, "... rotary_dim"] - ) -> Float[torch.Tensor, "... rotary_dim"]: - """ - Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] - - The final axis of x must have even length. - - GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. - """ - rot_x = x.clone() - if self.cfg.rotary_adjacent_pairs: - rot_x[..., ::2] = -x[..., 1::2] - rot_x[..., 1::2] = x[..., ::2] - else: - n = x.size(-1) // 2 - rot_x[..., :n] = -x[..., n:] - rot_x[..., n:] = x[..., :n] + self, x: Float[torch.Tensor, "... rotary_dim"] + ) -> Float[torch.Tensor, "... rotary_dim"]: + """ + Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] - return rot_x + The final axis of x must have even length. - + GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. + """ + rot_x = x.clone() + if self.cfg.rotary_adjacent_pairs: + rot_x[..., ::2] = -x[..., 1::2] + rot_x[..., 1::2] = x[..., ::2] + else: + n = x.size(-1) // 2 + rot_x[..., :n] = -x[..., n:] + rot_x[..., n:] = x[..., :n] + return rot_x class DynamicNTKScalingRotary(RotaryEmbedding): - - def calculate_sin_cos(self, rotary_dim, n_ctx, base, dtype, factor, low_freq_factor, high_freq_factor): - # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 + def calculate_sin_cos_rotary( + self, + rotary_dim: int, + n_ctx: int, + base: int = 10000, + dtype: torch.dtype = torch.float32, + ): + # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 + print("Using NTK-by-Parts Rotary Embedding") + high_precision = torch.float32 if dtype != torch.float64 else torch.float64 + pos = torch.arange(n_ctx, dtype=high_precision) + inv_freq = 1.0 / ( base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) ) @@ -124,9 +127,6 @@ def calculate_sin_cos(self, rotary_dim, n_ctx, base, dtype, factor, low_freq_fac is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) freq = 1 / inv_freq_llama - if self.cfg.rotary_adjacent_pairs: - freq = einops.repeat(freq, "d -> (d 2)") + freq = einops.repeat(freq, "d -> (d 2)") angles = pos[:, None] / freq[None, :] return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) - - diff --git a/transformer_lens/factories/rotary_embedding_factory.py b/transformer_lens/factories/rotary_embedding_factory.py index a5fb62175..b53d91db5 100644 --- a/transformer_lens/factories/rotary_embedding_factory.py +++ b/transformer_lens/factories/rotary_embedding_factory.py @@ -1,4 +1,7 @@ -from transformer_lens.components.rotary_embeddings import RotaryEmbedding, DynamicNTKScalingRotary +from transformer_lens.components.rotary_embeddings import ( + DynamicNTKScalingRotary, + RotaryEmbedding, +) from transformer_lens.HookedTransformerConfig import HookedTransformerConfig From 42a9f31cc1ee1a70097854852818cdd222531671 Mon Sep 17 00:00:00 2001 From: Jonas Rohweder Date: Wed, 15 Jan 2025 19:49:03 +0100 Subject: [PATCH 3/5] ran format --- transformer_lens/components/abstract_attention.py | 2 -- transformer_lens/components/rotary_embeddings.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index d3658c4cc..aa8ed7bf2 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -515,8 +515,6 @@ def calculate_sin_cos_rotary( angles = pos[:, None] / freq[None, :] return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) - - @staticmethod def create_alibi_slope( n_ctx: int, device: Optional[Union[str, torch.device]] = None diff --git a/transformer_lens/components/rotary_embeddings.py b/transformer_lens/components/rotary_embeddings.py index 8b9874605..ad33cd750 100644 --- a/transformer_lens/components/rotary_embeddings.py +++ b/transformer_lens/components/rotary_embeddings.py @@ -89,7 +89,8 @@ def rotate_every_two( n = x.size(-1) // 2 rot_x[..., :n] = -x[..., n:] rot_x[..., n:] = x[..., :n] - return rot_x + return rot_x + class DynamicNTKScalingRotary(RotaryEmbedding): def calculate_sin_cos_rotary( From 9fba9ebcff5e49326d6bc7b135b505c40c321ddc Mon Sep 17 00:00:00 2001 From: Jonas Rohweder Date: Thu, 16 Jan 2025 14:19:04 +0100 Subject: [PATCH 4/5] guarantee cos and sin are available from AbstractAttention --- .../components/test_abstract_attention.py | 31 +++++++++- transformer_lens/components/__init__.py | 1 + .../components/abstract_attention.py | 61 +++---------------- .../components/rotary_embeddings.py | 7 ++- 4 files changed, 43 insertions(+), 57 deletions(-) diff --git a/tests/unit/components/test_abstract_attention.py b/tests/unit/components/test_abstract_attention.py index 7820c1690..912188169 100644 --- a/tests/unit/components/test_abstract_attention.py +++ b/tests/unit/components/test_abstract_attention.py @@ -1,6 +1,7 @@ import torch -from transformer_lens.components import AbstractAttention +from transformer_lens.components import AbstractAttention, RotaryEmbedding +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig def test_create_alibi_slope(): @@ -38,3 +39,31 @@ def test_create_alibi_bias(): assert torch.equal( torch.tril(matrix, diagonal=-1), torch.tril(ref_lower_triangle, diagonal=-1) ) + + +def test_rotary_attribute_access(): + cfg = HookedTransformerConfig( + n_layers=12, + d_model=512, + n_ctx=1024, + d_head=64, + n_heads=8, + load_in_4bit=False, + dtype=torch.float32, + act_fn="relu", + rotary_dim=64, + rotary_base=10000, + rotary_adjacent_pairs=True, + ) + + rotary_module = RotaryEmbedding(cfg) + + class DummyAttention(AbstractAttention): + def __init__(self): + super().__init__(cfg) + self.rotary_module = rotary_module + + attention = DummyAttention() + + assert torch.equal(attention.rotary_sin, rotary_module.rotary_sin), "rotary_sin does not match!" + assert torch.equal(attention.rotary_cos, rotary_module.rotary_cos), "rotary_cos does not match!" diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py index 3b908fefb..4652ece53 100644 --- a/transformer_lens/components/__init__.py +++ b/transformer_lens/components/__init__.py @@ -22,6 +22,7 @@ from .grouped_query_attention import GroupedQueryAttention from .mlps.gated_mlp import GatedMLP from .mlps.mlp import MLP +from .rotary_embeddings import RotaryEmbedding # Interdependent modules from .bert_block import BertBlock diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index aa8ed7bf2..df8a0250b 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -1,4 +1,3 @@ -import math from abc import ABC from typing import Dict, Optional, Tuple, Union @@ -133,6 +132,14 @@ def __init__( # will be overwritten by the child T5Attention class self.has_relative_attention_bias = False + @property + def rotary_sin(self): + return self.rotary_module.rotary_sin + + @property + def rotary_cos(self): + return self.rotary_module.rotary_cos + @property def OV(self) -> FactoredMatrix: """ @@ -463,58 +470,6 @@ def apply_causal_mask( attn_scores = attn_scores.to(final_mask.device) return torch.where(final_mask, attn_scores, self.IGNORE) - def calculate_sin_cos_rotary( - self, - rotary_dim: int, - n_ctx: int, - base: int = 10000, - dtype: torch.dtype = torch.float32, - ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: - """ - Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details - - Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. - To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. - """ - high_precision = torch.float32 if dtype != torch.float64 else torch.float64 - pos = torch.arange(n_ctx, dtype=high_precision) - dim = torch.arange(rotary_dim // 2, dtype=high_precision) - - # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 - # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 - if self.cfg.use_NTK_by_parts_rope: - inv_freq = 1.0 / ( - base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) - ) - factor = self.cfg.NTK_by_parts_factor - low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor - high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor - old_context_len = n_ctx - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - wavelen = 2 * math.pi / inv_freq - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) - smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smoothed_inv_freq = ( - 1 - smooth_factor - ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama - is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - freq = 1 / inv_freq_llama - else: - freq = base ** (dim / (rotary_dim / 2)) - if self.cfg.rotary_adjacent_pairs: - freq = einops.repeat(freq, "d -> (d 2)") - else: - freq = einops.repeat(freq, "d -> (2 d)") - # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency - angles = pos[:, None] / freq[None, :] - return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) - @staticmethod def create_alibi_slope( n_ctx: int, device: Optional[Union[str, torch.device]] = None diff --git a/transformer_lens/components/rotary_embeddings.py b/transformer_lens/components/rotary_embeddings.py index ad33cd750..38aa243a3 100644 --- a/transformer_lens/components/rotary_embeddings.py +++ b/transformer_lens/components/rotary_embeddings.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple +from typing import Optional, Tuple, cast import einops import torch @@ -13,9 +13,10 @@ class RotaryEmbedding(nn.Module): def __init__(self, cfg: HookedTransformerConfig): super().__init__() - self.cfg = cfg + self.cfg: HookedTransformerConfig = cfg + rotary_dim = cast(int, self.cfg.rotary_dim) sin, cos = self.calculate_sin_cos_rotary( - rotary_dim=cfg.rotary_dim, n_ctx=cfg.n_ctx, base=cfg.rotary_base, dtype=cfg.dtype + rotary_dim=rotary_dim, n_ctx=cfg.n_ctx, base=cfg.rotary_base, dtype=cfg.dtype ) self.register_buffer("rotary_sin", sin) self.register_buffer("rotary_cos", cos) From 51c123a0c1e96aaefa6c846f57b13bff5ef56246 Mon Sep 17 00:00:00 2001 From: Jonas Rohw <40701485+jonasrohw@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:49:13 +0100 Subject: [PATCH 5/5] Remove print statement --- transformer_lens/components/rotary_embeddings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_lens/components/rotary_embeddings.py b/transformer_lens/components/rotary_embeddings.py index 38aa243a3..93fef12e7 100644 --- a/transformer_lens/components/rotary_embeddings.py +++ b/transformer_lens/components/rotary_embeddings.py @@ -103,7 +103,6 @@ def calculate_sin_cos_rotary( ): # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 - print("Using NTK-by-Parts Rotary Embedding") high_precision = torch.float32 if dtype != torch.float64 else torch.float64 pos = torch.arange(n_ctx, dtype=high_precision)