diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 5575e636027..b83995dd10b 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -39,6 +39,7 @@ exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405 + exir_ops.edge.aten.neg.default: NegConverter, # noqa F405 exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 exir_ops.edge.aten.slice_copy.Tensor: SliceTensorConverter, # noqa F405 diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py index 3b8b9bf9b3f..0d8c74cc183 100755 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py @@ -40,6 +40,9 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mul_tensor_converter import ( MulTensorConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.neg_converter import ( + NegConverter, +) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.permute_copy_converter import ( PermuteCopyConverter, ) @@ -87,6 +90,7 @@ "MeanDimConverter", "MMConverter", "MulTensorConverter", + "NegConverter", "PermuteCopyConverter", "QDQPerChannelDequantizeConverter", "QDQPerTensorDequantizeConverter", diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/neg_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/neg_converter.py new file mode 100644 index 00000000000..96f56c1764d --- /dev/null +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/neg_converter.py @@ -0,0 +1,69 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np + +from executorch.backends.nxp.backend import edge_helper +from executorch.backends.nxp.backend.ir.converter.node_converter import ( + CustomDelegationOptions, + NodeConverter, +) +from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( + sub_options, +) +from executorch.backends.nxp.backend.ir.tflite_generator.tflite_model import ( + Quantization, + Scale, + ZeroPoint, +) +from torch.fx import Node +from torch.nn import Parameter + + +class NegConverter(NodeConverter): + + @staticmethod + def _is_supported_in_IR( + node: Node, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if len(node.args) != 1: + # Should never happen + return False + + # The conversion code below expects a per tensor quantized operator. + scale, zp = edge_helper.get_quantization_parameters_for(node.args[0]) + match scale, zp: + case [float(), int()]: + pass # Atomic quantization parameters -> per tensor quantization. + case _: + return False # Everything else is unexpected. + + return True + + def convert(self, node: Node): + """Convert 'aten.neg.default' operator to NeutronIR 0 - 'Sub'. + + The ExecuTorch schema is 'aten::neg(Tensor self) -> Tensor' + """ + self.assert_convertible(node) + + t_op = self._create_tflite_op_with_io_tensors(node) + + x = t_op.tmp_inputs[0] + + # Extract the zero_point, to use as the first input of the `Sub`. + scale = x.quantization.scale.vector + zp = x.quantization.zero_point.vector + zero_tensor = self.builder.create_tensor_for_data(np.array(zp, "int8"), "zero") + zero_tensor.quantization = Quantization( + scale=Scale(list(scale)), zero_point=ZeroPoint(list(zp)) + ) + + # Assign the NeutronIR operator its builtin options and inputs. + t_op.builtin_options = sub_options.Sub() + t_op.tmp_inputs = [zero_tensor, x] + + self.builder.append_operators([t_op]) diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index eba96fc6c48..eb393fc0261 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -211,6 +211,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]): exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 exir_ops.edge.aten.mul.Tensor: MulTensorConverter, # noqa F405 + exir_ops.edge.aten.neg.default: NegConverter, # noqa F405 exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index 62774fcb51d..738a816337e 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -32,6 +32,7 @@ MeanDimPattern, MmPattern, MulTensorPattern, + NegPattern, NodeArgsIdx, PadPattern, PermutePattern, @@ -261,6 +262,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False) OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig), OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig), OpQuantizer(MulTensorPattern(is_qat=is_qat), static_qconfig), + OpQuantizer(NegPattern(is_qat=is_qat), static_qconfig), OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig), OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig), OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig), diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 2412fd1ea53..94b2a5bf285 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -799,6 +799,15 @@ def get_anchors( ) +class NegPattern(SharedSpecPattern): + """ + Quantizer for the `aten.neg.default` operator. + """ + + def partition_types(self): + return [torch.ops.aten.neg.default] + + class PadPattern(SharedSpecPattern): """ Quantizer for Pad operator. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_neg_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_neg_converter.py new file mode 100644 index 00000000000..d116d0e2364 --- /dev/null +++ b/backends/nxp/tests/ir/converter/node_converter/test_neg_converter.py @@ -0,0 +1,123 @@ +# Copyright 2026 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pytest +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +@pytest.fixture(autouse=True) +def reseed_model_per_test_run(): + torch.manual_seed(42) + np.random.seed(23) + + +# noinspection PyProtectedMember +ExecutorchDelegateCall = torch._higher_order_ops.executorch_call_delegate +Neg = exir_ops.edge.aten.neg.default + + +class NegModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + # noinspection PyMethodMayBeStatic + def forward(self, x): + return -x + + +class ConvNegModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 1) + + # noinspection PyMethodMayBeStatic + def forward(self, x): + x = self.conv(x) + return -x + + +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((8,), id="1D"), + pytest.param((4, 2), id="2D"), + pytest.param((1, 2, 3), id="3D"), + pytest.param((1, 2, 3, 4), id="4D"), + ], +) +def test_convert_neg(mocker, input_shape): + model = NegModule() + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + delegated_ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `neg` was delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert not graph_contains_any_of_ops(delegated_ep.graph, [Neg]) + + # Verify correct behavior of the converted NeutronIR model. + intermediate_ep = converter_spy.call_args.args[1] + neutron_ir_model, _ = converter_spy.spy_return + + input_data = ( + np.random.random(input_shape).astype(np.float32) * 256.0 - 128.0 + ).astype(np.int8) + + # Make sure the tested program contains the `neg`. + assert graph_contains_any_of_ops(intermediate_ep.graph, [Neg]) + + convert_run_compare( + intermediate_ep, + tfl_model=neutron_ir_model, + input_data=input_data, + ) + + +def test_convert_neg__channels_last(mocker): + model = ConvNegModule() + input_shape = (1, 3, 4, 5) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + delegated_ep = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() + + # Make sure the `neg` was delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert not graph_contains_any_of_ops(delegated_ep.graph, [Neg]) + + # Verify correct behavior of the converted NeutronIR model. + intermediate_ep = converter_spy.call_args.args[1] + neutron_ir_model, _ = converter_spy.spy_return + + input_data = ( + np.random.random(input_shape).astype(np.float32) * 256.0 - 128.0 + ).astype(np.int8) + + # Make sure the tested program contains the `neg`. + assert graph_contains_any_of_ops(intermediate_ep.graph, [Neg]) + + convert_run_compare( + intermediate_ep, + tfl_model=neutron_ir_model, + input_data=input_data, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) diff --git a/docs/source/backends/nxp/op-support.csv b/docs/source/backends/nxp/op-support.csv index 581ec3ffb94..7d493c4f681 100644 --- a/docs/source/backends/nxp/op-support.csv +++ b/docs/source/backends/nxp/op-support.csv @@ -14,6 +14,7 @@ aten.max_pool2d_with_indices.default,int8,static int8,"dilation=1, ceil_mode=Fal aten.mean.dim,int8,static int8,"4D tensor only, dims = [-1,-2] or [-2,-1]" aten.mul.Tensor, int8, static int8, "tensor-size % 8 = 0" aten.mm.default,int8,static int8,"2D tensor only" +aten.neg.default,int8,static int8, aten.relu.default,int8,static int8, aten.tanh.default,int8,static int8, aten.view_copy.default,int8,static int8,