Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
checkpoint_path = self.llm_config.base.checkpoint
params_path = self.llm_config.base.params

# Adapter checkpoint and config.
adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint
adapter_config_path = self.llm_config.base.adapter_config
assert (adapter_checkpoint_path is None and adapter_config_path is None) or (
adapter_checkpoint_path is not None and adapter_config_path is not None
), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified."
# LoRA adapter configuration.
lora_config = self.llm_config.base.lora

self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
Expand Down Expand Up @@ -69,10 +65,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
with open(params_path, "r") as f:
params = json.loads(f.read())

# Get adapter checkpoint and config.
# Get adapter checkpoint.
adapter_checkpoint = {}
adapter_config = {}
if adapter_checkpoint_path:
if lora_config:
adapter_checkpoint_path = lora_config.adapter_checkpoint
if adapter_checkpoint_path.endswith(".pt"):
adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
Expand All @@ -92,22 +88,6 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
raise ValueError(
f"Unsupported adapter checkpoint format: {adapter_checkpoint_path}"
)

with open(adapter_config_path, "r") as f:
adapter_config_full = json.loads(f.read())
if (
"r" not in adapter_config_full
or "lora_alpha" not in adapter_config_full
or "target_modules" not in adapter_config_full
):
raise ValueError(
"Adapter config must contain r, lora_alpha, and target_modules."
)
adapter_config = {
"r": adapter_config_full["r"],
"lora_alpha": adapter_config_full["lora_alpha"],
"target_modules": adapter_config_full["target_modules"],
}
checkpoint.update(adapter_checkpoint)

output_prune_map = None
Expand All @@ -133,8 +113,10 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
input_prune_map=input_prune_map,
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
r=lora_config.r if lora_config else None,
lora_alpha=lora_config.lora_alpha if lora_config else None,
target_modules=lora_config.target_modules if lora_config else None,
**params,
**adapter_config,
)

if model_args.use_scaled_rope:
Expand Down
84 changes: 73 additions & 11 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""

import argparse
import json
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'json' is not used.

Suggested change
import json

Copilot uses AI. Check for mistakes.
import os
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'os' is not used.

Suggested change
import os

Copilot uses AI. Check for mistakes.
import re
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -61,6 +63,68 @@ class PreqMode(str, Enum):
preq_8da4w_out_8da8w = "8da4w_output_8da8w"


@dataclass
class LoraConfig:
"""LoRA adapter configuration.

Can be created in two ways:

1. Direct specification (all fields required):
LoraConfig(
adapter_checkpoint="/path/to/adapter.safetensors",
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
)

2. Parse from adapter_config.json (r, lora_alpha, target_modules auto-loaded):
LoraConfig(
adapter_checkpoint="/path/to/adapter.safetensors",
adapter_config="/path/to/adapter_config.json",
)
"""

adapter_checkpoint: str # Path to adapter weights (.safetensors or .pt)
r: Optional[int] = None # LoRA rank
lora_alpha: Optional[int] = None # LoRA alpha scaling
# Modules to apply LoRA to:
# attention: ["q_proj", "v_proj", "k_proj", "output_proj"/"o_proj"]
# feed-forward: ["gate_proj", "up_proj", "down_proj"]
target_modules: Optional[List[str]] = None
adapter_config: Optional[str] = None # Path to adapter_config.json

def __post_init__(self):
"""Parse adapter_config.json if provided and validate required fields."""
if self.adapter_config is not None:
self._parse_and_fill_from_config(self.adapter_config)

# Validate required fields
required = ["r", "lora_alpha", "target_modules"]
missing = [f for f in required if getattr(self, f) is None]
if missing:
raise ValueError(
f"LoraConfig missing required fields: {missing}. "
"Provide them directly or via adapter_config."
)

def _parse_and_fill_from_config(self, config_path: str) -> None:
"""Parse adapter_config.json and fill missing fields."""
import json
import os

Comment on lines +112 to +114
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these duplicate imports. The modules json and os are already imported at the top of the file (lines 21-22), so importing them again inside this method is unnecessary.

Suggested change
import json
import os

Copilot uses AI. Check for mistakes.
if not os.path.exists(config_path):
raise FileNotFoundError(f"LoRA config not found: {config_path}")

with open(config_path, "r") as f:
config = json.load(f)

for field in ["r", "lora_alpha", "target_modules"]:
if getattr(self, field) is None:
if field not in config:
raise ValueError(f"adapter_config.json must contain '{field}'")
setattr(self, field, config[field])


@dataclass
class BaseConfig:
"""
Expand All @@ -76,11 +140,7 @@ class BaseConfig:
If left empty, the model will either be initialized with random weights
if it is a Llama model or the weights will be downloaded from HuggingFace
if it is a non-Llama model.
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
the model has trained LoRA adapters. Must provide
adapter_config.json.
adapter_config: Path to the adapter_config.json file from torchtune.
Used if the model has trained LoRA adapters. Must provide adapter.pt.
lora: LoRA adapter configuration.
tokenizer_path: Path to the tokenizer file.
metadata: Json string containing metadata information.
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
Expand All @@ -97,8 +157,7 @@ class BaseConfig:
model_class: ModelType = ModelType.llama3
params: Optional[str] = None
checkpoint: Optional[str] = None
adapter_checkpoint: Optional[str] = None
adapter_config: Optional[str] = None
lora: Optional[LoraConfig] = None
tokenizer_path: Optional[str] = None
metadata: Optional[str] = None
use_lora: int = 0
Expand Down Expand Up @@ -523,10 +582,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.base.params = args.params
if hasattr(args, "checkpoint"):
llm_config.base.checkpoint = args.checkpoint
if hasattr(args, "adapter_checkpoint"):
llm_config.base.adapter_checkpoint = args.adapter_checkpoint
if hasattr(args, "adapter_config"):
llm_config.base.adapter_config = args.adapter_config
if hasattr(args, "adapter_checkpoint") and args.adapter_checkpoint:
if not hasattr(args, "adapter_config") or not args.adapter_config:
raise ValueError("--adapter_checkpoint requires --adapter_config")
llm_config.base.lora = LoraConfig(
adapter_checkpoint=args.adapter_checkpoint,
adapter_config=args.adapter_config,
)
if hasattr(args, "tokenizer_path"):
llm_config.base.tokenizer_path = args.tokenizer_path
if hasattr(args, "metadata"):
Expand Down
Loading