Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@
from .remove_getitem_pass import RemoveGetItemPass # noqa
from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa
from .remove_noop_pass import RemoveNoopPass # noqa
from .remove_permutes_around_elementwise_tosa_ops import ( # noqa
RemovePermutesAroundElementwiseTosaOps,
)
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorByProfilePass,
)
Expand Down
6 changes: 2 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
RemoveGetItemPass,
RemoveGraphAssertsPass,
RemoveNoopPass,
RemovePermutesAroundElementwiseTosaOps,
ReplaceInfAndLimitValuesPass,
ReplaceScalarWithTensorByProfilePass,
RewriteAvgPool2dPass,
Expand Down Expand Up @@ -164,9 +165,6 @@
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
)

from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
RemovePermutesAroundElementwiseOps,
)
from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassManager
Expand Down Expand Up @@ -538,7 +536,7 @@ def _tosa_pipeline(
RewriteMatmulPass(),
RewritePadPass(),
FuseViewCopyTransformPass(),
RemovePermutesAroundElementwiseOps(),
Comment thread
oscarandersson8218 marked this conversation as resolved.
RemovePermutesAroundElementwiseTosaOps(),
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
FuseCascadedTransposeOrPermuteOps(),
ConvertPermuteSingletonToViewPass(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
RemovePermutesAroundElementwiseOps,
)
from executorch.exir.dialects._ops import ops as exir_ops


class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
permutable_ops = {
*RemovePermutesAroundElementwiseOps.permutable_ops,
exir_ops.backend.tosa.RESCALE.default,
exir_ops.backend.tosa.TABLE.default,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import (
RemovePermutesAroundElementwiseTosaOps,
)
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops

TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT")
PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default
RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default
TABLE_TARGET = exir_ops.backend.tosa.TABLE.default


def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int:
return sum(
1
for node in graph_module.graph.nodes
if node.op == "call_function" and node.target == target
)


def test_remove_permutes_around_rescale_tosa_INT() -> None:
graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(1, 3, 4, 5)

permute_in = graph.create_node(
"call_function",
PERMUTE_TARGET,
args=(x, [0, 2, 3, 1]),
)
rescale = graph.create_node(
"call_function",
RESCALE_TARGET,
args=(permute_in, torch.int8, [1.0], 0, 0),
)
permute_out = graph.create_node(
"call_function",
PERMUTE_TARGET,
args=(rescale, [0, 3, 1, 2]),
)
graph.output(permute_out)

graph_module = torch.fx.GraphModule({}, graph)

with TosaLoweringContext(TOSA_INT_SPEC):
result = RemovePermutesAroundElementwiseTosaOps().call(graph_module)

assert result.modified
assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0
assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1
Loading