diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 025b43793..e76bb6317 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1134,6 +1134,7 @@ def from_pretrained( default_padding_side: Literal["left", "right"] = "right", dtype="float32", first_n_layers: Optional[int] = None, + force_unsupported_model=False, **from_pretrained_kwargs, ) -> T: """Load in a Pretrained Model. @@ -1273,6 +1274,8 @@ def from_pretrained( default_padding_side: Which side to pad on when tokenizing. Defaults to "right". first_n_layers: If specified, only load the first n layers of the model. + force_unsupported_model: If specified, the function will try to load the model specified by the user, even though + it may be not official supported by TransformerLens. Use it at your own risk. """ if model_name.lower().startswith("t5"): raise RuntimeError( @@ -1319,8 +1322,15 @@ def from_pretrained( ) and device in ["cpu", None]: logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") - # Get the model name used in HuggingFace, rather than the alias. - official_model_name = loading.get_official_model_name(model_name) + if force_unsupported_model: + # Force the loading of an unsupported model + logging.warning( + "You may be loading an unsupported model. Please be sure you know what you are doing and that you can expect unwanted behaviour" + ) + official_model_name = model_name + else: + # Get the model name used in HuggingFace, rather than the alias. + official_model_name = loading.get_official_model_name(model_name) # Load the config into an HookedTransformerConfig object. If loading from a # checkpoint, the config object will contain the information about the @@ -1336,6 +1346,7 @@ def from_pretrained( default_prepend_bos=default_prepend_bos, dtype=dtype, first_n_layers=first_n_layers, + force_unsupported_model=force_unsupported_model, **from_pretrained_kwargs, ) @@ -1368,7 +1379,12 @@ def from_pretrained( # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to # match the HookedTransformer parameter names. state_dict = loading.get_pretrained_state_dict( - official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + official_model_name, + cfg, + hf_model, + dtype=dtype, + force_unsupported_model=force_unsupported_model, + **from_pretrained_kwargs, ) # Create the HookedTransformer object diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..03e75bce0 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -773,7 +773,7 @@ def get_official_model_name(model_name: str): return official_model_name -def convert_hf_model_config(model_name: str, **kwargs: Any): +def convert_hf_model_config(model_name: str, force_unsupported_model=False, **kwargs: Any): """ Returns the model config for a HuggingFace model, converted to a dictionary in the HookedTransformerConfig format. @@ -785,7 +785,10 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): logging.info("Loading model config from local directory") official_model_name = model_name else: - official_model_name = get_official_model_name(model_name) + if force_unsupported_model == False: + official_model_name = get_official_model_name(model_name) + else: + official_model_name = model_name # Load HuggingFace model config if "llama" in official_model_name.lower(): @@ -1643,6 +1646,7 @@ def get_pretrained_model_config( default_prepend_bos: Optional[bool] = None, dtype: torch.dtype = torch.float32, first_n_layers: Optional[int] = None, + force_unsupported_model=False, **kwargs: Any, ): """Returns the pretrained model config as an HookedTransformerConfig object. @@ -1681,16 +1685,24 @@ def get_pretrained_model_config( so this empirically seems to give better results. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string. dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. + force_unsupported_model: If specified, the function will try to load the model specified by the user, even though + it may be not official supported by TransformerLens. Use it at your own risk. kwargs: Other optional arguments passed to HuggingFace's from_pretrained. Also given to other HuggingFace functions when compatible. """ if Path(model_name).exists(): # If the model_name is a path, it's a local model - cfg_dict = convert_hf_model_config(model_name, **kwargs) + cfg_dict = convert_hf_model_config( + model_name, force_unsupported_model=force_unsupported_model, **kwargs + ) official_model_name = model_name else: - official_model_name = get_official_model_name(model_name) + if force_unsupported_model == False: + official_model_name = get_official_model_name(model_name) + else: + # Forcing an unsupported model + official_model_name = model_name if ( official_model_name.startswith("NeelNanda") or official_model_name.startswith("ArthurConmy") @@ -1705,7 +1717,9 @@ def get_pretrained_model_config( f"Loading model {official_model_name} requires setting trust_remote_code=True" ) kwargs["trust_remote_code"] = True - cfg_dict = convert_hf_model_config(official_model_name, **kwargs) + cfg_dict = convert_hf_model_config( + official_model_name, force_unsupported_model=force_unsupported_model, **kwargs + ) # Processing common to both model types # Remove any prefix, saying the organization who made a model. cfg_dict["model_name"] = official_model_name.split("/")[-1] @@ -1842,6 +1856,7 @@ def get_pretrained_state_dict( cfg: HookedTransformerConfig, hf_model: Optional[Any] = None, dtype: torch.dtype = torch.float32, + force_unsupported_model=False, **kwargs: Any, ) -> dict[str, torch.Tensor]: """ @@ -1862,7 +1877,10 @@ def get_pretrained_state_dict( official_model_name = str(Path(official_model_name).resolve()) logging.info(f"Loading model from local path {official_model_name}") else: - official_model_name = get_official_model_name(official_model_name) + if force_unsupported_model == False: + official_model_name = get_official_model_name(official_model_name) + else: + official_model_name = official_model_name if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( "trust_remote_code", False ):