-
Notifications
You must be signed in to change notification settings - Fork 971
Fix GRU w8a32 operator (#17226) #17226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -689,11 +689,11 @@ def register_fake( | |||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| lib.define( | ||||||||||||||||||||||||||||||||||||||
| "quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor" | ||||||||||||||||||||||||||||||||||||||
| "quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden) -> Tensor" | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| lib.define( | ||||||||||||||||||||||||||||||||||||||
| "quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)" | ||||||||||||||||||||||||||||||||||||||
| "quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden, *, Tensor(a!) out) -> Tensor(a!)" | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| lib.define( | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -3060,11 +3060,20 @@ def quantized_w8a32_gru_meta( | |||||||||||||||||||||||||||||||||||||
| weights_hidden: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||
| w_h_scale: float, | ||||||||||||||||||||||||||||||||||||||
| bias_inputs: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||
| b_i_scale: float, | ||||||||||||||||||||||||||||||||||||||
| b_scale: float, | ||||||||||||||||||||||||||||||||||||||
| bias_hidden: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||
| b_h_scale: float, | ||||||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||
| return hidden.new_empty((2, *hidden.shape), dtype=torch.float32) | ||||||||||||||||||||||||||||||||||||||
| seq_len = inputs.shape[1] | ||||||||||||||||||||||||||||||||||||||
| assert seq_len == 1 | ||||||||||||||||||||||||||||||||||||||
| # inputs comes in shape [batch, seq_len, input_size] | ||||||||||||||||||||||||||||||||||||||
| # hidden comes in shape [batch, seq_len, hidden_size] | ||||||||||||||||||||||||||||||||||||||
| # weights_inputs comes in shape [3 * hidden_size, input_size] | ||||||||||||||||||||||||||||||||||||||
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | ||||||||||||||||||||||||||||||||||||||
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | ||||||||||||||||||||||||||||||||||||||
| # The first dimension stacks the output and the new hidden state | ||||||||||||||||||||||||||||||||||||||
| return hidden.new_empty( | ||||||||||||||||||||||||||||||||||||||
| (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32 | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+3066
to
+3075
|
||||||||||||||||||||||||||||||||||||||
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| # inputs comes in shape [batch, seq_len, input_size] | |
| # hidden comes in shape [batch, seq_len, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty( | |
| (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32 | |
| # inputs comes in shape [batch, input_size] | |
| # hidden comes in shape [batch, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty( | |
| (2, inputs.shape[0], hidden.shape[-1]), dtype=torch.float32 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv( | |
| torch.ops.aten.permute.default, | ||
| (other_inputs[0], [0, 2, 1]), # NCL -> NLC | ||
| ) | ||
| assert "val" in other_inputs[0].meta, "Missing val metadata on input node" | ||
| original_val = other_inputs[0].meta["val"] | ||
| assert original_val.fake_mode is not None, "fake_mode is None on input node" | ||
| with original_val.fake_mode: | ||
| transposed_inputs.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [0, 2, 1] | ||
| ) | ||
| # Propagate val metadata for transposed_inputs | ||
| if "val" in other_inputs[0].meta: | ||
| original_val = other_inputs[0].meta["val"] | ||
| fake_mode = original_val.fake_mode | ||
| if fake_mode is not None: | ||
| with fake_mode: | ||
| transposed_val = torch.ops.aten.permute.default(original_val, [0, 2, 1]) | ||
| transposed_inputs.meta["val"] = transposed_val | ||
| else: | ||
| transposed_inputs.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [0, 2, 1] | ||
| ) | ||
| copy_node_metadata(transposed_inputs, other_inputs[0]) | ||
|
|
||
| transposed_weights = graph_module.graph.call_function( | ||
| torch.ops.aten.permute.default, | ||
| (weights_inputs[0], [2, 0, 1]), # NCL -> LNC | ||
| ) | ||
| assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node" | ||
| original_val = weights_inputs[0].meta["val"] | ||
| assert original_val.fake_mode is not None, "fake_mode is None on weight node" | ||
| with original_val.fake_mode: | ||
| transposed_weights.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [2, 0, 1] | ||
| ) | ||
| # Propagate val metadata for transposed_weights | ||
| if "val" in weights_inputs[0].meta: | ||
| original_val = weights_inputs[0].meta["val"] | ||
| fake_mode = original_val.fake_mode | ||
| if fake_mode is not None: | ||
| with fake_mode: | ||
| transposed_val = torch.ops.aten.permute.default(original_val, [2, 0, 1]) | ||
| transposed_weights.meta["val"] = transposed_val | ||
| else: | ||
| transposed_weights.meta["val"] = torch.ops.aten.permute.default( | ||
| original_val, [2, 0, 1] | ||
| ) | ||
| copy_node_metadata(transposed_weights, weights_inputs[0]) | ||
|
|
||
| args = ( | ||
|
|
@@ -511,12 +521,10 @@ def get_args_and_kwargs_mixed_w8a32_gru( | |
| ) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: | ||
| # Stride, padding, dilation, groups not supported yet | ||
|
|
||
| assert len(dequants_weights) == 2 | ||
| assert len(dequants_biases) == 2 | ||
| w_i_scale = dequants_weights[0].args[1] | ||
| w_h_scale = dequants_weights[1].args[1] | ||
| b_i_scale = dequants_biases[0].args[1] | ||
| b_h_scale = dequants_biases[1].args[1] | ||
| b_scale = dequants_biases[0].args[1] | ||
|
Comment on lines
524
to
+527
|
||
|
|
||
| args = ( | ||
| other_inputs[0], | ||
|
|
@@ -526,9 +534,8 @@ def get_args_and_kwargs_mixed_w8a32_gru( | |
| weights_inputs[1], | ||
| w_h_scale, | ||
| bias_inputs[0], | ||
| b_i_scale, | ||
| b_scale, | ||
| bias_inputs[1], | ||
| b_h_scale, | ||
| ) | ||
| kwargs = {} | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -718,7 +718,7 @@ def get_anchors( | |||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| cnn_weights = conv_layer.args[1] | ||||||||||||||||||||
| if hasattr(cnn_weights.meta, "tensor_meta"): | ||||||||||||||||||||
| if "tensor_meta" in cnn_weights.meta: | ||||||||||||||||||||
| cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape | ||||||||||||||||||||
| # Bail if the channels are not multiple of 4 (SIMD) | ||||||||||||||||||||
| if cnn_weights_shape[0] % 4 != 0: | ||||||||||||||||||||
|
|
@@ -744,6 +744,18 @@ def get_anchors( | |||||||||||||||||||
| conv_layer, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| inputs = conv_layer.args[0] | ||||||||||||||||||||
| if "tensor_meta" in inputs.meta: | ||||||||||||||||||||
| inputs_shape = inputs.meta["tensor_meta"].shape | ||||||||||||||||||||
| # Bail if length != kernel size - Not yet supported | ||||||||||||||||||||
| if inputs_shape[-1] != cnn_weights_shape[2]: | ||||||||||||||||||||
|
Comment on lines
+750
to
+751
|
||||||||||||||||||||
| # Bail if length != kernel size - Not yet supported | |
| if inputs_shape[-1] != cnn_weights_shape[2]: | |
| # Bail if the input channels do not match the weight's in_channels | |
| # or if the input length is smaller than the kernel size. | |
| if ( | |
| len(inputs_shape) < 3 | |
| or inputs_shape[1] != cnn_weights_shape[1] | |
| or inputs_shape[-1] < cnn_weights_shape[2] | |
| ): |
Copilot
AI
Apr 14, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code assumes the GRU params tuple contains both biases and immediately indexes wrapper.args[2]/[3]. For aten.gru.input, has_biases can be false (e.g. nn.GRU(..., bias=False)), in which case the params list may not contain bias entries and this will raise an IndexError during pattern matching. Consider explicitly checking the has_biases argument (and/or len(wrapper.args) >= 4) and returning empty=True anchors when biases are not present, since the Cadence quantized_w8a32_gru replacement requires bias tensors.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1257,9 +1257,8 @@ def quantized_w8a32_gru( | |||||||||||||||||||||||||||||||||||||||
| weights_hidden: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||
| w_h_scale: float, | ||||||||||||||||||||||||||||||||||||||||
| bias_inputs: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||
| b_i_scale: float, | ||||||||||||||||||||||||||||||||||||||||
| b_scale: float, | ||||||||||||||||||||||||||||||||||||||||
| bias_hidden: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||
| b_h_scale: float, | ||||||||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||
| assert weights_inputs.dtype == torch.int8 | ||||||||||||||||||||||||||||||||||||||||
| assert weights_hidden.dtype == torch.int8 | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1288,10 +1287,8 @@ def quantized_w8a32_gru( | |||||||||||||||||||||||||||||||||||||||
| dequant_weights_inputs = weights_inputs.float() * w_i_scale | ||||||||||||||||||||||||||||||||||||||||
| dequant_weights_hidden = weights_hidden.float() * w_h_scale | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # C++ implementation averages the two bias scales | ||||||||||||||||||||||||||||||||||||||||
| avg_bias_scale = (b_i_scale + b_h_scale) / 2 | ||||||||||||||||||||||||||||||||||||||||
| dequant_bias_inputs = bias_inputs.float() * avg_bias_scale | ||||||||||||||||||||||||||||||||||||||||
| dequant_bias_hidden = bias_hidden.float() * avg_bias_scale | ||||||||||||||||||||||||||||||||||||||||
| dequant_bias_inputs = bias_inputs.float() * b_scale | ||||||||||||||||||||||||||||||||||||||||
| dequant_bias_hidden = bias_hidden.float() * b_scale | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs) | ||||||||||||||||||||||||||||||||||||||||
| gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden) | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1310,8 +1307,14 @@ def quantized_w8a32_gru( | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| assert new_hidden.shape == original_hidden_shape | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| new_hidden = new_hidden.view(original_hidden_shape) | ||||||||||||||||||||||||||||||||||||||||
| return torch.stack([new_hidden, new_hidden], dim=0) | ||||||||||||||||||||||||||||||||||||||||
| batch_size = inputs.shape[0] | ||||||||||||||||||||||||||||||||||||||||
| input_dim = inputs.shape[1] | ||||||||||||||||||||||||||||||||||||||||
| hidden_dim = hidden.shape[-1] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| new_hidden_expanded = new_hidden.unsqueeze(1).expand( | ||||||||||||||||||||||||||||||||||||||||
| batch_size, input_dim, hidden_dim | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1310
to
1319
|
||||||||||||||||||||||||||||||||||||||||
| batch_size = inputs.shape[0] | |
| input_dim = inputs.shape[1] | |
| hidden_dim = hidden.shape[-1] | |
| new_hidden_expanded = new_hidden.unsqueeze(1).expand(batch_size, input_dim, hidden_dim) | |
| return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0) | |
| if new_hidden.dim() == 1: | |
| new_hidden_normalized = new_hidden.unsqueeze(0).unsqueeze(0) | |
| elif new_hidden.dim() == 2: | |
| new_hidden_normalized = new_hidden.unsqueeze(1) | |
| elif new_hidden.dim() == 3: | |
| new_hidden_normalized = new_hidden | |
| else: | |
| raise ValueError( | |
| f"Hidden state must be 1D, 2D, or 3D, got shape {tuple(new_hidden.shape)}" | |
| ) | |
| return torch.stack([new_hidden_normalized, new_hidden_normalized], dim=0) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2901,9 +2901,8 @@ def test_softmax_f32_f32(self) -> None: | |
| torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 (3*4 x 4) | ||
| 0.1, # w_h_scale | ||
| torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 | ||
| 0.1, # b_i_scale | ||
| 0.1, # b_scale | ||
| torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 | ||
| 0.1, # b_h_scale | ||
| ), | ||
| ( | ||
| "invalid_batch_size_2", | ||
|
|
@@ -2918,9 +2917,8 @@ def test_softmax_f32_f32(self) -> None: | |
| torch.ones((12, 4), dtype=torch.int8), # weights_hidden: 12x4 | ||
| 0.1, # w_h_scale | ||
| torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 | ||
| 0.1, # b_i_scale | ||
| 0.1, # b_scale | ||
| torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 | ||
| 0.1, # b_h_scale | ||
| ), | ||
| ( | ||
| "non_zero_biases", | ||
|
|
@@ -2933,11 +2931,10 @@ def test_softmax_f32_f32(self) -> None: | |
| torch.tensor( | ||
| [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 | ||
| ), # bias_inputs: 12 | ||
| 0.1, # b_i_scale | ||
| 0.1, # b_scale | ||
| torch.tensor( | ||
| [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int8 | ||
| ), # bias_hidden: 12 | ||
| 0.1, # b_h_scale | ||
| ), | ||
| ( | ||
| "negative_weights", | ||
|
|
@@ -2954,9 +2951,8 @@ def test_softmax_f32_f32(self) -> None: | |
| ), # weights_hidden: 12x4 (alternating pattern) | ||
| 0.1, # w_h_scale | ||
| torch.zeros(12, dtype=torch.int8), # bias_inputs: 12 | ||
| 0.1, # b_i_scale | ||
| 0.1, # b_scale | ||
| torch.zeros(12, dtype=torch.int8), # bias_hidden: 12 | ||
| 0.1, # b_h_scale | ||
| ), | ||
| ( | ||
| "hidden_dim_8", | ||
|
|
@@ -2969,9 +2965,8 @@ def test_softmax_f32_f32(self) -> None: | |
| torch.ones((24, 8), dtype=torch.int8), # weights_hidden: 24x8 (3*8 x 8) | ||
| 0.1, # w_h_scale | ||
| torch.zeros(24, dtype=torch.int8), # bias_inputs: 24 | ||
| 0.1, # b_i_scale | ||
| 0.1, # b_scale | ||
| torch.zeros(24, dtype=torch.int8), # bias_hidden: 24 | ||
| 0.1, # b_h_scale | ||
| ), | ||
| ] | ||
| ) | ||
|
|
@@ -2985,9 +2980,8 @@ def test_quantized_w8a32_gru( | |
| weights_hidden: torch.Tensor, | ||
| w_h_scale: float, | ||
| bias_inputs: torch.Tensor, | ||
| b_i_scale: float, | ||
| b_scale: float, | ||
| bias_hidden: torch.Tensor, | ||
| b_h_scale: float, | ||
| ) -> None: | ||
|
|
||
| if name == "invalid_batch_size_2": | ||
|
|
@@ -3000,9 +2994,8 @@ def test_quantized_w8a32_gru( | |
| weights_hidden, | ||
| w_h_scale, | ||
| bias_inputs, | ||
| b_i_scale, | ||
| b_scale, | ||
| bias_hidden, | ||
| b_h_scale, | ||
| ) | ||
| self.assertIn( | ||
| "Leading dimension 0 of hidden state must be 1", str(context.exception) | ||
|
|
@@ -3017,9 +3010,8 @@ def test_quantized_w8a32_gru( | |
| weights_hidden, | ||
| w_h_scale, | ||
| bias_inputs, | ||
| b_i_scale, | ||
| b_scale, | ||
| bias_hidden, | ||
| b_h_scale, | ||
| ) | ||
|
|
||
| # Verify output properties | ||
|
|
@@ -3028,10 +3020,11 @@ def test_quantized_w8a32_gru( | |
| torch.float32, | ||
| f"Output dtype should be float32 in {name}", | ||
| ) | ||
| expected_shape = (2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]) | ||
| self.assertEqual( | ||
| output.shape, | ||
| (2, *hidden.shape), | ||
| f"Output shape should match {(2, *hidden.shape)} in {name}", | ||
| expected_shape, | ||
| f"Output shape should match {expected_shape} in {name}", | ||
|
Comment on lines
+3023
to
+3027
|
||
| ) | ||
|
Comment on lines
+3023
to
3028
|
||
| assert isinstance(output, torch.Tensor) | ||
|
|
||
|
|
@@ -3064,7 +3057,6 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None: | |
| bias_inputs, | ||
| 0.1, | ||
| bias_hidden, | ||
| 0.1, | ||
| ) | ||
|
|
||
| self.assertIn( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantized_w8a32_gru_metaindexesinputs.shape[1]without first validatinginputsrank, which can raise anIndexError(and produce a confusing failure) if the op is ever invoked with a 1D/2D input. Add an explicit shape/rank check (e.g.,inputs.dim() == 3/hidden.dim() == 3) before readingshape[1], and then assertseq_len == 1with a clear message.