From f241feb09be27975a2b73403f5aaefa4ada2723d Mon Sep 17 00:00:00 2001 From: Israel Adewuyi Date: Tue, 21 Jan 2025 21:49:02 +0300 Subject: [PATCH 1/5] added list of apple openelm models --- transformer_lens/loading_from_pretrained.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 17d32e8c7..dca1ee3ac 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -250,6 +250,14 @@ "google-t5/t5-base", "google-t5/t5-large", "ai-forever/mGPT", + "apple/OpenELM-270M", + "apple/OpenELM-450M", + "apple/OpenELM-1_1B", + "apple/OpenELM-3B", + "apple/OpenELM-270M-Instruct", + "apple/OpenELM-450M-Instruct", + "apple/OpenELM-1_1B-Instruct", + "apple/OpenELM-3B-Instruct", ] """Official model names for models on HuggingFace.""" From e0037a1c87b13e323f86ce305ad7e0f24855e037 Mon Sep 17 00:00:00 2001 From: Israel Adewuyi Date: Mon, 3 Feb 2025 22:13:52 +0300 Subject: [PATCH 2/5] Added weight conversions for OpenELM --- .../pretrained/weight_conversions/__init__.py | 1 + .../pretrained/weight_conversions/openelm.py | 53 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 transformer_lens/pretrained/weight_conversions/openelm.py diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index b13850ee0..c9eda6c61 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -18,3 +18,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .openelm import convert_openelm_weights diff --git a/transformer_lens/pretrained/weight_conversions/openelm.py b/transformer_lens/pretrained/weight_conversions/openelm.py new file mode 100644 index 000000000..850d76d62 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/openelm.py @@ -0,0 +1,53 @@ +import torch +import einops + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + +def convert_openelm_weights(openelm, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.d_mlp is not None + assert cfg.n_key_value_heads is not None + + state_dict["embed.W_E"] = openelm.transformer.token_embeddings.weight + + for l in range(cfg.n_layers): + WQ = openelm.transformer.layers[l].attn.qkv_proj.weight[:(cfg.n_query_heads[l] * cfg.d_head)] + WK = openelm.transformer.layers[l].attn.qkv_proj.weight[(cfg.n_query_heads[l] * cfg.d_head) : ((cfg.n_query_heads[l] + cfg.n_key_value_heads[l]) * cfg.d_head)] + WV = openelm.transformer.layers[l].attn.qkv_proj.weight[-cfg.n_key_value_heads[l] * cfg.d_head:] + + WQ = einops.rearrange(WQ, "(n h) m->n m h", n=cfg.n_query_heads[l]) + WK = einops.rearrange(WK, "(n h) m->n m h", n=cfg.n_key_value_heads[l]) + WV = einops.rearrange(WV, "(n h) m->n m h", n=cfg.n_key_value_heads[l]) + + state_dict[f"blocks.{l}.attn.W_Q"] = WQ + state_dict[f"blocks.{l}.attn._W_K"] = WK + state_dict[f"blocks.{l}.attn._W_V"] = WV + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype + ) + + WO = openelm.transformer.layers[l].attn.out_proj.weight + WO = einops.rearrange(WO, "m (n h)->n h m", n=cfg.n_query_heads[l]) + state_dict[f"blocks.{l}.attn.W_O"] = WO + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln2.w"] = openelm.transformer.layers[l].attn_norm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = openelm.transformer.layers[l].ffn.proj_1.weight[:cfg.d_mlps[l], :].T + state_dict[f"blocks.{l}.mlp.W_gate"] = openelm.transformer.layers[l].ffn.proj_1.weight[cfg.d_mlps[l]:, :].T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlps[l], dtype=cfg.dtype) + state_dict[f"blocks.{l}.mlp.W_out"] = openelm.transformer.layers[l].ffn.proj_2.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(openelm.transformer.layers[l].ffn.proj_2.weight.shape[0], dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.ln3.w"] = openelm.transformer.layers[l].ffn_norm.weight + + state_dict["ln_final.w"] = openelm.transformer.norm.weight + + state_dict["unembed.W_U"] = openelm.transformer.token_embeddings.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) \ No newline at end of file From 4d0dab60b756eed9ff961d162af1d21256ee8ece Mon Sep 17 00:00:00 2001 From: Israel Adewuyi Date: Mon, 3 Feb 2025 22:25:19 +0300 Subject: [PATCH 3/5] Added OpenELM to loading_from_pretrained --- transformer_lens/loading_from_pretrained.py | 51 ++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index dca1ee3ac..3b2632521 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -41,6 +41,7 @@ convert_qwen2_weights, convert_qwen_weights, convert_t5_weights, + convert_openelm_weights, ) OFFICIAL_MODEL_NAMES = [ @@ -1345,6 +1346,49 @@ def convert_hf_model_config(model_name: str, **kwargs): "parallel_attn_mlp": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } + elif architecture == "OpenELMForCausalLM": + def make_divisible( + v: Union[float, int], + divisor: Optional[int] = 8, + min_value: Optional[Union[float, int]] = None, + ) -> Union[float, int]: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by the divisor + It can be seen at: + https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62 + Args: + v: input value + divisor: default to 8 + min_value: minimum divisor value + Returns: + new_v: new divisible value + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + cfg_dict = { + "d_model": hf_config.model_dim, + "d_head": hf_config.head_dim, + "n_heads": 64,# is this variable too? , + "n_layers": hf_config.num_transformer_layers, + "n_ctx": hf_config.max_context_length, + "eps": 23, # what is going on here?? + "d_vocab": hf_config.vocab_size, + "act_fn": "silu", + "initializer_range": hf_config.initializer_range, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "trust_remote_code": True, + "n_key_value_heads": hf_config.num_kv_heads, + "n_query_heads": hf_config.num_query_heads, + "d_mlps": [(2 * int(make_divisible(val * hf_config.model_dim, hf_config.ffn_dim_divisor))) for val in hf_config.ffn_multipliers], + } elif official_model_name.startswith("google/gemma-2b"): # Architecture for Gemma 2b and Gemma 2b Instruct models @@ -1496,7 +1540,10 @@ def convert_hf_model_config(model_name: str, **kwargs): # All of these models use LayerNorm cfg_dict["original_architecture"] = architecture # The name such that AutoTokenizer.from_pretrained works - cfg_dict["tokenizer_name"] = official_model_name + if architecture == "OpenELMForCausalLM": + cfg_dict["tokenizer_name"] = "meta-llama/Llama-2-7b-hf" + else: + cfg_dict["tokenizer_name"] = official_model_name if kwargs.get("trust_remote_code", False): cfg_dict["trust_remote_code"] = True return cfg_dict @@ -1890,6 +1937,8 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma2ForCausalLM": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "OpenELMForCausalLM": + state_dict = convert_openelm_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." From 2006d6a7ea1405718e488d9e153f5269623d3aca Mon Sep 17 00:00:00 2001 From: Israel Adewuyi Date: Mon, 3 Feb 2025 22:27:15 +0300 Subject: [PATCH 4/5] Extended args of the HookedTransformerConfig class --- transformer_lens/HookedTransformerConfig.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 4458705de..eff721816 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -243,7 +243,7 @@ class HookedTransformerConfig: default_prepend_bos: bool = True dtype: torch.dtype = torch.float32 tokenizer_prepends_bos: Optional[bool] = None - n_key_value_heads: Optional[int] = None + n_key_value_heads: Optional[List[int]] = None post_embedding_ln: bool = False rotary_base: int = 10000 trust_remote_code: bool = False @@ -262,6 +262,8 @@ class HookedTransformerConfig: NTK_by_parts_low_freq_factor: float = 1.0 NTK_by_parts_high_freq_factor: float = 4.0 NTK_by_parts_factor: float = 8.0 + n_query_heads: Optional[List[int]] = None + d_mlps: Optional[List[int]] = None def __post_init__(self): if self.n_heads == -1: From 1b214aece3cd3fc057217cd5485b4bdbc8eea266 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 19 Jun 2025 20:36:13 +0200 Subject: [PATCH 5/5] ran format --- transformer_lens/loading_from_pretrained.py | 18 ++++++----- .../pretrained/weight_conversions/openelm.py | 31 ++++++++++++++----- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 520e5c41c..7e281ddbf 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -37,6 +37,7 @@ convert_neel_solu_old_weights, convert_neo_weights, convert_neox_weights, + convert_openelm_weights, convert_opt_weights, convert_phi3_weights, convert_phi_weights, @@ -44,7 +45,6 @@ convert_qwen3_weights, convert_qwen_weights, convert_t5_weights, - convert_openelm_weights, ) OFFICIAL_MODEL_NAMES = [ @@ -1446,6 +1446,7 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, } elif architecture == "OpenELMForCausalLM": + def make_divisible( v: Union[float, int], divisor: Optional[int] = 8, @@ -1470,23 +1471,26 @@ def make_divisible( if new_v < 0.9 * v: new_v += divisor return new_v - + cfg_dict = { "d_model": hf_config.model_dim, "d_head": hf_config.head_dim, - "n_heads": 64,# is this variable too? , + "n_heads": 64, # is this variable too? , "n_layers": hf_config.num_transformer_layers, "n_ctx": hf_config.max_context_length, - "eps": 23, # what is going on here?? + "eps": 23, # what is going on here?? "d_vocab": hf_config.vocab_size, "act_fn": "silu", - "initializer_range": hf_config.initializer_range, + "initializer_range": hf_config.initializer_range, "normalization_type": "RMS", - "positional_embedding_type": "rotary", + "positional_embedding_type": "rotary", "trust_remote_code": True, "n_key_value_heads": hf_config.num_kv_heads, "n_query_heads": hf_config.num_query_heads, - "d_mlps": [(2 * int(make_divisible(val * hf_config.model_dim, hf_config.ffn_dim_divisor))) for val in hf_config.ffn_multipliers], + "d_mlps": [ + (2 * int(make_divisible(val * hf_config.model_dim, hf_config.ffn_dim_divisor))) + for val in hf_config.ffn_multipliers + ], } elif official_model_name.startswith("google/gemma-2b"): diff --git a/transformer_lens/pretrained/weight_conversions/openelm.py b/transformer_lens/pretrained/weight_conversions/openelm.py index 850d76d62..129e9fad4 100644 --- a/transformer_lens/pretrained/weight_conversions/openelm.py +++ b/transformer_lens/pretrained/weight_conversions/openelm.py @@ -1,8 +1,9 @@ -import torch import einops +import torch from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + def convert_openelm_weights(openelm, cfg: HookedTransformerConfig): state_dict = {} @@ -12,9 +13,17 @@ def convert_openelm_weights(openelm, cfg: HookedTransformerConfig): state_dict["embed.W_E"] = openelm.transformer.token_embeddings.weight for l in range(cfg.n_layers): - WQ = openelm.transformer.layers[l].attn.qkv_proj.weight[:(cfg.n_query_heads[l] * cfg.d_head)] - WK = openelm.transformer.layers[l].attn.qkv_proj.weight[(cfg.n_query_heads[l] * cfg.d_head) : ((cfg.n_query_heads[l] + cfg.n_key_value_heads[l]) * cfg.d_head)] - WV = openelm.transformer.layers[l].attn.qkv_proj.weight[-cfg.n_key_value_heads[l] * cfg.d_head:] + WQ = openelm.transformer.layers[l].attn.qkv_proj.weight[ + : (cfg.n_query_heads[l] * cfg.d_head) + ] + WK = openelm.transformer.layers[l].attn.qkv_proj.weight[ + (cfg.n_query_heads[l] * cfg.d_head) : ( + (cfg.n_query_heads[l] + cfg.n_key_value_heads[l]) * cfg.d_head + ) + ] + WV = openelm.transformer.layers[l].attn.qkv_proj.weight[ + -cfg.n_key_value_heads[l] * cfg.d_head : + ] WQ = einops.rearrange(WQ, "(n h) m->n m h", n=cfg.n_query_heads[l]) WK = einops.rearrange(WK, "(n h) m->n m h", n=cfg.n_key_value_heads[l]) @@ -39,15 +48,21 @@ def convert_openelm_weights(openelm, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) state_dict[f"blocks.{l}.ln2.w"] = openelm.transformer.layers[l].attn_norm.weight - state_dict[f"blocks.{l}.mlp.W_in"] = openelm.transformer.layers[l].ffn.proj_1.weight[:cfg.d_mlps[l], :].T - state_dict[f"blocks.{l}.mlp.W_gate"] = openelm.transformer.layers[l].ffn.proj_1.weight[cfg.d_mlps[l]:, :].T + state_dict[f"blocks.{l}.mlp.W_in"] = ( + openelm.transformer.layers[l].ffn.proj_1.weight[: cfg.d_mlps[l], :].T + ) + state_dict[f"blocks.{l}.mlp.W_gate"] = ( + openelm.transformer.layers[l].ffn.proj_1.weight[cfg.d_mlps[l] :, :].T + ) state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlps[l], dtype=cfg.dtype) state_dict[f"blocks.{l}.mlp.W_out"] = openelm.transformer.layers[l].ffn.proj_2.weight.T - state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(openelm.transformer.layers[l].ffn.proj_2.weight.shape[0], dtype=cfg.dtype) + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + openelm.transformer.layers[l].ffn.proj_2.weight.shape[0], dtype=cfg.dtype + ) state_dict[f"blocks.{l}.mlp.ln3.w"] = openelm.transformer.layers[l].ffn_norm.weight state_dict["ln_final.w"] = openelm.transformer.norm.weight state_dict["unembed.W_U"] = openelm.transformer.token_embeddings.weight.T - state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) \ No newline at end of file + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)