Fix quantized_w8a32_gru output shape to preserve batch dimensions#17291
Fix quantized_w8a32_gru output shape to preserve batch dimensions#17291ethansfng wants to merge 1 commit intopytorch:mainfrom
Conversation
Summary: The meta function for `quantized_w8a32_gru` returned shape `(2, hidden_dim)`, dropping the batch dimensions of the hidden state. After the fusion pass replaces `aten.gru.input` with the custom op, `getitem[0]` on the output produced shape `(hidden_dim,)` instead of the original `(batch, ..., hidden_dim)`, causing a shape mismatch with the reference output and NaN in the comparison. Fix the meta function to return `(2, *hidden.shape)` to preserve all dimensions. Also update the reference implementation to handle 3D hidden states and preserve the original shape in the output, consistent with the shape fix applied to `quantized_w8a32_conv` in D84745967. Differential Revision: D92570517
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17291
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 09daf81 with merge base 7823792 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@ethansfng has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92570517. |
This PR needs a
|
Summary:
The meta function for
quantized_w8a32_grureturned shape(2, hidden_dim),dropping the batch dimensions of the hidden state. After the fusion pass replaces
aten.gru.inputwith the custom op,getitem[0]on the output produced shape(hidden_dim,)instead of the original(batch, ..., hidden_dim), causing ashape mismatch with the reference output and NaN in the comparison.
Fix the meta function to return
(2, *hidden.shape)to preserve all dimensions.Also update the reference implementation to handle 3D hidden states and preserve
the original shape in the output, consistent with the shape fix applied to
quantized_w8a32_convin D84745967.Differential Revision: D92570517