Conversation
| auto und = mm->add_instruction(migraphx::make_op("undefined")); | ||
| mm->add_instruction( | ||
| migraphx::make_op( | ||
| "gru", |
There was a problem hiding this comment.
This was a mistake, the test is for lstm, but we were using gru.
There was a problem hiding this comment.
Pull request overview
This PR migrates RNN-family ops (RNN/LSTM/GRU) from dedicated operators + rewrite_rnn pass to OpBuilder-based graph construction, updating ONNX parsing and verification tests accordingly and removing the legacy ops/pass.
Changes:
- Add OpBuilder implementations for
rnn,lstm, andgru(plus shared utilities) that directly expand into lower-level ops. - Update ONNX parsers and many verify tests to use
migraphx::op::builder::add(...)results instead ofmake_op("rnn"/"lstm"/"gru")+rnn_last_*ops. - Remove
rewrite_rnnfrom target pass pipelines and remove legacy RNN-family operators/headers/registrations; add new OpBuilder-focused tests.
Reviewed changes
Copilot reviewed 96 out of 97 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| test/verify/test_var_sl_gru_forward.cpp | Switch GRU construction to OpBuilder and adjust return ordering to match builder outputs. |
| test/verify/test_var_sl_gru_bidirct.cpp | Switch bidirectional GRU construction to OpBuilder. |
| test/verify/test_rnn_sql_2.cpp | Switch RNN construction to OpBuilder; remove unused serialize include. |
| test/verify/test_rnn_sql_1_layout.cpp | Switch RNN construction to OpBuilder and adapt layout handling to builder outputs. |
| test/verify/test_rnn_sql_1.cpp | Switch RNN construction to OpBuilder; remove unused serialize include. |
| test/verify/test_rnn_reverse_layout.cpp | Switch reverse RNN construction to OpBuilder and reuse builder output for layout ops. |
| test/verify/test_rnn_reverse2.cpp | Switch reverse RNN construction to OpBuilder (no explicit return). |
| test/verify/test_rnn_reverse.cpp | Switch reverse RNN construction to OpBuilder (no explicit return). |
| test/verify/test_rnn_forward_layout.cpp | Switch forward RNN construction to OpBuilder and adapt layout handling. |
| test/verify/test_rnn_forward10.cpp | Switch forward RNN construction to OpBuilder. |
| test/verify/test_rnn_forward.cpp | Switch forward RNN construction to OpBuilder. |
| test/verify/test_rnn_bidirectional_layout.cpp | Switch bidirectional RNN construction to OpBuilder and adapt last-output usage. |
| test/verify/test_rnn_bidirectional10.cpp | Switch bidirectional RNN construction to OpBuilder (no explicit return). |
| test/verify/test_rnn_bidirectional.cpp | Switch bidirectional RNN construction to OpBuilder (no explicit return). |
| test/verify/test_rnn_bi_3args_layout.cpp | Switch 3-arg bidirectional RNN construction to OpBuilder and adapt last-output usage. |
| test/verify/test_rnn_bi_3args.cpp | Switch 3-arg bidirectional RNN construction to OpBuilder. |
| test/verify/test_rnn_5args.cpp | Switch 5-arg RNN construction to OpBuilder. |
| test/verify/test_rnn_4args_layout.cpp | Switch 4-arg RNN construction to OpBuilder and adapt layout handling. |
| test/verify/test_rnn_4args.cpp | Switch 4-arg RNN construction to OpBuilder. |
| test/verify/test_rnn_3args.cpp | Switch 3-arg RNN construction to OpBuilder. |
| test/verify/test_lstm_two_outputs.cpp | Switch LSTM construction to OpBuilder and return builder outputs. |
| test/verify/test_lstm_three_outputs_layout.cpp | Switch LSTM construction to OpBuilder and adapt layout handling for 3 outputs. |
| test/verify/test_lstm_three_outputs.cpp | Switch LSTM construction to OpBuilder and return 3 builder outputs. |
| test/verify/test_lstm_reverse_last.cpp | Switch reverse LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_reverse_3args_layout.cpp | Switch reverse LSTM construction to OpBuilder and adapt layout handling. |
| test/verify/test_lstm_reverse_3args_cell_output_layout.cpp | Switch reverse LSTM construction to OpBuilder and adapt cell-output usage. |
| test/verify/test_lstm_reverse_3args_cell_output.cpp | Switch reverse LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_reverse_3args.cpp | Switch reverse LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_forward_seq1.cpp | Switch forward LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_forward_last_layout.cpp | Switch forward LSTM construction to OpBuilder and adapt layout handling for last output. |
| test/verify/test_lstm_forward_last.cpp | Switch forward LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_forward_hs_layout.cpp | Switch forward LSTM construction to OpBuilder and adapt layout handling for hidden states. |
| test/verify/test_lstm_forward_hs.cpp | Switch forward LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_forward_default_actv1.cpp | Switch forward LSTM construction to OpBuilder with explicit default activation list. |
| test/verify/test_lstm_forward_default_actv.cpp | Switch forward LSTM construction to OpBuilder with empty activation list. |
| test/verify/test_lstm_forward_3args_und.cpp | Switch forward LSTM construction to OpBuilder with undefined placeholders. |
| test/verify/test_lstm_forward_3args.cpp | Switch forward LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_bidirct_seq1.cpp | Switch bidirectional LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_bidirct_last_layout.cpp | Switch bidirectional LSTM construction to OpBuilder and adapt layout handling for last output. |
| test/verify/test_lstm_bidirct_last.cpp | Switch bidirectional LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_bidirct_hs.cpp | Switch bidirectional LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_lstm_bidirct_default_actv2.cpp | Switch bidirectional LSTM construction to OpBuilder with 2 default activations. |
| test/verify/test_lstm_bidirct_default_actv1.cpp | Switch bidirectional LSTM construction to OpBuilder with 1 default activation. |
| test/verify/test_lstm_bidirct_default_actv.cpp | Switch bidirectional LSTM construction to OpBuilder with empty activation list. |
| test/verify/test_lstm_bidirct_3args_und.cpp | Switch GRU (file name indicates LSTM) to OpBuilder with undefined placeholders. |
| test/verify/test_lstm_bidirct_3args_layout.cpp | Switch bidirectional LSTM construction to OpBuilder and adapt layout handling. |
| test/verify/test_lstm_bidirct_3args.cpp | Switch bidirectional LSTM construction to OpBuilder (no explicit return). |
| test/verify/test_gru_two_outputs.cpp | Switch GRU construction to OpBuilder and return builder outputs. |
| test/verify/test_gru_reverse_last_layout.cpp | Switch reverse GRU construction to OpBuilder and adapt layout handling for last output. |
| test/verify/test_gru_reverse_last.cpp | Switch reverse GRU construction to OpBuilder and return last output directly. |
| test/verify/test_gru_reverse_3args_layout.cpp | Switch reverse GRU construction to OpBuilder and adapt layout handling. |
| test/verify/test_gru_reverse_3args.cpp | Switch reverse GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_forward_seq1.cpp | Switch forward GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_forward_layout.cpp | Switch forward GRU construction to OpBuilder and adapt layout handling. |
| test/verify/test_gru_forward_default_actv1.cpp | Switch forward GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_forward_default_actv.cpp | Switch forward GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_forward_3args_und.cpp | Switch forward GRU construction to OpBuilder with undefined placeholders. |
| test/verify/test_gru_forward_3args_layout.cpp | Switch forward GRU construction to OpBuilder and adapt layout handling. |
| test/verify/test_gru_forward_3args.cpp | Switch forward GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_forward.cpp | Switch forward GRU construction to OpBuilder and preserve return ordering. |
| test/verify/test_gru_bidirct_seq1.cpp | Switch bidirectional GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_bidirct_layout.cpp | Switch bidirectional GRU construction to OpBuilder and adapt layout handling. |
| test/verify/test_gru_bidirct_default_actv1.cpp | Switch bidirectional GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_bidirct_default_actv.cpp | Switch bidirectional GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_bidirct_3args_und.cpp | Switch bidirectional GRU construction to OpBuilder with undefined placeholders. |
| test/verify/test_gru_bidirct_3args_layout.cpp | Switch bidirectional GRU construction to OpBuilder and adapt layout handling. |
| test/verify/test_gru_bidirct_3args.cpp | Switch bidirectional GRU construction to OpBuilder and return hidden states. |
| test/verify/test_gru_bidirct.cpp | Switch bidirectional GRU construction to OpBuilder and return both outputs. |
| test/operators.cpp | Remove legacy rnn operator test case (operator no longer present). |
| test/op/builder/rnn_builder_test.cpp | Add OpBuilder tests for RNN output arity, shapes, and basic eval. |
| test/op/builder/lstm_builder_test.cpp | Add OpBuilder tests for LSTM output arity, shapes, and basic eval. |
| test/op/builder/gru_builder_test.cpp | Add OpBuilder tests for GRU output arity, shapes, and basic eval. |
| src/targets/ref/target.cpp | Remove rewrite_rnn pass from ref target pipeline. |
| src/targets/gpu/target.cpp | Remove rewrite_rnn pass from GPU target pipeline. |
| src/targets/fpga/target.cpp | Remove rewrite_rnn pass from FPGA target pipeline. |
| src/targets/cpu/target.cpp | Remove rewrite_rnn pass from CPU target pipeline. |
| src/quantization.cpp | Remove rewrite_rnn from quantization pre-pass pipeline. |
| src/op/builder/rnn.cpp | Add OpBuilder expansion for vanilla RNN into lower-level ops + var-seq-len handling. |
| src/op/builder/lstm.cpp | Add OpBuilder expansion for LSTM into lower-level ops + var-seq-len handling. |
| src/op/builder/gru.cpp | Add OpBuilder expansion for GRU into lower-level ops + var-seq-len handling. |
| src/op/builder/include/migraphx/op/builder/rnn_utils.hpp | Add shared utilities for seq-len handling, padding, and var-seq-len shifting. |
| src/onnx/parse_rnn.cpp | Parse ONNX RNN via OpBuilder and consume returned outputs. |
| src/onnx/parse_lstm.cpp | Parse ONNX LSTM via OpBuilder and consume returned outputs. |
| src/onnx/parse_gru.cpp | Parse ONNX GRU via OpBuilder and consume returned outputs. |
| src/include/migraphx/rewrite_rnn.hpp | Remove legacy rewrite pass API header. |
| src/include/migraphx/operators.hpp | Stop including legacy rnn/gru/lstm + last-output operator headers. |
| src/include/migraphx/op/rnn_last_hs_output.hpp | Remove legacy rnn_last_hs_output operator. |
| src/include/migraphx/op/rnn_last_cell_output.hpp | Remove legacy rnn_last_cell_output operator. |
| src/include/migraphx/op/rnn.hpp | Remove legacy rnn operator. |
| src/include/migraphx/op/lstm.hpp | Remove legacy lstm operator. |
| src/include/migraphx/op/gru.hpp | Remove legacy gru operator. |
| src/driver/passes.cpp | Remove rewrite_rnn from driver pass lookup. |
| src/CMakeLists.txt | Stop building/registering legacy rnn/gru/lstm operators and rewrite_rnn. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| struct lstm_builder : op_builder<lstm_builder> | ||
| { | ||
| static std::vector<std::string> names() { return {"lstm"}; } | ||
|
|
||
| std::size_t hidden_size = 1; | ||
| std::vector<operation> actv_funcs{}; | ||
| op::rnn_direction direction = op::rnn_direction::forward; | ||
| float clip = 0.0f; | ||
| int input_forget = 0; | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.hidden_size, "hidden_size"), | ||
| f(self.actv_funcs, "actv_func"), | ||
| f(self.direction, "direction"), | ||
| f(self.clip, "clip"), | ||
| f(self.input_forget, "input_forget")); | ||
| } |
There was a problem hiding this comment.
hidden_size and clip are reflected on lstm_builder, but hidden_size is not validated/used (the implementation derives hs from R), and clip is also unused. This can mask invalid inputs (attribute/weight mismatch) and makes the public builder API misleading. Consider validating these attributes against the provided tensors and either applying clip/hidden_size semantics or throwing when they are inconsistent/non-default.
src/op/builder/lstm.cpp
Outdated
| std::size_t hidden_size = 1; | ||
| std::vector<operation> actv_funcs{}; | ||
| op::rnn_direction direction = op::rnn_direction::forward; | ||
| float clip = 0.0f; | ||
| int input_forget = 0; | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.hidden_size, "hidden_size"), | ||
| f(self.actv_funcs, "actv_func"), | ||
| f(self.direction, "direction"), | ||
| f(self.clip, "clip"), | ||
| f(self.input_forget, "input_forget")); | ||
| } |
There was a problem hiding this comment.
input_forget is parsed and reflected on lstm_builder, but it is never used in the LSTM computation. As a result, ONNX models that set input_forget=1 will behave the same as 0. Consider either implementing the input-forget coupling in the cell logic, or explicitly rejecting/ignoring non-zero values with a clear error to avoid silent misbehavior.
| struct rnn_builder : op_builder<rnn_builder> | ||
| { | ||
| static std::vector<std::string> names() { return {"rnn"}; } | ||
|
|
||
| std::size_t hidden_size = 1; | ||
| std::vector<operation> actv_funcs{}; | ||
| op::rnn_direction direction = op::rnn_direction::forward; | ||
| float clip = 0.0f; | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.hidden_size, "hidden_size"), | ||
| f(self.actv_funcs, "actv_func"), | ||
| f(self.direction, "direction"), | ||
| f(self.clip, "clip")); | ||
| } |
There was a problem hiding this comment.
hidden_size (and clip) are accepted as attributes on rnn_builder, but the implementation never uses them or validates them against the provided weight shapes. This can hide malformed inputs (e.g., ONNX attribute/weight mismatch) that previously would have been caught. Consider validating hidden_size and direction against W/R shapes and either using hidden_size for shape decisions or throwing on mismatch.
| struct gru_builder : op_builder<gru_builder> | ||
| { | ||
| static std::vector<std::string> names() { return {"gru"}; } | ||
|
|
||
| std::size_t hidden_size = 1; | ||
| std::vector<operation> actv_funcs{}; | ||
| op::rnn_direction direction = op::rnn_direction::forward; | ||
| float clip = 0.0f; | ||
| int linear_before_reset = 0; | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.hidden_size, "hidden_size"), | ||
| f(self.actv_funcs, "actv_func"), | ||
| f(self.direction, "direction"), | ||
| f(self.clip, "clip"), | ||
| f(self.linear_before_reset, "linear_before_reset")); | ||
| } |
There was a problem hiding this comment.
hidden_size (and clip) are reflected on gru_builder, but the builder derives hs from the input shapes and does not validate/consume the hidden_size attribute. This makes invalid graphs harder to diagnose (attribute/weight mismatch) and can lead to silent shape inconsistencies. Consider validating hidden_size against R/W and erroring on mismatch (or removing the unused attribute).
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4606 +/- ##
===========================================
- Coverage 92.21% 92.19% -0.02%
===========================================
Files 567 564 -3
Lines 27897 27807 -90
===========================================
- Hits 25724 25636 -88
+ Misses 2173 2171 -2
🚀 New features to boost your workflow:
|
Motivation
Technical Details
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable