diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 4c7485d29b0..816ce97cf90 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -23,7 +23,7 @@ import re from dataclasses import dataclass, field from enum import Enum -from typing import ClassVar, List, Optional +from typing import ClassVar, Dict, List, Optional ################################################################################ @@ -320,6 +320,37 @@ class DebugConfig: verbose: bool = False +################################################################################ +############################## MultimethodLoraConfig ########################### +################################################################################ + + +@dataclass +class MultimethodLoraConfig: + """Configuration for exporting multiple methods to a single .pte file. + + Maps method names to optional LoRA configurations. A None value means + the method uses base model weights. + + Attributes: + methods: Dict mapping method names to optional LoRA configs. + Empty dict disables multimethod export. + + Example: + MultimethodLoraConfig(methods={ + "forward": None, # base model + "lora_forward": lora_config, # LoRA variant + }) + """ + + methods: Dict[str, Optional[LoraConfig]] = field(default_factory=dict) + + @property + def enabled(self) -> bool: + """Returns True if multimethod export is configured.""" + return len(self.methods) > 0 + + ################################################################################ ############################# QuantizationConfig ############################### ################################################################################ @@ -564,6 +595,7 @@ class LlmConfig: model: ModelConfig = field(default_factory=ModelConfig) export: ExportConfig = field(default_factory=ExportConfig) debug: DebugConfig = field(default_factory=DebugConfig) + multimethod: MultimethodLoraConfig = field(default_factory=MultimethodLoraConfig) quantization: QuantizationConfig = field(default_factory=QuantizationConfig) backend: BackendConfig = field(default_factory=BackendConfig)