diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 78e005950b7..341475afe6d 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -26,14 +26,16 @@ Arguments: quant_name Quantization type (optional, default: non-quantized) Options: - non-quantized - - quantized-int4-tile-packed - - quantized-int4-weight-only + - quantized-int4-tile-packed (CUDA only) + - quantized-int4-weight-only (CUDA only) + - quantized-int4-metal (Metal only) - quantized-8da4w (XNNPACK only) output_dir Output directory for artifacts (optional, default: current directory) Examples: export_model_artifact.sh metal "openai/whisper-small" + export_model_artifact.sh metal "nvidia/parakeet-tdt" "quantized-int4-metal" export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output" export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output" @@ -131,18 +133,25 @@ case "$QUANT_NAME" in ;; quantized-int4-tile-packed) if [ "$DEVICE" = "metal" ]; then - echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + echo "Error: Metal backend does not support quantization '$QUANT_NAME'" exit 1 fi EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" ;; quantized-int4-weight-only) if [ "$DEVICE" = "metal" ]; then - echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + echo "Error: Metal backend does not support quantization '$QUANT_NAME'" exit 1 fi EXTRA_ARGS="--qlinear_encoder 4w" ;; + quantized-int4-metal) + if [ "$DEVICE" != "metal" ]; then + echo "Error: Quantization '$QUANT_NAME' only supported on Metal backend" + exit 1 + fi + EXTRA_ARGS="--qlinear fpa4w --qlinear_encoder fpa4w" + ;; quantized-8da4w) if [ "$DEVICE" != "xnnpack" ]; then echo "Error: quantized-8da4w is only supported with xnnpack device" @@ -152,7 +161,7 @@ case "$QUANT_NAME" in ;; *) echo "Error: Unsupported quantization '$QUANT_NAME'" - echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-8da4w" + echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-int4-metal, quantized-8da4w" exit 1 ;; esac diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index bf86b01aff8..ec15c87f737 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -41,11 +41,10 @@ jobs: set -eux echo "::group::Setup ExecuTorch" - PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh + PYTHON_EXECUTABLE=python ${CONDA_RUN} EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh echo "::endgroup::" echo "::group::Build Metal Runtime" - ${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --update-ao ${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --build echo "::endgroup::" @@ -73,6 +72,12 @@ jobs: name: "parakeet-tdt" quant: - "non-quantized" + # Only test int4 quantization with parakeet-tdt + include: + - model: + repo: "nvidia" + name: "parakeet-tdt" + quant: "quantized-int4-metal" with: runner: macos-m2-stable python-version: '3.11' @@ -123,6 +128,12 @@ jobs: name: "parakeet-tdt" quant: - "non-quantized" + # Only test int4 quantization with parakeet-tdt + include: + - model: + repo: "nvidia" + name: "parakeet-tdt" + quant: "quantized-int4-metal" with: runner: macos-m2-stable python-version: '3.11' diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 215851f09f9..20d558ecaa7 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -44,8 +44,12 @@ def get_decomposition_table(cls) -> Dict[Any, Any]: @classmethod def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: - """Return Metal-specific passes (currently none)""" - return [] + """Return Metal-specific passes""" + from executorch.backends.apple.metal.passes.decompose_linear_pass import ( + DecomposeLinearPass, + ) + + return [DecomposeLinearPass()] @classmethod def get_aoti_compile_options( diff --git a/backends/apple/metal/passes/__init__.py b/backends/apple/metal/passes/__init__.py new file mode 100644 index 00000000000..2a1209f1356 --- /dev/null +++ b/backends/apple/metal/passes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.apple.metal.passes.decompose_linear_pass import ( # noqa: F401 + DecomposeLinearPass, +) + +__all__ = ["DecomposeLinearPass"] diff --git a/backends/apple/metal/passes/decompose_linear_pass.py b/backends/apple/metal/passes/decompose_linear_pass.py new file mode 100644 index 00000000000..e6b8578cc9f --- /dev/null +++ b/backends/apple/metal/passes/decompose_linear_pass.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeLinearPass(ExportPass): + """ + Decompose aten.linear into matmul + add to avoid addmm. + + For 2D inputs, we unsqueeze to 3D before decomposition to force the matmul + code path instead of addmm. The C++ implementation of aten.linear directly + calls addmm for 2D inputs with bias, which would require implementing + aoti_torch_mps_addmm_out. By unsqueezing to 3D, we force the matmul path, + then squeeze back to 2D. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + graph = graph_module.graph + + for node in graph.nodes: + # Check if this is a linear operation + is_linear = False + + if node.op == "call_function": + # Match both edge dialect and core aten linear operators + if node.target == exir_ops.edge.aten.linear.default: + is_linear = True + elif node.target == torch.ops.aten.linear.default: + is_linear = True + + if is_linear: + # Get input, weight, and bias arguments + input_node = node.args[0] + weight_node = node.args[1] + bias_node = node.args[2] if len(node.args) > 2 else None + + with graph.inserting_before(node): + # Determine which ops to use based on the input operator + target_str = str(node.target) + + if "executorch_exir_dialects_edge" in target_str: + # Use edge dialect operators + t_op = exir_ops.edge.aten.t.default + matmul_op = exir_ops.edge.aten.matmul.default + add_op = exir_ops.edge.aten.add.Tensor + unsqueeze_op = exir_ops.edge.aten.unsqueeze.default + squeeze_op = exir_ops.edge.aten.squeeze.dims + else: + # Use core aten operators + t_op = torch.ops.aten.t.default + matmul_op = torch.ops.aten.matmul.default + add_op = torch.ops.aten.add.Tensor + unsqueeze_op = torch.ops.aten.unsqueeze.default + squeeze_op = torch.ops.aten.squeeze.dims + + # Check if input is 2D + needs_unsqueeze = False + if hasattr(input_node, "meta") and "val" in input_node.meta: + if len(input_node.meta["val"].shape) == 2: + needs_unsqueeze = True + + # Unsqueeze 2D input to 3D: (M, K) -> (1, M, K) + current_input = input_node + if needs_unsqueeze: + current_input = graph.call_function( + unsqueeze_op, + args=(input_node, 0), + ) + + # Decompose linear: matmul(input, weight.T) + bias + weight_t = graph.call_function( + t_op, + args=(weight_node,), + ) + + matmul_result = graph.call_function( + matmul_op, + args=(current_input, weight_t), + ) + + if bias_node is not None: + result = graph.call_function( + add_op, + args=(matmul_result, bias_node), + ) + else: + result = matmul_result + + # Squeeze 3D output back to 2D: (1, M, N) -> (M, N) + if needs_unsqueeze: + result = graph.call_function( + squeeze_op, + args=(result, [0]), + ) + + # Replace all uses of the linear node with the decomposed result + node.replace_all_uses_with(result) + graph.erase_node(node) + modified = True + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 403ce355381..4908b2e0ffc 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -208,6 +208,22 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearWithBias(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(7, 101, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_bias"] = { + "model_class": LinearWithBias, + "input_shapes": [(127, 7)], + "description": "Simple linear layer model bias", + + # ------------------------------------------------------------------------- class LinearNoBiasInt4(nn.Module): def __init__(self): diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 695e70ed472..649a2225536 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -39,10 +39,10 @@ The export script supports quantizing encoder and decoder linear layers using [t | Argument | Description | |----------|-------------| -| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | | `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) | | `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` | -| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | | `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) | | `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` | | `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` | @@ -50,12 +50,13 @@ The export script supports quantizing encoder and decoder linear layers using [t #### Quantization Configs -| Config | Description | -|--------|-------------| -| `4w` | 4-bit weight only quantization | -| `8w` | 8-bit weight only quantization | -| `8da4w` | 8-bit dynamic activation, 4-bit weight | -| `8da8w` | 8-bit dynamic activation, 8-bit weight | +| Config | Description | Backends | +|--------|-------------|----------| +| `4w` | 4-bit weight only quantization | CUDA | +| `8w` | 8-bit weight only quantization | CUDA | +| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA | +| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA | +| `fpa4w` | Floating point activation, 4-bit weight | Metal | #### Example: Dynamic Quantization for XNNPACK @@ -86,6 +87,30 @@ python export_parakeet_tdt.py \ **Note:** The `tile_packed_to_4d` packing format is optimized for CUDA. +#### Example: Metal 4-bit Quantization + +```bash +python export_parakeet_tdt.py \ + --backend metal \ + --qlinear_encoder fpa4w \ + --qlinear_encoder_group_size 32 \ + --qlinear fpa4w \ + --qlinear_group_size 32 \ + --output-dir ./parakeet_metal_quantized +``` + +**Note:** Metal 4-bit quantization requires torchao built with experimental MPS (Metal) ops. + +You can install torchao with Metal support from the `ao` repo: +```bash +USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation +``` + +Alternatively, you can build torchao with Metal support while installing ExecuTorch: +```bash +EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh +``` + ### Metal Export (macOS) ```bash diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index fa21d4e3dc1..6747880cd9e 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -622,7 +622,7 @@ def main(): parser.add_argument( "--qlinear", type=str, - choices=["4w", "8w", "8da4w", "8da8w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], help="Quantization config for decoder linear layers", ) parser.add_argument( @@ -642,7 +642,7 @@ def main(): parser.add_argument( "--qlinear_encoder", type=str, - choices=["4w", "8w", "8da4w", "8da8w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], help="Quantization config for encoder linear layers", ) parser.add_argument( @@ -678,6 +678,12 @@ def main(): if args.dtype == "fp16": parser.error("fp16 is not yet supported") + # Validate fpa4w quantization requires Metal backend + if args.qlinear == "fpa4w" and args.backend != "metal": + parser.error("--qlinear=fpa4w can only be used with --backend=metal") + if args.qlinear_encoder == "fpa4w" and args.backend != "metal": + parser.error("--qlinear_encoder=fpa4w can only be used with --backend=metal") + os.makedirs(args.output_dir, exist_ok=True) print("Extracting tokenizer...") diff --git a/examples/models/parakeet/quantize.py b/examples/models/parakeet/quantize.py index 5d602d7a3e4..a08a681fdc6 100644 --- a/examples/models/parakeet/quantize.py +++ b/examples/models/parakeet/quantize.py @@ -17,7 +17,7 @@ def quantize_model_( # noqa: C901 Args: module: The PyTorch module to quantize. - qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w"). + qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w", "fpa4w"). qlinear_group_size: Group size for linear quantization (default: 32). qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d"). qembedding_config: Quantization config for embedding layers ("4w", "8w"). @@ -26,12 +26,41 @@ def quantize_model_( # noqa: C901 if not qlinear_config and not qembedding_config: return + from torchao.quantization.quant_api import quantize_ + + # Metal (MPS) quantization uses different API + if qlinear_config == "fpa4w": + # Load MPS ops + import torchao.experimental.ops.mps # noqa: F401 + from torchao.experimental.quant_api import UIntxWeightOnlyConfig + + config = UIntxWeightOnlyConfig( + group_size=qlinear_group_size, + bitwidth=4, + ) + + def linear_filter(m, fqn): + if isinstance(m, torch.nn.Linear): + if m.weight.shape[1] % qlinear_group_size != 0: + raise ValueError( + f"Metal int4 quantization requires weight dimension (K) to be multiple of group_size. " + f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]}, group_size={qlinear_group_size})" # noqa: E501 + ) + return True + return False + + print( + f" Applying {qlinear_config} linear quantization " + f"(group_size={qlinear_group_size})..." + ) + quantize_(module, config, filter_fn=linear_filter) + return + from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int4WeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, - quantize_, ) # Quantize embedding layers first diff --git a/third-party/ao b/third-party/ao index 28306f08500..1b4b6d998bf 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit 28306f085003b892bc0a250c209d80f5d4a5147b +Subproject commit 1b4b6d998bf988f059e97a10181cbc4aec269b69