-
Notifications
You must be signed in to change notification settings - Fork 829
Metal backend: enable linear with bias #17115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
39db621
0ed7c5c
b4310cc
94c823c
31b6f45
c68cc6b
bd7192f
bcc8bda
f166c50
0834659
ed4dcee
a058197
7146282
d3501af
fe5be37
a0e3469
fcfa832
2e50286
0145613
2e3254a
c5a3c1a
457428b
fec15bc
40ec415
c16dc59
8ee7d60
9966d37
646b4b3
3483dbf
310b1b6
6ad4556
7e422e2
1ae26f5
086e05c
9cede1e
4149007
ade165f
5ba588f
11da547
0bfe7a5
099bfd3
7ee1d30
3655f63
a3a8aca
f4203c8
c96a67f
e81b589
0f2cddd
96e72b1
8ff273f
5b3ea00
4316164
6bab05d
401af46
957ba1f
5b457db
87f1529
9ea88a9
4fb7659
cf89a2b
56f91d6
46e48be
4962722
65e572d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from executorch.backends.apple.metal.passes.decompose_linear_pass import ( # noqa: F401 | ||
| DecomposeLinearPass, | ||
| ) | ||
|
|
||
| __all__ = ["DecomposeLinearPass"] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,111 @@ | ||||||||||||||||||||||||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||||||||||||
| # All rights reserved. | ||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||
| # This source code is licensed under the BSD-style license found in the | ||||||||||||||||||||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||
| from executorch.exir.dialects._ops import ops as exir_ops | ||||||||||||||||||||||||||||
| from executorch.exir.pass_base import ExportPass, PassResult | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class DecomposeLinearPass(ExportPass): | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| Decompose aten.linear into matmul + add to avoid addmm. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| For 2D inputs, we unsqueeze to 3D before decomposition to force the matmul | ||||||||||||||||||||||||||||
| code path instead of addmm. The C++ implementation of aten.linear directly | ||||||||||||||||||||||||||||
| calls addmm for 2D inputs with bias, which would require implementing | ||||||||||||||||||||||||||||
| aoti_torch_mps_addmm_out. By unsqueezing to 3D, we force the matmul path, | ||||||||||||||||||||||||||||
| then squeeze back to 2D. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just override
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressing in #17237 |
||||||||||||||||||||||||||||
| modified = False | ||||||||||||||||||||||||||||
| graph = graph_module.graph | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for node in graph.nodes: | ||||||||||||||||||||||||||||
| # Check if this is a linear operation | ||||||||||||||||||||||||||||
| is_linear = False | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if node.op == "call_function": | ||||||||||||||||||||||||||||
| # Match both edge dialect and core aten linear operators | ||||||||||||||||||||||||||||
| if node.target == exir_ops.edge.aten.linear.default: | ||||||||||||||||||||||||||||
| is_linear = True | ||||||||||||||||||||||||||||
| elif node.target == torch.ops.aten.linear.default: | ||||||||||||||||||||||||||||
| is_linear = True | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if is_linear: | ||||||||||||||||||||||||||||
| # Get input, weight, and bias arguments | ||||||||||||||||||||||||||||
| input_node = node.args[0] | ||||||||||||||||||||||||||||
| weight_node = node.args[1] | ||||||||||||||||||||||||||||
| bias_node = node.args[2] if len(node.args) > 2 else None | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| with graph.inserting_before(node): | ||||||||||||||||||||||||||||
| # Determine which ops to use based on the input operator | ||||||||||||||||||||||||||||
| target_str = str(node.target) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if "executorch_exir_dialects_edge" in target_str: | ||||||||||||||||||||||||||||
| # Use edge dialect operators | ||||||||||||||||||||||||||||
| t_op = exir_ops.edge.aten.t.default | ||||||||||||||||||||||||||||
| matmul_op = exir_ops.edge.aten.matmul.default | ||||||||||||||||||||||||||||
| add_op = exir_ops.edge.aten.add.Tensor | ||||||||||||||||||||||||||||
| unsqueeze_op = exir_ops.edge.aten.unsqueeze.default | ||||||||||||||||||||||||||||
| squeeze_op = exir_ops.edge.aten.squeeze.dims | ||||||||||||||||||||||||||||
|
Comment on lines
+53
to
+54
|
||||||||||||||||||||||||||||
| unsqueeze_op = exir_ops.edge.aten.unsqueeze.default | |
| squeeze_op = exir_ops.edge.aten.squeeze.dims | |
| unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default | |
| squeeze_op = exir_ops.edge.aten.squeeze_copy.dims |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pass assumes that all input nodes have metadata with a 'val' attribute to determine dimensionality. If metadata is missing or incomplete, the pass will not unsqueeze 2D inputs, which could result in the addmm code path being taken instead of matmul. This defeats the purpose of the pass. Consider adding a fallback or error handling when metadata is not available, or document this limitation clearly.
| if hasattr(input_node, "meta") and "val" in input_node.meta: | |
| if len(input_node.meta["val"].shape) == 2: | |
| needs_unsqueeze = True | |
| if hasattr(input_node, "meta"): | |
| val_meta = input_node.meta.get("val", None) | |
| if val_meta is not None and hasattr(val_meta, "shape"): | |
| if len(val_meta.shape) == 2: | |
| needs_unsqueeze = True | |
| else: | |
| raise RuntimeError( | |
| "DecomposeLinearPass requires input_node.meta['val'] with a 'shape' " | |
| "attribute to determine input dimensionality." | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use call_operator and pass in meta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressing in #17237
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -208,6 +208,22 @@ def forward(self, x: torch.Tensor): | |||||
| } | ||||||
|
|
||||||
|
|
||||||
| # ------------------------------------------------------------------------- | ||||||
| class LinearWithBias(nn.Module): | ||||||
| def __init__(self): | ||||||
| super().__init__() | ||||||
| self.linear = nn.Linear(7, 101, bias=True) | ||||||
|
|
||||||
| def forward(self, x: torch.Tensor): | ||||||
| return self.linear(x) | ||||||
|
|
||||||
|
|
||||||
| MODULE_REGISTRY["linear_bias"] = { | ||||||
| "model_class": LinearWithBias, | ||||||
| "input_shapes": [(127, 7)], | ||||||
| "description": "Simple linear layer model bias", | ||||||
|
||||||
| "description": "Simple linear layer model bias", | |
| "description": "Simple linear layer model with bias", |
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing closing brace in the MODULE_REGISTRY dictionary definition. The dictionary is opened but not closed, which will cause a syntax error.
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meta['val'] is not being preserved look like?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressing in #17237