Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ class HookedTransformerConfig:
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0
NTK_original_ctx_len: int = 8192
n_query_heads: Optional[List[int]] = None
d_mlps: Optional[List[int]] = None

def __post_init__(self):
if self.n_heads == -1:
Expand Down
63 changes: 62 additions & 1 deletion transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
convert_neel_solu_old_weights,
convert_neo_weights,
convert_neox_weights,
convert_openelm_weights,
convert_opt_weights,
convert_phi3_weights,
convert_phi_weights,
Expand Down Expand Up @@ -263,6 +264,14 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"apple/OpenELM-270M",
"apple/OpenELM-450M",
"apple/OpenELM-1_1B",
"apple/OpenELM-3B",
"apple/OpenELM-270M-Instruct",
"apple/OpenELM-450M-Instruct",
"apple/OpenELM-1_1B-Instruct",
"apple/OpenELM-3B-Instruct",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -1436,6 +1445,53 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
"parallel_attn_mlp": False,
"rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
}
elif architecture == "OpenELMForCausalLM":

def make_divisible(
v: Union[float, int],
divisor: Optional[int] = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by the divisor
It can be seen at:
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
Args:
v: input value
divisor: default to 8
min_value: minimum divisor value
Returns:
new_v: new divisible value
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v

cfg_dict = {
"d_model": hf_config.model_dim,
"d_head": hf_config.head_dim,
"n_heads": 64, # is this variable too? ,
"n_layers": hf_config.num_transformer_layers,
"n_ctx": hf_config.max_context_length,
"eps": 23, # what is going on here??
"d_vocab": hf_config.vocab_size,
"act_fn": "silu",
"initializer_range": hf_config.initializer_range,
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"trust_remote_code": True,
"n_key_value_heads": hf_config.num_kv_heads,
"n_query_heads": hf_config.num_query_heads,
"d_mlps": [
(2 * int(make_divisible(val * hf_config.model_dim, hf_config.ffn_dim_divisor)))
for val in hf_config.ffn_multipliers
],
}

elif official_model_name.startswith("google/gemma-2b"):
# Architecture for Gemma 2b and Gemma 2b Instruct models
Expand Down Expand Up @@ -1587,7 +1643,10 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
# All of these models use LayerNorm
cfg_dict["original_architecture"] = architecture
# The name such that AutoTokenizer.from_pretrained works
cfg_dict["tokenizer_name"] = official_model_name
if architecture == "OpenELMForCausalLM":
cfg_dict["tokenizer_name"] = "meta-llama/Llama-2-7b-hf"
else:
cfg_dict["tokenizer_name"] = official_model_name
if kwargs.get("trust_remote_code", False):
cfg_dict["trust_remote_code"] = True
return cfg_dict
Expand Down Expand Up @@ -1986,6 +2045,8 @@ def get_pretrained_state_dict(
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "Gemma2ForCausalLM":
state_dict = convert_gemma_weights(hf_model, cfg)
elif cfg.original_architecture == "OpenELMForCausalLM":
state_dict = convert_openelm_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
from .neel_solu_old import convert_neel_solu_old_weights
from .openelm import convert_openelm_weights
68 changes: 68 additions & 0 deletions transformer_lens/pretrained/weight_conversions/openelm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_openelm_weights(openelm, cfg: HookedTransformerConfig):
state_dict = {}

assert cfg.d_mlp is not None
assert cfg.n_key_value_heads is not None

state_dict["embed.W_E"] = openelm.transformer.token_embeddings.weight

for l in range(cfg.n_layers):
WQ = openelm.transformer.layers[l].attn.qkv_proj.weight[
: (cfg.n_query_heads[l] * cfg.d_head)
]
WK = openelm.transformer.layers[l].attn.qkv_proj.weight[
(cfg.n_query_heads[l] * cfg.d_head) : (
(cfg.n_query_heads[l] + cfg.n_key_value_heads[l]) * cfg.d_head
)
]
WV = openelm.transformer.layers[l].attn.qkv_proj.weight[
-cfg.n_key_value_heads[l] * cfg.d_head :
]

WQ = einops.rearrange(WQ, "(n h) m->n m h", n=cfg.n_query_heads[l])
WK = einops.rearrange(WK, "(n h) m->n m h", n=cfg.n_key_value_heads[l])
WV = einops.rearrange(WV, "(n h) m->n m h", n=cfg.n_key_value_heads[l])

state_dict[f"blocks.{l}.attn.W_Q"] = WQ
state_dict[f"blocks.{l}.attn._W_K"] = WK
state_dict[f"blocks.{l}.attn._W_V"] = WV

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype
)
state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
cfg.n_key_value_heads[l], cfg.d_head, dtype=cfg.dtype
)

WO = openelm.transformer.layers[l].attn.out_proj.weight
WO = einops.rearrange(WO, "m (n h)->n h m", n=cfg.n_query_heads[l])
state_dict[f"blocks.{l}.attn.W_O"] = WO

state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
state_dict[f"blocks.{l}.ln2.w"] = openelm.transformer.layers[l].attn_norm.weight

state_dict[f"blocks.{l}.mlp.W_in"] = (
openelm.transformer.layers[l].ffn.proj_1.weight[: cfg.d_mlps[l], :].T
)
state_dict[f"blocks.{l}.mlp.W_gate"] = (
openelm.transformer.layers[l].ffn.proj_1.weight[cfg.d_mlps[l] :, :].T
)
state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlps[l], dtype=cfg.dtype)
state_dict[f"blocks.{l}.mlp.W_out"] = openelm.transformer.layers[l].ffn.proj_2.weight.T
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
openelm.transformer.layers[l].ffn.proj_2.weight.shape[0], dtype=cfg.dtype
)

state_dict[f"blocks.{l}.mlp.ln3.w"] = openelm.transformer.layers[l].ffn_norm.weight

state_dict["ln_final.w"] = openelm.transformer.norm.weight

state_dict["unembed.W_U"] = openelm.transformer.token_embeddings.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
Loading