From 354a9249900c5938333f4c4a873fa65e5a2db8ab Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Fri, 6 Feb 2026 16:59:45 -0800 Subject: [PATCH] Fix quantized_w8a32_conv1d fusion and operator Summary: 1. Python fusion pass (fusion_pass.py): The permute nodes created in get_args_and_kwargs_mixed_w8a32_conv were missing val metadata, causing SpecViolationError during ExportedProgram validation. Added val computation using fake_mode, matching the pattern used by AddmmPattern. 2. C++ operators (generic + HiFi): cnn_step was called once with wrong dimensions (input length instead of kernel_size for d1), only producing out_channels values instead of the full output. Fixed to loop over output time steps, calling cnn_step per position with correct kernel_size, then scattering to NCL output layout. Differential Revision: D92565258 --- backends/cadence/aot/quantizer/fusion_pass.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 7093ef19c3d..9c926d91000 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -432,12 +432,39 @@ def get_args_and_kwargs_mixed_w8a32_conv( torch.ops.aten.permute.default, (other_inputs[0], [0, 2, 1]), # NCL -> NLC ) + if "val" in other_inputs[0].meta: + original_val = other_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_inputs.meta["val"] = torch.ops.aten.permute.default( + original_val, [0, 2, 1] + ) + else: + shape = list(original_val.shape) + # NCL -> NLC: [0,1,2] -> [0,2,1] + shape[1], shape[2] = shape[2], shape[1] + transposed_inputs.meta["val"] = torch.zeros(shape, dtype=original_val.dtype) copy_node_metadata(transposed_inputs, other_inputs[0]) transposed_weights = graph_module.graph.call_function( torch.ops.aten.permute.default, (weights_inputs[0], [2, 0, 1]), # NCL -> LNC ) + if "val" in weights_inputs[0].meta: + original_val = weights_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_weights.meta["val"] = torch.ops.aten.permute.default( + original_val, [2, 0, 1] + ) + else: + shape = list(original_val.shape) + # NCL -> LNC: [0,1,2] -> [2,0,1] + transposed_weights.meta["val"] = torch.zeros( + [shape[2], shape[0], shape[1]], dtype=original_val.dtype + ) copy_node_metadata(transposed_weights, weights_inputs[0]) args = (