From 4827a886c6bd38f89d3df3134cc219debb06e903 Mon Sep 17 00:00:00 2001 From: mrinaldi97 <67000245+mrinaldi97@users.noreply.github.com> Date: Sun, 23 Feb 2025 15:54:34 +0100 Subject: [PATCH 1/3] Force unsupported models Adding the "force_unsupported_models" feature --- transformer_lens/HookedTransformer.py | 16 ++++++++++--- transformer_lens/loading_from_pretrained.py | 26 ++++++++++++++++----- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index cf4c369ac..a877919eb 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1124,6 +1124,7 @@ def from_pretrained( default_padding_side: Literal["left", "right"] = "right", dtype="float32", first_n_layers: Optional[int] = None, + force_unsupported_model=False, **from_pretrained_kwargs, ) -> T: """Load in a Pretrained Model. @@ -1263,6 +1264,8 @@ def from_pretrained( default_padding_side: Which side to pad on when tokenizing. Defaults to "right". first_n_layers: If specified, only load the first n layers of the model. + force_unsupported_model: If specified, the function will try to load the model specified by the user, even though + it may be not official supported by TransformerLens. Use it at your own risk. """ if model_name.lower().startswith("t5"): raise RuntimeError( @@ -1308,8 +1311,13 @@ def from_pretrained( ) and device in ["cpu", None]: logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") - # Get the model name used in HuggingFace, rather than the alias. - official_model_name = loading.get_official_model_name(model_name) + if force_unsupported_model: + #Force the loading of an unsupported model + logging.warning("You may be loading an unsupported model. Please be sure you know what you are doing and that you can expect unwanted behaviour") + official_model_name=model_name + else: + # Get the model name used in HuggingFace, rather than the alias. + official_model_name = loading.get_official_model_name(model_name) # Load the config into an HookedTransformerConfig object. If loading from a # checkpoint, the config object will contain the information about the @@ -1325,6 +1333,7 @@ def from_pretrained( default_prepend_bos=default_prepend_bos, dtype=dtype, first_n_layers=first_n_layers, + force_unsupported_model=force_unsupported_model, **from_pretrained_kwargs, ) @@ -1357,9 +1366,10 @@ def from_pretrained( # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to # match the HookedTransformer parameter names. state_dict = loading.get_pretrained_state_dict( - official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + official_model_name, cfg, hf_model, dtype=dtype, force_unsupported_model=force_unsupported_model, **from_pretrained_kwargs ) + # Create the HookedTransformer object model = cls( cfg, diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 952d2bf9b..f09f60051 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -738,7 +738,7 @@ def get_official_model_name(model_name: str): return official_model_name -def convert_hf_model_config(model_name: str, **kwargs): +def convert_hf_model_config(model_name: str, force_unsupported_model=False, **kwargs): """ Returns the model config for a HuggingFace model, converted to a dictionary in the HookedTransformerConfig format. @@ -750,7 +750,10 @@ def convert_hf_model_config(model_name: str, **kwargs): logging.info("Loading model config from local directory") official_model_name = model_name else: - official_model_name = get_official_model_name(model_name) + if force_unsupported_model==False: + official_model_name = get_official_model_name(model_name) + else: + official_model_name = model_name # Load HuggingFace model config if "llama" in official_model_name.lower(): @@ -1562,6 +1565,7 @@ def get_pretrained_model_config( default_prepend_bos: Optional[bool] = None, dtype: torch.dtype = torch.float32, first_n_layers: Optional[int] = None, + force_unsupported_model=False, **kwargs, ): """Returns the pretrained model config as an HookedTransformerConfig object. @@ -1600,16 +1604,22 @@ def get_pretrained_model_config( so this empirically seems to give better results. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string. dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. + force_unsupported_model: If specified, the function will try to load the model specified by the user, even though + it may be not official supported by TransformerLens. Use it at your own risk. kwargs: Other optional arguments passed to HuggingFace's from_pretrained. Also given to other HuggingFace functions when compatible. """ if Path(model_name).exists(): # If the model_name is a path, it's a local model - cfg_dict = convert_hf_model_config(model_name, **kwargs) + cfg_dict = convert_hf_model_config(model_name, force_unsupported_model=force_unsupported_model, **kwargs) official_model_name = model_name else: - official_model_name = get_official_model_name(model_name) + if force_unsupported_model==False: + official_model_name = get_official_model_name(model_name) + else: + #Forcing an unsupported model + official_model_name=model_name if ( official_model_name.startswith("NeelNanda") or official_model_name.startswith("ArthurConmy") @@ -1624,7 +1634,7 @@ def get_pretrained_model_config( f"Loading model {official_model_name} requires setting trust_remote_code=True" ) kwargs["trust_remote_code"] = True - cfg_dict = convert_hf_model_config(official_model_name, **kwargs) + cfg_dict = convert_hf_model_config(official_model_name, force_unsupported_model=force_unsupported_model, **kwargs) # Processing common to both model types # Remove any prefix, saying the organization who made a model. cfg_dict["model_name"] = official_model_name.split("/")[-1] @@ -1759,6 +1769,7 @@ def get_pretrained_state_dict( cfg: HookedTransformerConfig, hf_model=None, dtype: torch.dtype = torch.float32, + force_unsupported_model=False, **kwargs, ) -> Dict[str, torch.Tensor]: """ @@ -1779,7 +1790,10 @@ def get_pretrained_state_dict( official_model_name = str(Path(official_model_name).resolve()) logging.info(f"Loading model from local path {official_model_name}") else: - official_model_name = get_official_model_name(official_model_name) + if force_unsupported_model==False: + official_model_name = get_official_model_name(official_model_name) + else: + official_model_name = official_model_name if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( "trust_remote_code", False ): From 46d2af38ac7bc67c443dc7d357ce1298dfd12cfc Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 19 Jun 2025 19:28:52 +0200 Subject: [PATCH 2/3] fixed typing issue --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 320cab2d5..acf5626b3 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1854,7 +1854,7 @@ def get_pretrained_state_dict( dtype: torch.dtype = torch.float32, force_unsupported_model=False, **kwargs: Any, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """ Loads in the model weights for a pretrained model, and processes them to have the HookedTransformer parameter names and shapes. Supports checkpointed From 71cfcfdb3b2bb58a16fe743c204aba6d9e7cd0a8 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 19 Jun 2025 20:34:14 +0200 Subject: [PATCH 3/3] ran format --- transformer_lens/HookedTransformer.py | 22 +++++++++++++-------- transformer_lens/loading_from_pretrained.py | 20 +++++++++++-------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 540ca23ce..e76bb6317 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1134,7 +1134,7 @@ def from_pretrained( default_padding_side: Literal["left", "right"] = "right", dtype="float32", first_n_layers: Optional[int] = None, - force_unsupported_model=False, + force_unsupported_model=False, **from_pretrained_kwargs, ) -> T: """Load in a Pretrained Model. @@ -1323,12 +1323,14 @@ def from_pretrained( logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") if force_unsupported_model: - #Force the loading of an unsupported model - logging.warning("You may be loading an unsupported model. Please be sure you know what you are doing and that you can expect unwanted behaviour") - official_model_name=model_name + # Force the loading of an unsupported model + logging.warning( + "You may be loading an unsupported model. Please be sure you know what you are doing and that you can expect unwanted behaviour" + ) + official_model_name = model_name else: - # Get the model name used in HuggingFace, rather than the alias. - official_model_name = loading.get_official_model_name(model_name) + # Get the model name used in HuggingFace, rather than the alias. + official_model_name = loading.get_official_model_name(model_name) # Load the config into an HookedTransformerConfig object. If loading from a # checkpoint, the config object will contain the information about the @@ -1377,10 +1379,14 @@ def from_pretrained( # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to # match the HookedTransformer parameter names. state_dict = loading.get_pretrained_state_dict( - official_model_name, cfg, hf_model, dtype=dtype, force_unsupported_model=force_unsupported_model, **from_pretrained_kwargs + official_model_name, + cfg, + hf_model, + dtype=dtype, + force_unsupported_model=force_unsupported_model, + **from_pretrained_kwargs, ) - # Create the HookedTransformer object model = cls( cfg, diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index acf5626b3..03e75bce0 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -785,7 +785,7 @@ def convert_hf_model_config(model_name: str, force_unsupported_model=False, **kw logging.info("Loading model config from local directory") official_model_name = model_name else: - if force_unsupported_model==False: + if force_unsupported_model == False: official_model_name = get_official_model_name(model_name) else: official_model_name = model_name @@ -1693,14 +1693,16 @@ def get_pretrained_model_config( """ if Path(model_name).exists(): # If the model_name is a path, it's a local model - cfg_dict = convert_hf_model_config(model_name, force_unsupported_model=force_unsupported_model, **kwargs) + cfg_dict = convert_hf_model_config( + model_name, force_unsupported_model=force_unsupported_model, **kwargs + ) official_model_name = model_name else: - if force_unsupported_model==False: - official_model_name = get_official_model_name(model_name) + if force_unsupported_model == False: + official_model_name = get_official_model_name(model_name) else: - #Forcing an unsupported model - official_model_name=model_name + # Forcing an unsupported model + official_model_name = model_name if ( official_model_name.startswith("NeelNanda") or official_model_name.startswith("ArthurConmy") @@ -1715,7 +1717,9 @@ def get_pretrained_model_config( f"Loading model {official_model_name} requires setting trust_remote_code=True" ) kwargs["trust_remote_code"] = True - cfg_dict = convert_hf_model_config(official_model_name, force_unsupported_model=force_unsupported_model, **kwargs) + cfg_dict = convert_hf_model_config( + official_model_name, force_unsupported_model=force_unsupported_model, **kwargs + ) # Processing common to both model types # Remove any prefix, saying the organization who made a model. cfg_dict["model_name"] = official_model_name.split("/")[-1] @@ -1873,7 +1877,7 @@ def get_pretrained_state_dict( official_model_name = str(Path(official_model_name).resolve()) logging.info(f"Loading model from local path {official_model_name}") else: - if force_unsupported_model==False: + if force_unsupported_model == False: official_model_name = get_official_model_name(official_model_name) else: official_model_name = official_model_name