From 09daf81327a34adc1efd87f328e8bbeac71f3c10 Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Fri, 6 Feb 2026 16:59:49 -0800 Subject: [PATCH] Fix quantized_w8a32_gru output shape to preserve batch dimensions Summary: The meta function for `quantized_w8a32_gru` returned shape `(2, hidden_dim)`, dropping the batch dimensions of the hidden state. After the fusion pass replaces `aten.gru.input` with the custom op, `getitem[0]` on the output produced shape `(hidden_dim,)` instead of the original `(batch, ..., hidden_dim)`, causing a shape mismatch with the reference output and NaN in the comparison. Fix the meta function to return `(2, *hidden.shape)` to preserve all dimensions. Also update the reference implementation to handle 3D hidden states and preserve the original shape in the output, consistent with the shape fix applied to `quantized_w8a32_conv` in D84745967. Differential Revision: D92570517 --- backends/cadence/aot/ops_registrations.py | 2 +- backends/cadence/aot/ref_implementations.py | 14 ++++++++------ .../cadence/aot/tests/test_ref_implementations.py | 6 +++--- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 49732da4ce8..9b4568e008d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -2858,7 +2858,7 @@ def quantized_w8a32_gru_meta( bias_hidden: torch.Tensor, b_h_scale: float, ) -> torch.Tensor: - return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32) + return hidden.new_empty((2, *hidden.shape), dtype=torch.float32) @register_fake("cadence::slice_scatter_") diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 8a533c80db1..dd04d563028 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1019,11 +1019,13 @@ def quantized_w8a32_gru( assert inputs.dtype == torch.float32 assert hidden.dtype == torch.float32 - if len(hidden.shape) > 2: - raise ValueError("Hidden state must be 2D or 1D") - - if len(hidden.shape) == 2 and hidden.shape[0] != 1: - raise ValueError("Leading dimension of hidden state must be 1") + # Hidden state can be 1D, 2D (1, hidden_dim), or 3D (1, 1, hidden_dim). + # All leading dimensions must be 1. + for d in range(len(hidden.shape) - 1): + if hidden.shape[d] != 1: + raise ValueError( + f"Leading dimension {d} of hidden state must be 1, got {hidden.shape[d]}" + ) original_hidden_shape = hidden.shape hidden = hidden.view(-1) @@ -1059,7 +1061,7 @@ def quantized_w8a32_gru( assert new_hidden.shape == original_hidden_shape - new_hidden = new_hidden.view(-1) + new_hidden = new_hidden.view(original_hidden_shape) return torch.stack([new_hidden, new_hidden], dim=0) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index e0960522c32..e5c16fa22ea 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2952,7 +2952,7 @@ def test_quantized_w8a32_gru( b_h_scale, ) self.assertIn( - "Leading dimension of hidden state must be 1", str(context.exception) + "Leading dimension 0 of hidden state must be 1", str(context.exception) ) return @@ -2977,8 +2977,8 @@ def test_quantized_w8a32_gru( ) self.assertEqual( output.shape, - (2, hidden.shape[-1]), - f"Output shape should match {(2, hidden.shape[-1])} in {name}", + (2, *hidden.shape), + f"Output shape should match {(2, *hidden.shape)} in {name}", ) assert isinstance(output, torch.Tensor)