From 854e3b513ef71ea9220c224bffa6e7bb9a141903 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Wed, 4 Feb 2026 18:22:14 -0800 Subject: [PATCH] Support multimethod in export_llama_lib TODO: add CI test. Differential Revision: [D92315602](https://our.internmc.facebook.com/intern/diff/D92315602/) [ghstack-poisoned] --- examples/models/llama/BUCK | 1 + examples/models/llama/export_llama_lib.py | 138 +++++++++++++++++++++- examples/models/llama/runner/targets.bzl | 1 + 3 files changed, 139 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/BUCK b/examples/models/llama/BUCK index 9d9897a2819..e8b21c898fa 100644 --- a/examples/models/llama/BUCK +++ b/examples/models/llama/BUCK @@ -148,6 +148,7 @@ fbcode_target(_kind = runtime.python_library, fbcode_target(_kind = runtime.python_library, name = "export_library", srcs = [ + "convert_weights.py", "export_llama.py", "export_llama_lib.py", "model.py", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 219cc71ded1..2891f0adb1a 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -20,15 +20,17 @@ from importlib import resources as _resources from json import JSONDecodeError from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch +from torch.export import ExportedProgram from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func from executorch.examples.models.llama.hf_download import ( download_and_convert_hf_checkpoint, ) +from executorch.exir import to_edge_transform_and_lower from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.builder import DType, LLMEdgeManager from executorch.extension.llm.export.config.llm_config import LlmConfig @@ -844,6 +846,28 @@ def _validate_args(llm_config): "Shared embedding is only supported with torchao quantization." ) + if llm_config.multimethod.enabled: + if llm_config.base.lora is not None: + raise ValueError( + "Cannot use both base.lora and multimethod.methods. " + "Use multimethod.methods for all LoRA variants." + ) + if llm_config.quantization.pt2e_quantize is not None: + raise ValueError( + "PT2E quantization is not supported with multimethod export." + ) + if ( + llm_config.backend.coreml.enabled + or llm_config.backend.vulkan.enabled + or llm_config.backend.qnn.enabled + or llm_config.backend.mps.enabled + or llm_config.backend.openvino.enabled + ): + raise ValueError( + "Multimethod export only supports XNNPACK backend or portable ops" + "Please disable other backends (coreml, vulkan, qnn, mps, openvino)." + ) + def _to_edge_and_lower_llama_xnnpack( builder_exported, @@ -1107,9 +1131,121 @@ def _to_edge_and_lower_llama( # noqa: C901 return builder +def _get_xnnpack_partitioners(llm_config: LlmConfig) -> Optional[List]: + """Get XNNPACK partitioners for multimethod export.""" + partitioners = [] + + if llm_config.backend.xnnpack.enabled: + partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)) + if llm_config.backend.xnnpack.extended_ops: + partitioners.append( + get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) + ) + + return partitioners if partitioners else None + + +def _get_output_filename(llm_config: LlmConfig, modelname: str, output_dir: str, dtype: DType) -> str: + """Determine output filename for the .pte file.""" + if dtype == DType.fp16: + modelname = f"{modelname}_h" + + if llm_config.export.output_name: + output_name = llm_config.export.output_name + if output_name.endswith(".pte"): + return output_name + else: + return f"{output_dir}/{output_name}.pte" + else: + return f"{output_dir}/{modelname}.pte" + + +def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: + """ + Export multiple methods (base + LoRA variants) to a single .pte file. + + For each method in llm_config.multimethod.methods: + - If LoraConfig is None: use base model + - If LoraConfig is provided: create model with LoRA weights + + Limitations: + - Only XNNPACK backend is supported for multimethod export. + - PT2E quantization is not supported. + - Each method is exported separately; export time scales linearly + with the number of methods. + - The final .pte file deduplicates shared weights automatically. + """ + num_methods = len(llm_config.multimethod.methods) + logging.info( + f"Multimethod export: exporting {num_methods} method(s). " + "Each method requires separate model instantiation and export." + ) + + additional_passes = [] + if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: + additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + + # Build dict of exported programs + method_to_program: Dict[str, ExportedProgram] = {} + first_builder = None + + for method_name, lora_config in llm_config.multimethod.methods.items(): + logging.info(f"Exporting method: {method_name}") + + # Create a copy of config with this method's LoRA setting + method_config = copy.deepcopy(llm_config) + method_config.base.lora = lora_config + # Disable multimethod to avoid infinite recursion + method_config.multimethod.methods = {} + + # Load and prepare model for this method + builder = _prepare_for_llama_export(method_config) + builder = builder.export() + builder.run_canonical_optimizations() + + # Get the exported program + exported_program = builder._export(builder.pre_autograd_graph_module) + method_to_program[method_name] = exported_program + + if first_builder is None: + first_builder = builder + + assert first_builder is not None, "No methods to export" + + # Get partitioners based on backend config + partitioners = _get_xnnpack_partitioners(llm_config) + + # Lower all methods together using multimethod API + edge_config = first_builder._get_edge_config() + edge_manager = to_edge_transform_and_lower( + method_to_program, + partitioner=partitioners, + compile_config=edge_config, + constant_methods=first_builder.metadata, + ) + + # Convert to executorch and save + first_builder.edge_manager = edge_manager + first_builder = first_builder.to_executorch(passes=additional_passes) + + output_file = _get_output_filename( + llm_config, + first_builder.modelname, + first_builder.output_dir, + first_builder.dtype, + ) + first_builder.save_to_pte(output_file) + + return first_builder + + def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 _validate_args(llm_config) + # Check for multimethod export + if llm_config.multimethod.enabled: + return _export_llama_multimethod(llm_config) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( llm_config ) diff --git a/examples/models/llama/runner/targets.bzl b/examples/models/llama/runner/targets.bzl index f2020126acc..f5f30f73054 100644 --- a/examples/models/llama/runner/targets.bzl +++ b/examples/models/llama/runner/targets.bzl @@ -47,6 +47,7 @@ def define_common_targets(): "//executorch/examples/models/llama/tokenizer:tiktoken", "//pytorch/tokenizers:llama2c_tokenizer", "//pytorch/tokenizers:hf_tokenizer", + "//pytorch/tokenizers:regex_lookahead", ] + (_get_operator_lib(aten)) + ([ # Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE) # Therefore enable it explicitly for now to avoid failing tests