diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 49732da4ce8..9b4568e008d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -2858,7 +2858,7 @@ def quantized_w8a32_gru_meta( bias_hidden: torch.Tensor, b_h_scale: float, ) -> torch.Tensor: - return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32) + return hidden.new_empty((2, *hidden.shape), dtype=torch.float32) @register_fake("cadence::slice_scatter_") diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 8a533c80db1..dd04d563028 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1019,11 +1019,13 @@ def quantized_w8a32_gru( assert inputs.dtype == torch.float32 assert hidden.dtype == torch.float32 - if len(hidden.shape) > 2: - raise ValueError("Hidden state must be 2D or 1D") - - if len(hidden.shape) == 2 and hidden.shape[0] != 1: - raise ValueError("Leading dimension of hidden state must be 1") + # Hidden state can be 1D, 2D (1, hidden_dim), or 3D (1, 1, hidden_dim). + # All leading dimensions must be 1. + for d in range(len(hidden.shape) - 1): + if hidden.shape[d] != 1: + raise ValueError( + f"Leading dimension {d} of hidden state must be 1, got {hidden.shape[d]}" + ) original_hidden_shape = hidden.shape hidden = hidden.view(-1) @@ -1059,7 +1061,7 @@ def quantized_w8a32_gru( assert new_hidden.shape == original_hidden_shape - new_hidden = new_hidden.view(-1) + new_hidden = new_hidden.view(original_hidden_shape) return torch.stack([new_hidden, new_hidden], dim=0) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index e0960522c32..e5c16fa22ea 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2952,7 +2952,7 @@ def test_quantized_w8a32_gru( b_h_scale, ) self.assertIn( - "Leading dimension of hidden state must be 1", str(context.exception) + "Leading dimension 0 of hidden state must be 1", str(context.exception) ) return @@ -2977,8 +2977,8 @@ def test_quantized_w8a32_gru( ) self.assertEqual( output.shape, - (2, hidden.shape[-1]), - f"Output shape should match {(2, hidden.shape[-1])} in {name}", + (2, *hidden.shape), + f"Output shape should match {(2, *hidden.shape)} in {name}", ) assert isinstance(output, torch.Tensor)