diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 7093ef19c3d..9c926d91000 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -432,12 +432,39 @@ def get_args_and_kwargs_mixed_w8a32_conv( torch.ops.aten.permute.default, (other_inputs[0], [0, 2, 1]), # NCL -> NLC ) + if "val" in other_inputs[0].meta: + original_val = other_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_inputs.meta["val"] = torch.ops.aten.permute.default( + original_val, [0, 2, 1] + ) + else: + shape = list(original_val.shape) + # NCL -> NLC: [0,1,2] -> [0,2,1] + shape[1], shape[2] = shape[2], shape[1] + transposed_inputs.meta["val"] = torch.zeros(shape, dtype=original_val.dtype) copy_node_metadata(transposed_inputs, other_inputs[0]) transposed_weights = graph_module.graph.call_function( torch.ops.aten.permute.default, (weights_inputs[0], [2, 0, 1]), # NCL -> LNC ) + if "val" in weights_inputs[0].meta: + original_val = weights_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_weights.meta["val"] = torch.ops.aten.permute.default( + original_val, [2, 0, 1] + ) + else: + shape = list(original_val.shape) + # NCL -> LNC: [0,1,2] -> [2,0,1] + transposed_weights.meta["val"] = torch.zeros( + [shape[2], shape[0], shape[1]], dtype=original_val.dtype + ) copy_node_metadata(transposed_weights, weights_inputs[0]) args = (