Conversation
Summary: #### Summary This diff fixes the Conv1d w8a32 operator by adding a transformation to the `val` attribute of the `other_inputs[0].meta` dictionary. Specifically, the `permute` operation is applied to the `original_val` tensor with the `fake_mode` context, and the resulting `transposed_val` is assigned to `transposed_inputs.meta["val"]`. Reviewed By: mcremon-meta Differential Revision: D89863750
Summary: # Context This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching. # Mitigation The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters. Reviewed By: hsharma35 Differential Revision: D90437262
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17226
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 Awaiting Approval, 4 New FailuresAs of commit 6f6ff6a with merge base 267a59d ( AWAITING APPROVAL - The following workflows need approval before CI can run:
NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This pull request fixes the GRU w8a32 operator by correcting the output shape in both the reference implementation and meta kernel, and enhancing pattern matching with safer parameter checks.
Changes:
- Fixed GRU w8a32 operator output shape from
(2, hidden_dim)to(2, batch, input_dim, hidden_dim)to properly reflect the expected dimensions - Enhanced pattern matching safety by using
.get()method instead of direct dictionary access for tensor metadata - Added SharedQuantizationSpec for GRU biases to ensure consistent quantization scales
- Added metadata propagation for transposed tensors in fusion pass
- Added input shape validation for conv operator
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/cadence/aot/tests/test_ref_implementations.py | Updated test expectations to match corrected output shape |
| backends/cadence/aot/ref_implementations.py | Fixed output shape calculation by expanding hidden state instead of flattening |
| backends/cadence/aot/quantizer/patterns.py | Added safer metadata access, input validation, and shared bias quantization spec |
| backends/cadence/aot/quantizer/fusion_pass.py | Added val metadata propagation for transposed inputs and weights |
| backends/cadence/aot/ops_registrations.py | Updated meta kernel to return correct output shape with improved documentation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| seq_len = inputs.shape[1] | ||
| assert seq_len == 1 | ||
| # inputs comes in shape [batch, seq_len, input_size] | ||
| # hidden comes in shape [batch, seq_len, hidden_size] | ||
| # weights_inputs comes in shape [3 * hidden_size, input_size] | ||
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | ||
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | ||
| # The first dimension stacks the output and the new hidden state | ||
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) |
There was a problem hiding this comment.
The assertion assert seq_len == 1 will fail with the 2D inputs used in the test cases. Looking at test_ref_implementations.py, the tests use 2D inputs like torch.tensor([[1.0, 2.0]], dtype=torch.float32) with shape [1, 2], where inputs.shape[1] is 2, not 1. Either the assertion logic needs to check if inputs is 3D before asserting seq_len == 1, or the comments describing the expected shapes are incorrect and should indicate 2D inputs [batch, input_size] instead of 3D [batch, seq_len, input_size].
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| # inputs comes in shape [batch, seq_len, input_size] | |
| # hidden comes in shape [batch, seq_len, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) | |
| # inputs may come in shape [batch, seq_len, input_size] or [batch, input_size] (for seq_len == 1) | |
| if inputs.dim() == 3: | |
| batch = inputs.shape[0] | |
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| elif inputs.dim() == 2: | |
| batch = inputs.shape[0] | |
| seq_len = 1 | |
| else: | |
| raise AssertionError(f"Unsupported inputs.dim() for quantized_w8a32_gru_meta: {inputs.dim()}") | |
| # hidden is expected to have hidden_size in its last dimension | |
| hidden_size = hidden.shape[-1] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty((2, batch, seq_len, hidden_size), dtype=torch.float32) |
| seq_len = inputs.shape[1] | ||
| assert seq_len == 1 | ||
| # inputs comes in shape [batch, seq_len, input_size] | ||
| # hidden comes in shape [batch, seq_len, hidden_size] | ||
| # weights_inputs comes in shape [3 * hidden_size, input_size] | ||
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | ||
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | ||
| # The first dimension stacks the output and the new hidden state | ||
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) |
There was a problem hiding this comment.
The comments describe inputs as 3D [batch, seq_len, input_size] and hidden as 3D [batch, seq_len, hidden_size], but the test cases in test_ref_implementations.py use 2D tensors [batch, input_size] and [batch, hidden_size]. The comments should be updated to reflect the actual expected shapes, or the implementation should be modified to match the documented shapes.
| seq_len = inputs.shape[1] | |
| assert seq_len == 1 | |
| # inputs comes in shape [batch, seq_len, input_size] | |
| # hidden comes in shape [batch, seq_len, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, seq_len, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| return hidden.new_empty((2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32) | |
| # inputs comes in shape [batch, input_size] | |
| # hidden comes in shape [batch, hidden_size] | |
| # weights_inputs comes in shape [3 * hidden_size, input_size] | |
| # weights_hidden comes in shape [3 * hidden_size, hidden_size] | |
| # output comes in empty with shape [2, batch, hidden_size] | |
| # The first dimension stacks the output and the new hidden state | |
| assert len(inputs.shape) == 2 | |
| assert len(hidden.shape) == 2 | |
| assert inputs.shape[0] == hidden.shape[0] | |
| return hidden.new_empty((2, inputs.shape[0], hidden.shape[-1]), dtype=torch.float32) |
Summary:
Context
This diff fixes the reference implementation of the w8a32 GRU operator and enhances the operator's pattern matching.
Mitigation
The reference implementation has now the right output dimension and pattern matching now uses a safer check for the operator parameters.
Reviewed By: hsharma35
Differential Revision: D90437262