From 15545c046f40e568c3e1fbdd24148cbcfcb0d890 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 4 Feb 2026 08:49:56 +0100 Subject: [PATCH] Arm backend: Improve noop handling - Removes the identity node_visitor in favour of handling all noops in the remove_noop_pass - Adds detach_copy to the partitioner is_noop check to avoid partitioning graphs with single detach_copy ops - Make noop check functions in the partitioner private and remove public documentation. Instead handle alias_copy similar to other noops Signed-off-by: Adrian Lundell Change-Id: I018f4959865eae29b6842b88269e6be5b208509a --- backends/arm/_passes/remove_noop_pass.py | 4 +- backends/arm/operators/__init__.py | 1 - backends/arm/operators/ops_identity.py | 80 ----------------------- backends/arm/test/ops/test_detach_copy.py | 2 +- backends/arm/tosa/partitioner.py | 62 ++++-------------- 5 files changed, 15 insertions(+), 134 deletions(-) delete mode 100644 backends/arm/operators/ops_identity.py diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 8ac808809ef..ff33b273155 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -25,7 +25,9 @@ def call_operator(self, op, args, kwargs, meta): if op not in ( exir_ops.edge.dim_order_ops._clone_dim_order.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten.alias_copy.default, exir_ops.edge.aten.copy.default, + exir_ops.edge.aten.detach_copy.default, ): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index ecac81f5761..74105580ccd 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -69,5 +69,4 @@ op_where, op_while, ops_binary, - ops_identity, ) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py deleted file mode 100644 index 105599bb392..00000000000 --- a/backends/arm/operators/ops_identity.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Any, List - -import torch -import torch.fx - -import tosa_serializer as ts - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg - - -def identity_operator_factory(identity_target: str): - """Creates and registers NodeVisitors for operators that map directly to a - TOSA IDENTITY op. - """ - - class IdentityOperatorVisitor(NodeVisitor): - target = identity_target - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [inputs[0], output], ts) - supported_dtypes = [ - ts.DType.BOOL, - ts.DType.INT8, - ts.DType.INT16, - ts.DType.INT32, - ] - if self.tosa_spec.support_float(): - supported_dtypes += [ts.DType.FP32] - if self.tosa_spec.support_extension("bf16"): - supported_dtypes += [ts.DType.BF16] - if self.tosa_spec.support_extension("int16"): - supported_dtypes += [ts.DType.INT48] - if self.tosa_spec.support_extension("int4"): - supported_dtypes += [ts.DType.INT4] - validate_valid_dtype( - self.target, - [inputs[0], output], - supported_dtypes, - self.tosa_spec, - ) - - # Simply add an identityOp - attr = ts.TosaSerializerAttribute() - attr.IdentityAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.IDENTITY, - [inputs[0].name], - [output.name], - attr, - ) - - register_node_visitor(IdentityOperatorVisitor) - - -identity_operator_factory("aten.alias_copy.default") -identity_operator_factory("aten.detach_copy.default") diff --git a/backends/arm/test/ops/test_detach_copy.py b/backends/arm/test/ops/test_detach_copy.py index b5878729b44..c8715ca847a 100644 --- a/backends/arm/test/ops/test_detach_copy.py +++ b/backends/arm/test/ops/test_detach_copy.py @@ -31,7 +31,7 @@ class DetachCopy(torch.nn.Module): exir_op = exir_op def forward(self, x: torch.Tensor): - return torch.detach_copy(x) + return torch.detach_copy(x) + 1 @common.parametrize("test_data", test_data_suite) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 640dcd5761b..fa5e8e352e5 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -46,47 +46,19 @@ logger = logging.getLogger(__name__) -def is_noop_clone(node: torch.fx.node.Node) -> bool: - """Return True if the node is a no-op ``dim_order_ops._clone_dim_order``. - - Args: - node (torch.fx.Node): FX node to inspect. - - Returns: - bool: True if the node targets ``dim_order_ops._clone_dim_order.default`` - in the Edge dialect; otherwise, False. - - """ +def _is_noop_clone(node: torch.fx.node.Node) -> bool: return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default -def is_noop_alias_copy(node: torch.fx.Node) -> bool: - """Return True if the node is a no-op ``aten.alias_copy``. - - Args: - node (torch.fx.Node): FX node to inspect. - - Returns: - bool: True if the node targets ``aten.alias_copy.default``; otherwise, - False. - - """ +def _is_noop_alias_copy(node: torch.fx.Node) -> bool: return node.target == exir_ops.edge.aten.alias_copy.default -def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: - """Return True if node is a no-op ``dim_order_ops._to_dim_order_copy``. +def _is_noop_detach_copy(node: torch.fx.Node) -> bool: + return node.target == exir_ops.edge.aten.detach_copy.default - Consider the op a no-op when the output dtype equals the input's dtype. - Args: - node (torch.fx.Node): FX node to inspect. - - Returns: - bool: True if it targets ``_to_dim_order_copy.default`` and preserves - dtype; otherwise, False. - - """ +def _is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: return False else: @@ -94,20 +66,7 @@ def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: return node.meta.get("dtype") == get_first_fake_tensor(input_node).dtype -def is_noop_expand(node: torch.fx.node.Node) -> bool: - """Return True if the node is an ``expand_copy`` with all-ones multiples. - - This corresponds to a semantic no-op, since expanding by 1 along every - dimension leaves the tensor unchanged. - - Args: - node (torch.fx.Node): FX node to inspect. - - Returns: - bool: True if the node targets ``aten.expand_copy.default`` and all - computed multiples are 1; otherwise, False. - - """ +def _is_noop_expand(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.aten.expand_copy.default: return False else: @@ -291,10 +250,11 @@ def _tag_module( # noqa ) is_noop_partition = all( - is_noop_clone(node) - or is_noop_alias_copy(node) - or is_noop_expand(node) - or is_noop_to_dim_order_copy(node) + _is_noop_clone(node) + or _is_noop_alias_copy(node) + or _is_noop_detach_copy(node) + or _is_noop_expand(node) + or _is_noop_to_dim_order_copy(node) or node.target in Q_OPS or node.target in DQ_OPS for node in partition.nodes