Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
39db621
Update
manuelcandales Jan 31, 2026
0ed7c5c
Update
manuelcandales Jan 31, 2026
b4310cc
Update
manuelcandales Jan 31, 2026
94c823c
Update
manuelcandales Jan 31, 2026
31b6f45
Update
manuelcandales Feb 2, 2026
c68cc6b
Update
manuelcandales Feb 2, 2026
bd7192f
Update
manuelcandales Feb 2, 2026
bcc8bda
Update
manuelcandales Feb 2, 2026
f166c50
Update
manuelcandales Feb 2, 2026
0834659
Update
manuelcandales Feb 2, 2026
ed4dcee
Update
manuelcandales Feb 2, 2026
a058197
Update
manuelcandales Feb 2, 2026
7146282
Update
manuelcandales Feb 2, 2026
d3501af
Update
manuelcandales Feb 2, 2026
fe5be37
Update
manuelcandales Feb 2, 2026
a0e3469
Update
manuelcandales Feb 2, 2026
fcfa832
Update
manuelcandales Feb 2, 2026
2e50286
Update
manuelcandales Feb 2, 2026
0145613
Update
manuelcandales Feb 2, 2026
2e3254a
Update
manuelcandales Feb 3, 2026
c5a3c1a
Update
manuelcandales Feb 3, 2026
457428b
Update
manuelcandales Feb 3, 2026
fec15bc
Update
manuelcandales Feb 3, 2026
40ec415
Update
manuelcandales Feb 3, 2026
c16dc59
Update
manuelcandales Feb 4, 2026
8ee7d60
Update
manuelcandales Feb 4, 2026
9966d37
Update
manuelcandales Feb 4, 2026
646b4b3
Update
manuelcandales Feb 5, 2026
3483dbf
Update
manuelcandales Feb 5, 2026
310b1b6
Update
manuelcandales Feb 5, 2026
6ad4556
Update
manuelcandales Feb 5, 2026
7e422e2
Update
manuelcandales Feb 5, 2026
1ae26f5
Update
manuelcandales Feb 5, 2026
086e05c
Update
manuelcandales Feb 5, 2026
9cede1e
Update
manuelcandales Feb 5, 2026
4149007
Update
manuelcandales Feb 5, 2026
ade165f
Update
manuelcandales Feb 5, 2026
5ba588f
Update
manuelcandales Feb 5, 2026
11da547
Update
manuelcandales Feb 5, 2026
0bfe7a5
Update
manuelcandales Feb 5, 2026
099bfd3
Update
manuelcandales Feb 5, 2026
7ee1d30
Update
manuelcandales Feb 5, 2026
3655f63
Update
manuelcandales Feb 5, 2026
a3a8aca
Update
manuelcandales Feb 5, 2026
f4203c8
Update
manuelcandales Feb 5, 2026
c96a67f
Update
manuelcandales Feb 5, 2026
e81b589
Update
manuelcandales Feb 5, 2026
0f2cddd
Update
manuelcandales Feb 5, 2026
96e72b1
Update
manuelcandales Feb 5, 2026
8ff273f
Update
manuelcandales Feb 5, 2026
5b3ea00
Update
manuelcandales Feb 5, 2026
4316164
Update
manuelcandales Feb 5, 2026
6bab05d
Update
manuelcandales Feb 5, 2026
401af46
Update
manuelcandales Feb 5, 2026
957ba1f
Update
manuelcandales Feb 5, 2026
5b457db
Update
manuelcandales Feb 5, 2026
87f1529
Update
manuelcandales Feb 5, 2026
9ea88a9
Update
manuelcandales Feb 5, 2026
4fb7659
Update
manuelcandales Feb 5, 2026
cf89a2b
Update
manuelcandales Feb 5, 2026
56f91d6
Update
manuelcandales Feb 5, 2026
46e48be
Update
manuelcandales Feb 5, 2026
4962722
Update
manuelcandales Feb 5, 2026
65e572d
Update
manuelcandales Feb 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ Arguments:
quant_name Quantization type (optional, default: non-quantized)
Options:
- non-quantized
- quantized-int4-tile-packed
- quantized-int4-weight-only
- quantized-int4-tile-packed (CUDA only)
- quantized-int4-weight-only (CUDA only)
- quantized-int4-metal (Metal only)
- quantized-8da4w (XNNPACK only)

output_dir Output directory for artifacts (optional, default: current directory)

Examples:
export_model_artifact.sh metal "openai/whisper-small"
export_model_artifact.sh metal "nvidia/parakeet-tdt" "quantized-int4-metal"
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output"
Expand Down Expand Up @@ -131,18 +133,25 @@ case "$QUANT_NAME" in
;;
quantized-int4-tile-packed)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d"
;;
quantized-int4-weight-only)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear_encoder 4w"
;;
quantized-int4-metal)
if [ "$DEVICE" != "metal" ]; then
echo "Error: Quantization '$QUANT_NAME' only supported on Metal backend"
exit 1
fi
EXTRA_ARGS="--qlinear fpa4w --qlinear_encoder fpa4w"
;;
quantized-8da4w)
if [ "$DEVICE" != "xnnpack" ]; then
echo "Error: quantized-8da4w is only supported with xnnpack device"
Expand All @@ -152,7 +161,7 @@ case "$QUANT_NAME" in
;;
*)
echo "Error: Unsupported quantization '$QUANT_NAME'"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-8da4w"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-int4-metal, quantized-8da4w"
exit 1
;;
esac
Expand Down
15 changes: 13 additions & 2 deletions .github/workflows/metal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ jobs:
set -eux

echo "::group::Setup ExecuTorch"
PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh
PYTHON_EXECUTABLE=python ${CONDA_RUN} EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh
echo "::endgroup::"

echo "::group::Build Metal Runtime"
${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --update-ao
${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --build
echo "::endgroup::"

Expand Down Expand Up @@ -73,6 +72,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down Expand Up @@ -123,6 +128,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down
8 changes: 6 additions & 2 deletions backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:

@classmethod
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
"""Return Metal-specific passes (currently none)"""
return []
"""Return Metal-specific passes"""
from executorch.backends.apple.metal.passes.decompose_linear_pass import (
DecomposeLinearPass,
)

return [DecomposeLinearPass()]

@classmethod
def get_aoti_compile_options(
Expand Down
11 changes: 11 additions & 0 deletions backends/apple/metal/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.apple.metal.passes.decompose_linear_pass import ( # noqa: F401
DecomposeLinearPass,
)

__all__ = ["DecomposeLinearPass"]
111 changes: 111 additions & 0 deletions backends/apple/metal/passes/decompose_linear_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeLinearPass(ExportPass):
"""
Decompose aten.linear into matmul + add to avoid addmm.

For 2D inputs, we unsqueeze to 3D before decomposition to force the matmul
code path instead of addmm. The C++ implementation of aten.linear directly
calls addmm for 2D inputs with bias, which would require implementing
aoti_torch_mps_addmm_out. By unsqueezing to 3D, we force the matmul path,
then squeeze back to 2D.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta['val'] is not being preserved look like?

Copy link
Contributor Author

@manuelcandales manuelcandales Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressing in #17237

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just override call_operator(self, op, args, kwargs, meta): so that it eliminates all the boilerplate code

Copy link
Contributor Author

@manuelcandales manuelcandales Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressing in #17237

modified = False
graph = graph_module.graph

for node in graph.nodes:
# Check if this is a linear operation
is_linear = False

if node.op == "call_function":
# Match both edge dialect and core aten linear operators
if node.target == exir_ops.edge.aten.linear.default:
is_linear = True
elif node.target == torch.ops.aten.linear.default:
is_linear = True

if is_linear:
# Get input, weight, and bias arguments
input_node = node.args[0]
weight_node = node.args[1]
bias_node = node.args[2] if len(node.args) > 2 else None

with graph.inserting_before(node):
# Determine which ops to use based on the input operator
target_str = str(node.target)

if "executorch_exir_dialects_edge" in target_str:
# Use edge dialect operators
t_op = exir_ops.edge.aten.t.default
matmul_op = exir_ops.edge.aten.matmul.default
add_op = exir_ops.edge.aten.add.Tensor
unsqueeze_op = exir_ops.edge.aten.unsqueeze.default
squeeze_op = exir_ops.edge.aten.squeeze.dims
Comment on lines +53 to +54
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pass uses incorrect operator names for edge dialect squeeze and unsqueeze operations. It should use exir_ops.edge.aten.squeeze_copy.dims instead of exir_ops.edge.aten.squeeze.dims, and exir_ops.edge.aten.unsqueeze_copy.default instead of exir_ops.edge.aten.unsqueeze.default. The edge dialect typically uses _copy variants of view operations.

Suggested change
unsqueeze_op = exir_ops.edge.aten.unsqueeze.default
squeeze_op = exir_ops.edge.aten.squeeze.dims
unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims

Copilot uses AI. Check for mistakes.
else:
# Use core aten operators
t_op = torch.ops.aten.t.default
matmul_op = torch.ops.aten.matmul.default
add_op = torch.ops.aten.add.Tensor
unsqueeze_op = torch.ops.aten.unsqueeze.default
squeeze_op = torch.ops.aten.squeeze.dims

# Check if input is 2D
needs_unsqueeze = False
if hasattr(input_node, "meta") and "val" in input_node.meta:
if len(input_node.meta["val"].shape) == 2:
needs_unsqueeze = True
Comment on lines +65 to +67
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pass assumes that all input nodes have metadata with a 'val' attribute to determine dimensionality. If metadata is missing or incomplete, the pass will not unsqueeze 2D inputs, which could result in the addmm code path being taken instead of matmul. This defeats the purpose of the pass. Consider adding a fallback or error handling when metadata is not available, or document this limitation clearly.

Suggested change
if hasattr(input_node, "meta") and "val" in input_node.meta:
if len(input_node.meta["val"].shape) == 2:
needs_unsqueeze = True
if hasattr(input_node, "meta"):
val_meta = input_node.meta.get("val", None)
if val_meta is not None and hasattr(val_meta, "shape"):
if len(val_meta.shape) == 2:
needs_unsqueeze = True
else:
raise RuntimeError(
"DecomposeLinearPass requires input_node.meta['val'] with a 'shape' "
"attribute to determine input dimensionality."
)

Copilot uses AI. Check for mistakes.

# Unsqueeze 2D input to 3D: (M, K) -> (1, M, K)
current_input = input_node
if needs_unsqueeze:
current_input = graph.call_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use call_operator and pass in meta

Copy link
Contributor Author

@manuelcandales manuelcandales Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressing in #17237

unsqueeze_op,
args=(input_node, 0),
)

# Decompose linear: matmul(input, weight.T) + bias
weight_t = graph.call_function(
t_op,
args=(weight_node,),
)

matmul_result = graph.call_function(
matmul_op,
args=(current_input, weight_t),
)

if bias_node is not None:
result = graph.call_function(
add_op,
args=(matmul_result, bias_node),
)
else:
result = matmul_result

# Squeeze 3D output back to 2D: (1, M, N) -> (M, N)
if needs_unsqueeze:
result = graph.call_function(
squeeze_op,
args=(result, [0]),
)

# Replace all uses of the linear node with the decomposed result
node.replace_all_uses_with(result)
graph.erase_node(node)
modified = True

if modified:
graph_module.recompile()

return PassResult(graph_module, modified)
16 changes: 16 additions & 0 deletions backends/apple/metal/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,22 @@ def forward(self, x: torch.Tensor):
}


# -------------------------------------------------------------------------
class LinearWithBias(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(7, 101, bias=True)

def forward(self, x: torch.Tensor):
return self.linear(x)


MODULE_REGISTRY["linear_bias"] = {
"model_class": LinearWithBias,
"input_shapes": [(127, 7)],
"description": "Simple linear layer model bias",
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description string is incomplete and grammatically incorrect. It should be "Simple linear layer model with bias" to be consistent with the "linear_nobias" description above.

Suggested change
"description": "Simple linear layer model bias",
"description": "Simple linear layer model with bias",

Copilot uses AI. Check for mistakes.


Comment on lines +225 to +226
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing closing brace in the MODULE_REGISTRY dictionary definition. The dictionary is opened but not closed, which will cause a syntax error.

Suggested change
}

Copilot uses AI. Check for mistakes.
# -------------------------------------------------------------------------
class LinearNoBiasInt4(nn.Module):
def __init__(self):
Expand Down
41 changes: 33 additions & 8 deletions examples/models/parakeet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,24 @@ The export script supports quantizing encoder and decoder linear layers using [t

| Argument | Description |
|----------|-------------|
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) |
| `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) |
| `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` |
| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` |
| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) |

#### Quantization Configs

| Config | Description |
|--------|-------------|
| `4w` | 4-bit weight only quantization |
| `8w` | 8-bit weight only quantization |
| `8da4w` | 8-bit dynamic activation, 4-bit weight |
| `8da8w` | 8-bit dynamic activation, 8-bit weight |
| Config | Description | Backends |
|--------|-------------|----------|
| `4w` | 4-bit weight only quantization | CUDA |
| `8w` | 8-bit weight only quantization | CUDA |
| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA |
| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA |
| `fpa4w` | Floating point activation, 4-bit weight | Metal |

#### Example: Dynamic Quantization for XNNPACK

Expand Down Expand Up @@ -86,6 +87,30 @@ python export_parakeet_tdt.py \

**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA.

#### Example: Metal 4-bit Quantization

```bash
python export_parakeet_tdt.py \
--backend metal \
--qlinear_encoder fpa4w \
--qlinear_encoder_group_size 32 \
--qlinear fpa4w \
--qlinear_group_size 32 \
--output-dir ./parakeet_metal_quantized
```

**Note:** Metal 4-bit quantization requires torchao built with experimental MPS (Metal) ops.

You can install torchao with Metal support from the `ao` repo:
```bash
USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation
```

Alternatively, you can build torchao with Metal support while installing ExecuTorch:
```bash
EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh
```

### Metal Export (macOS)

```bash
Expand Down
10 changes: 8 additions & 2 deletions examples/models/parakeet/export_parakeet_tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def main():
parser.add_argument(
"--qlinear",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for decoder linear layers",
)
parser.add_argument(
Expand All @@ -642,7 +642,7 @@ def main():
parser.add_argument(
"--qlinear_encoder",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for encoder linear layers",
)
parser.add_argument(
Expand Down Expand Up @@ -678,6 +678,12 @@ def main():
if args.dtype == "fp16":
parser.error("fp16 is not yet supported")

# Validate fpa4w quantization requires Metal backend
if args.qlinear == "fpa4w" and args.backend != "metal":
parser.error("--qlinear=fpa4w can only be used with --backend=metal")
if args.qlinear_encoder == "fpa4w" and args.backend != "metal":
parser.error("--qlinear_encoder=fpa4w can only be used with --backend=metal")

os.makedirs(args.output_dir, exist_ok=True)

print("Extracting tokenizer...")
Expand Down
Loading
Loading