Skip to content

Refactor rnn ops to op builders#4606

Draft
pfultz2 wants to merge 8 commits intodevelopfrom
rnn-op-builder
Draft

Refactor rnn ops to op builders#4606
pfultz2 wants to merge 8 commits intodevelopfrom
rnn-op-builder

Conversation

@pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Feb 12, 2026

Motivation

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
migraphx::make_op(
"gru",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a mistake, the test is for lstm, but we were using gru.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR migrates RNN-family ops (RNN/LSTM/GRU) from dedicated operators + rewrite_rnn pass to OpBuilder-based graph construction, updating ONNX parsing and verification tests accordingly and removing the legacy ops/pass.

Changes:

  • Add OpBuilder implementations for rnn, lstm, and gru (plus shared utilities) that directly expand into lower-level ops.
  • Update ONNX parsers and many verify tests to use migraphx::op::builder::add(...) results instead of make_op("rnn"/"lstm"/"gru") + rnn_last_* ops.
  • Remove rewrite_rnn from target pass pipelines and remove legacy RNN-family operators/headers/registrations; add new OpBuilder-focused tests.

Reviewed changes

Copilot reviewed 96 out of 97 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
test/verify/test_var_sl_gru_forward.cpp Switch GRU construction to OpBuilder and adjust return ordering to match builder outputs.
test/verify/test_var_sl_gru_bidirct.cpp Switch bidirectional GRU construction to OpBuilder.
test/verify/test_rnn_sql_2.cpp Switch RNN construction to OpBuilder; remove unused serialize include.
test/verify/test_rnn_sql_1_layout.cpp Switch RNN construction to OpBuilder and adapt layout handling to builder outputs.
test/verify/test_rnn_sql_1.cpp Switch RNN construction to OpBuilder; remove unused serialize include.
test/verify/test_rnn_reverse_layout.cpp Switch reverse RNN construction to OpBuilder and reuse builder output for layout ops.
test/verify/test_rnn_reverse2.cpp Switch reverse RNN construction to OpBuilder (no explicit return).
test/verify/test_rnn_reverse.cpp Switch reverse RNN construction to OpBuilder (no explicit return).
test/verify/test_rnn_forward_layout.cpp Switch forward RNN construction to OpBuilder and adapt layout handling.
test/verify/test_rnn_forward10.cpp Switch forward RNN construction to OpBuilder.
test/verify/test_rnn_forward.cpp Switch forward RNN construction to OpBuilder.
test/verify/test_rnn_bidirectional_layout.cpp Switch bidirectional RNN construction to OpBuilder and adapt last-output usage.
test/verify/test_rnn_bidirectional10.cpp Switch bidirectional RNN construction to OpBuilder (no explicit return).
test/verify/test_rnn_bidirectional.cpp Switch bidirectional RNN construction to OpBuilder (no explicit return).
test/verify/test_rnn_bi_3args_layout.cpp Switch 3-arg bidirectional RNN construction to OpBuilder and adapt last-output usage.
test/verify/test_rnn_bi_3args.cpp Switch 3-arg bidirectional RNN construction to OpBuilder.
test/verify/test_rnn_5args.cpp Switch 5-arg RNN construction to OpBuilder.
test/verify/test_rnn_4args_layout.cpp Switch 4-arg RNN construction to OpBuilder and adapt layout handling.
test/verify/test_rnn_4args.cpp Switch 4-arg RNN construction to OpBuilder.
test/verify/test_rnn_3args.cpp Switch 3-arg RNN construction to OpBuilder.
test/verify/test_lstm_two_outputs.cpp Switch LSTM construction to OpBuilder and return builder outputs.
test/verify/test_lstm_three_outputs_layout.cpp Switch LSTM construction to OpBuilder and adapt layout handling for 3 outputs.
test/verify/test_lstm_three_outputs.cpp Switch LSTM construction to OpBuilder and return 3 builder outputs.
test/verify/test_lstm_reverse_last.cpp Switch reverse LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_reverse_3args_layout.cpp Switch reverse LSTM construction to OpBuilder and adapt layout handling.
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp Switch reverse LSTM construction to OpBuilder and adapt cell-output usage.
test/verify/test_lstm_reverse_3args_cell_output.cpp Switch reverse LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_reverse_3args.cpp Switch reverse LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_forward_seq1.cpp Switch forward LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_forward_last_layout.cpp Switch forward LSTM construction to OpBuilder and adapt layout handling for last output.
test/verify/test_lstm_forward_last.cpp Switch forward LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_forward_hs_layout.cpp Switch forward LSTM construction to OpBuilder and adapt layout handling for hidden states.
test/verify/test_lstm_forward_hs.cpp Switch forward LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_forward_default_actv1.cpp Switch forward LSTM construction to OpBuilder with explicit default activation list.
test/verify/test_lstm_forward_default_actv.cpp Switch forward LSTM construction to OpBuilder with empty activation list.
test/verify/test_lstm_forward_3args_und.cpp Switch forward LSTM construction to OpBuilder with undefined placeholders.
test/verify/test_lstm_forward_3args.cpp Switch forward LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_bidirct_seq1.cpp Switch bidirectional LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_bidirct_last_layout.cpp Switch bidirectional LSTM construction to OpBuilder and adapt layout handling for last output.
test/verify/test_lstm_bidirct_last.cpp Switch bidirectional LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_bidirct_hs.cpp Switch bidirectional LSTM construction to OpBuilder (no explicit return).
test/verify/test_lstm_bidirct_default_actv2.cpp Switch bidirectional LSTM construction to OpBuilder with 2 default activations.
test/verify/test_lstm_bidirct_default_actv1.cpp Switch bidirectional LSTM construction to OpBuilder with 1 default activation.
test/verify/test_lstm_bidirct_default_actv.cpp Switch bidirectional LSTM construction to OpBuilder with empty activation list.
test/verify/test_lstm_bidirct_3args_und.cpp Switch GRU (file name indicates LSTM) to OpBuilder with undefined placeholders.
test/verify/test_lstm_bidirct_3args_layout.cpp Switch bidirectional LSTM construction to OpBuilder and adapt layout handling.
test/verify/test_lstm_bidirct_3args.cpp Switch bidirectional LSTM construction to OpBuilder (no explicit return).
test/verify/test_gru_two_outputs.cpp Switch GRU construction to OpBuilder and return builder outputs.
test/verify/test_gru_reverse_last_layout.cpp Switch reverse GRU construction to OpBuilder and adapt layout handling for last output.
test/verify/test_gru_reverse_last.cpp Switch reverse GRU construction to OpBuilder and return last output directly.
test/verify/test_gru_reverse_3args_layout.cpp Switch reverse GRU construction to OpBuilder and adapt layout handling.
test/verify/test_gru_reverse_3args.cpp Switch reverse GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_forward_seq1.cpp Switch forward GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_forward_layout.cpp Switch forward GRU construction to OpBuilder and adapt layout handling.
test/verify/test_gru_forward_default_actv1.cpp Switch forward GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_forward_default_actv.cpp Switch forward GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_forward_3args_und.cpp Switch forward GRU construction to OpBuilder with undefined placeholders.
test/verify/test_gru_forward_3args_layout.cpp Switch forward GRU construction to OpBuilder and adapt layout handling.
test/verify/test_gru_forward_3args.cpp Switch forward GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_forward.cpp Switch forward GRU construction to OpBuilder and preserve return ordering.
test/verify/test_gru_bidirct_seq1.cpp Switch bidirectional GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_bidirct_layout.cpp Switch bidirectional GRU construction to OpBuilder and adapt layout handling.
test/verify/test_gru_bidirct_default_actv1.cpp Switch bidirectional GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_bidirct_default_actv.cpp Switch bidirectional GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_bidirct_3args_und.cpp Switch bidirectional GRU construction to OpBuilder with undefined placeholders.
test/verify/test_gru_bidirct_3args_layout.cpp Switch bidirectional GRU construction to OpBuilder and adapt layout handling.
test/verify/test_gru_bidirct_3args.cpp Switch bidirectional GRU construction to OpBuilder and return hidden states.
test/verify/test_gru_bidirct.cpp Switch bidirectional GRU construction to OpBuilder and return both outputs.
test/operators.cpp Remove legacy rnn operator test case (operator no longer present).
test/op/builder/rnn_builder_test.cpp Add OpBuilder tests for RNN output arity, shapes, and basic eval.
test/op/builder/lstm_builder_test.cpp Add OpBuilder tests for LSTM output arity, shapes, and basic eval.
test/op/builder/gru_builder_test.cpp Add OpBuilder tests for GRU output arity, shapes, and basic eval.
src/targets/ref/target.cpp Remove rewrite_rnn pass from ref target pipeline.
src/targets/gpu/target.cpp Remove rewrite_rnn pass from GPU target pipeline.
src/targets/fpga/target.cpp Remove rewrite_rnn pass from FPGA target pipeline.
src/targets/cpu/target.cpp Remove rewrite_rnn pass from CPU target pipeline.
src/quantization.cpp Remove rewrite_rnn from quantization pre-pass pipeline.
src/op/builder/rnn.cpp Add OpBuilder expansion for vanilla RNN into lower-level ops + var-seq-len handling.
src/op/builder/lstm.cpp Add OpBuilder expansion for LSTM into lower-level ops + var-seq-len handling.
src/op/builder/gru.cpp Add OpBuilder expansion for GRU into lower-level ops + var-seq-len handling.
src/op/builder/include/migraphx/op/builder/rnn_utils.hpp Add shared utilities for seq-len handling, padding, and var-seq-len shifting.
src/onnx/parse_rnn.cpp Parse ONNX RNN via OpBuilder and consume returned outputs.
src/onnx/parse_lstm.cpp Parse ONNX LSTM via OpBuilder and consume returned outputs.
src/onnx/parse_gru.cpp Parse ONNX GRU via OpBuilder and consume returned outputs.
src/include/migraphx/rewrite_rnn.hpp Remove legacy rewrite pass API header.
src/include/migraphx/operators.hpp Stop including legacy rnn/gru/lstm + last-output operator headers.
src/include/migraphx/op/rnn_last_hs_output.hpp Remove legacy rnn_last_hs_output operator.
src/include/migraphx/op/rnn_last_cell_output.hpp Remove legacy rnn_last_cell_output operator.
src/include/migraphx/op/rnn.hpp Remove legacy rnn operator.
src/include/migraphx/op/lstm.hpp Remove legacy lstm operator.
src/include/migraphx/op/gru.hpp Remove legacy gru operator.
src/driver/passes.cpp Remove rewrite_rnn from driver pass lookup.
src/CMakeLists.txt Stop building/registering legacy rnn/gru/lstm operators and rewrite_rnn.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 293 to 311
struct lstm_builder : op_builder<lstm_builder>
{
static std::vector<std::string> names() { return {"lstm"}; }

std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{};
op::rnn_direction direction = op::rnn_direction::forward;
float clip = 0.0f;
int input_forget = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.input_forget, "input_forget"));
}
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_size and clip are reflected on lstm_builder, but hidden_size is not validated/used (the implementation derives hs from R), and clip is also unused. This can mask invalid inputs (attribute/weight mismatch) and makes the public builder API misleading. Consider validating these attributes against the provided tensors and either applying clip/hidden_size semantics or throwing when they are inconsistent/non-default.

Copilot uses AI. Check for mistakes.
Comment on lines 297 to 311
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{};
op::rnn_direction direction = op::rnn_direction::forward;
float clip = 0.0f;
int input_forget = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.input_forget, "input_forget"));
}
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_forget is parsed and reflected on lstm_builder, but it is never used in the LSTM computation. As a result, ONNX models that set input_forget=1 will behave the same as 0. Consider either implementing the input-forget coupling in the cell logic, or explicitly rejecting/ignoring non-zero values with a clear error to avoid silent misbehavior.

Copilot uses AI. Check for mistakes.
Comment on lines 152 to 168
struct rnn_builder : op_builder<rnn_builder>
{
static std::vector<std::string> names() { return {"rnn"}; }

std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{};
op::rnn_direction direction = op::rnn_direction::forward;
float clip = 0.0f;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_size (and clip) are accepted as attributes on rnn_builder, but the implementation never uses them or validates them against the provided weight shapes. This can hide malformed inputs (e.g., ONNX attribute/weight mismatch) that previously would have been caught. Consider validating hidden_size and direction against W/R shapes and either using hidden_size for shape decisions or throwing on mismatch.

Copilot uses AI. Check for mistakes.
Comment on lines 235 to 253
struct gru_builder : op_builder<gru_builder>
{
static std::vector<std::string> names() { return {"gru"}; }

std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{};
op::rnn_direction direction = op::rnn_direction::forward;
float clip = 0.0f;
int linear_before_reset = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hidden_size (and clip) are reflected on gru_builder, but the builder derives hs from the input shapes and does not validate/consume the hidden_size attribute. This makes invalid graphs harder to diagnose (attribute/weight mismatch) and can lead to silent shape inconsistencies. Consider validating hidden_size against R/W and erroring on mismatch (or removing the unused attribute).

Copilot uses AI. Check for mistakes.
@codecov
Copy link

codecov bot commented Feb 12, 2026

Codecov Report

❌ Patch coverage is 99.71751% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/op/builder/lstm.cpp 99.25% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4606      +/-   ##
===========================================
- Coverage    92.21%   92.19%   -0.02%     
===========================================
  Files          567      564       -3     
  Lines        27897    27807      -90     
===========================================
- Hits         25724    25636      -88     
+ Misses        2173     2171       -2     
Files with missing lines Coverage Δ
src/include/migraphx/serialize.hpp 92.45% <100.00%> (+0.15%) ⬆️
src/onnx/parse_gru.cpp 95.45% <100.00%> (-0.26%) ⬇️
src/onnx/parse_lstm.cpp 94.74% <100.00%> (-0.20%) ⬇️
src/onnx/parse_rnn.cpp 95.24% <100.00%> (-0.15%) ⬇️
src/op/builder/gru.cpp 100.00% <100.00%> (ø)
.../builder/include/migraphx/op/builder/rnn_utils.hpp 100.00% <100.00%> (ø)
src/op/builder/rnn.cpp 100.00% <100.00%> (ø)
src/quantization.cpp 85.33% <100.00%> (ø)
src/targets/ref/target.cpp 100.00% <ø> (ø)
src/op/builder/lstm.cpp 99.25% <99.25%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant