Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ def from_pretrained(
default_padding_side: Literal["left", "right"] = "right",
dtype="float32",
first_n_layers: Optional[int] = None,
force_unsupported_model=False,
**from_pretrained_kwargs,
) -> T:
"""Load in a Pretrained Model.
Expand Down Expand Up @@ -1273,6 +1274,8 @@ def from_pretrained(
default_padding_side: Which side to pad on when tokenizing. Defaults to
"right".
first_n_layers: If specified, only load the first n layers of the model.
force_unsupported_model: If specified, the function will try to load the model specified by the user, even though
it may be not official supported by TransformerLens. Use it at your own risk.
"""
if model_name.lower().startswith("t5"):
raise RuntimeError(
Expand Down Expand Up @@ -1319,8 +1322,15 @@ def from_pretrained(
) and device in ["cpu", None]:
logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")

# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)
if force_unsupported_model:
# Force the loading of an unsupported model
logging.warning(
"You may be loading an unsupported model. Please be sure you know what you are doing and that you can expect unwanted behaviour"
)
official_model_name = model_name
else:
# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)

# Load the config into an HookedTransformerConfig object. If loading from a
# checkpoint, the config object will contain the information about the
Expand All @@ -1336,6 +1346,7 @@ def from_pretrained(
default_prepend_bos=default_prepend_bos,
dtype=dtype,
first_n_layers=first_n_layers,
force_unsupported_model=force_unsupported_model,
**from_pretrained_kwargs,
)

Expand Down Expand Up @@ -1368,7 +1379,12 @@ def from_pretrained(
# Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
# match the HookedTransformer parameter names.
state_dict = loading.get_pretrained_state_dict(
official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
official_model_name,
cfg,
hf_model,
dtype=dtype,
force_unsupported_model=force_unsupported_model,
**from_pretrained_kwargs,
)

# Create the HookedTransformer object
Expand Down
30 changes: 24 additions & 6 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def get_official_model_name(model_name: str):
return official_model_name


def convert_hf_model_config(model_name: str, **kwargs: Any):
def convert_hf_model_config(model_name: str, force_unsupported_model=False, **kwargs: Any):
"""
Returns the model config for a HuggingFace model, converted to a dictionary
in the HookedTransformerConfig format.
Expand All @@ -785,7 +785,10 @@ def convert_hf_model_config(model_name: str, **kwargs: Any):
logging.info("Loading model config from local directory")
official_model_name = model_name
else:
official_model_name = get_official_model_name(model_name)
if force_unsupported_model == False:
official_model_name = get_official_model_name(model_name)
else:
official_model_name = model_name

# Load HuggingFace model config
if "llama" in official_model_name.lower():
Expand Down Expand Up @@ -1643,6 +1646,7 @@ def get_pretrained_model_config(
default_prepend_bos: Optional[bool] = None,
dtype: torch.dtype = torch.float32,
first_n_layers: Optional[int] = None,
force_unsupported_model=False,
**kwargs: Any,
):
"""Returns the pretrained model config as an HookedTransformerConfig object.
Expand Down Expand Up @@ -1681,16 +1685,24 @@ def get_pretrained_model_config(
so this empirically seems to give better results. Note that you can also locally override the default behavior
by passing in prepend_bos=True/False when you call a method that processes the input string.
dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
force_unsupported_model: If specified, the function will try to load the model specified by the user, even though
it may be not official supported by TransformerLens. Use it at your own risk.
kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
Also given to other HuggingFace functions when compatible.

"""
if Path(model_name).exists():
# If the model_name is a path, it's a local model
cfg_dict = convert_hf_model_config(model_name, **kwargs)
cfg_dict = convert_hf_model_config(
model_name, force_unsupported_model=force_unsupported_model, **kwargs
)
official_model_name = model_name
else:
official_model_name = get_official_model_name(model_name)
if force_unsupported_model == False:
official_model_name = get_official_model_name(model_name)
else:
# Forcing an unsupported model
official_model_name = model_name
if (
official_model_name.startswith("NeelNanda")
or official_model_name.startswith("ArthurConmy")
Expand All @@ -1705,7 +1717,9 @@ def get_pretrained_model_config(
f"Loading model {official_model_name} requires setting trust_remote_code=True"
)
kwargs["trust_remote_code"] = True
cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
cfg_dict = convert_hf_model_config(
official_model_name, force_unsupported_model=force_unsupported_model, **kwargs
)
# Processing common to both model types
# Remove any prefix, saying the organization who made a model.
cfg_dict["model_name"] = official_model_name.split("/")[-1]
Expand Down Expand Up @@ -1842,6 +1856,7 @@ def get_pretrained_state_dict(
cfg: HookedTransformerConfig,
hf_model: Optional[Any] = None,
dtype: torch.dtype = torch.float32,
force_unsupported_model=False,
**kwargs: Any,
) -> dict[str, torch.Tensor]:
"""
Expand All @@ -1862,7 +1877,10 @@ def get_pretrained_state_dict(
official_model_name = str(Path(official_model_name).resolve())
logging.info(f"Loading model from local path {official_model_name}")
else:
official_model_name = get_official_model_name(official_model_name)
if force_unsupported_model == False:
official_model_name = get_official_model_name(official_model_name)
else:
official_model_name = official_model_name
if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
"trust_remote_code", False
):
Expand Down