diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 1ec85936f7a..63a489ec8c9 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -31,12 +31,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): checkpoint_path = self.llm_config.base.checkpoint params_path = self.llm_config.base.params - # Adapter checkpoint and config. - adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint - adapter_config_path = self.llm_config.base.adapter_config - assert (adapter_checkpoint_path is None and adapter_config_path is None) or ( - adapter_checkpoint_path is not None and adapter_config_path is not None - ), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified." + # LoRA adapter configuration. + lora_config = self.llm_config.base.lora self.use_kv_cache = self.llm_config.model.use_kv_cache self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache @@ -69,10 +65,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): with open(params_path, "r") as f: params = json.loads(f.read()) - # Get adapter checkpoint and config. + # Get adapter checkpoint. adapter_checkpoint = {} - adapter_config = {} - if adapter_checkpoint_path: + if lora_config: + adapter_checkpoint_path = lora_config.adapter_checkpoint if adapter_checkpoint_path.endswith(".pt"): adapter_checkpoint = torch.load( adapter_checkpoint_path, map_location=device, mmap=True @@ -92,22 +88,6 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): raise ValueError( f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}" ) - - with open(adapter_config_path, "r") as f: - adapter_config_full = json.loads(f.read()) - if ( - "r" not in adapter_config_full - or "lora_alpha" not in adapter_config_full - or "target_modules" not in adapter_config_full - ): - raise ValueError( - "Adapter config must contain r, lora_alpha, and target_modules." - ) - adapter_config = { - "r": adapter_config_full["r"], - "lora_alpha": adapter_config_full["lora_alpha"], - "target_modules": adapter_config_full["target_modules"], - } checkpoint.update(adapter_checkpoint) output_prune_map = None @@ -133,8 +113,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): input_prune_map=input_prune_map, output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, + r=lora_config.r if lora_config else None, + lora_alpha=lora_config.lora_alpha if lora_config else None, + target_modules=lora_config.target_modules if lora_config else None, **params, - **adapter_config, ) if model_args.use_scaled_rope: diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index b40fad88a9c..4c7485d29b0 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -18,6 +18,8 @@ """ import argparse +import json +import os import re from dataclasses import dataclass, field from enum import Enum @@ -61,6 +63,68 @@ class PreqMode(str, Enum): preq_8da4w_out_8da8w = "8da4w_output_8da8w" +@dataclass +class LoraConfig: + """LoRA adapter configuration. + + Can be created in two ways: + + 1. Direct specification (all fields required): + LoraConfig( + adapter_checkpoint="/path/to/adapter.safetensors", + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + ) + + 2. Parse from adapter_config.json (r, lora_alpha, target_modules auto-loaded): + LoraConfig( + adapter_checkpoint="/path/to/adapter.safetensors", + adapter_config="/path/to/adapter_config.json", + ) + """ + + adapter_checkpoint: str # Path to adapter weights (.safetensors or .pt) + r: Optional[int] = None # LoRA rank + lora_alpha: Optional[int] = None # LoRA alpha scaling + # Modules to apply LoRA to: + # attention: ["q_proj", "v_proj", "k_proj", "output_proj"/"o_proj"] + # feed-forward: ["gate_proj", "up_proj", "down_proj"] + target_modules: Optional[List[str]] = None + adapter_config: Optional[str] = None # Path to adapter_config.json + + def __post_init__(self): + """Parse adapter_config.json if provided and validate required fields.""" + if self.adapter_config is not None: + self._parse_and_fill_from_config(self.adapter_config) + + # Validate required fields + required = ["r", "lora_alpha", "target_modules"] + missing = [f for f in required if getattr(self, f) is None] + if missing: + raise ValueError( + f"LoraConfig missing required fields: {missing}. " + "Provide them directly or via adapter_config." + ) + + def _parse_and_fill_from_config(self, config_path: str) -> None: + """Parse adapter_config.json and fill missing fields.""" + import json + import os + + if not os.path.exists(config_path): + raise FileNotFoundError(f"LoRA config not found: {config_path}") + + with open(config_path, "r") as f: + config = json.load(f) + + for field in ["r", "lora_alpha", "target_modules"]: + if getattr(self, field) is None: + if field not in config: + raise ValueError(f"adapter_config.json must contain '{field}'") + setattr(self, field, config[field]) + + @dataclass class BaseConfig: """ @@ -76,11 +140,7 @@ class BaseConfig: If left empty, the model will either be initialized with random weights if it is a Llama model or the weights will be downloaded from HuggingFace if it is a non-Llama model. - adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if - the model has trained LoRA adapters. Must provide - adapter_config.json. - adapter_config: Path to the adapter_config.json file from torchtune. - Used if the model has trained LoRA adapters. Must provide adapter.pt. + lora: LoRA adapter configuration. tokenizer_path: Path to the tokenizer file. metadata: Json string containing metadata information. e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' @@ -97,8 +157,7 @@ class BaseConfig: model_class: ModelType = ModelType.llama3 params: Optional[str] = None checkpoint: Optional[str] = None - adapter_checkpoint: Optional[str] = None - adapter_config: Optional[str] = None + lora: Optional[LoraConfig] = None tokenizer_path: Optional[str] = None metadata: Optional[str] = None use_lora: int = 0 @@ -523,10 +582,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.base.params = args.params if hasattr(args, "checkpoint"): llm_config.base.checkpoint = args.checkpoint - if hasattr(args, "adapter_checkpoint"): - llm_config.base.adapter_checkpoint = args.adapter_checkpoint - if hasattr(args, "adapter_config"): - llm_config.base.adapter_config = args.adapter_config + if hasattr(args, "adapter_checkpoint") and args.adapter_checkpoint: + if not hasattr(args, "adapter_config") or not args.adapter_config: + raise ValueError("--adapter_checkpoint requires --adapter_config") + llm_config.base.lora = LoraConfig( + adapter_checkpoint=args.adapter_checkpoint, + adapter_config=args.adapter_config, + ) if hasattr(args, "tokenizer_path"): llm_config.base.tokenizer_path = args.tokenizer_path if hasattr(args, "metadata"):