Optimize aten::min/max.dim with TopK op#2780
Optimize aten::min/max.dim with TopK op#2780danielhumanmod wants to merge 6 commits intomicrosoft:mainfrom
Conversation
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2780 +/- ##
==========================================
+ Coverage 70.46% 70.72% +0.25%
==========================================
Files 228 230 +2
Lines 27258 27443 +185
Branches 2761 2757 -4
==========================================
+ Hits 19208 19409 +201
+ Misses 7100 7096 -4
+ Partials 950 938 -12 ☔ View full report in Codecov by Sentry. |
Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.
So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation. |
|
I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass) |
Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback! |
|
Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks! |
There was a problem hiding this comment.
Pull request overview
Adds a new ONNXScript rewriter rule to fuse Reduce{Max,Min} + Arg{Max,Min} patterns into a single TopK (plus optional Squeeze), aiming to improve performance for torch.min/max(dim=...)-style graphs.
Changes:
- Introduces
FuseReduce{Max,Min}Arg{Max,Min}ToTopKrewrite rules and aRewriteRuleSet. - Adds extensive unit tests covering success and failure conditions across opset 13 and 18.
- Validates numerical equivalence and serialized-model correctness for rewritten graphs.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py |
Implements the Reduce+Arg → TopK fusion rules for both max and min cases. |
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py |
Adds unit tests for the new fusion rules, including opset and attribute/input variants. |
|
|
||
| # Check that the error message is the expected one | ||
| tracer_match = tracer.best_matches_map[rule][0] | ||
| self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) |
There was a problem hiding this comment.
tracer_match.status is a MatchStatus enum, but the test compares tracer_match.status.value (an int) to MatchStatus.CONDITION_FAILED (an enum). This will fail. Compare tracer_match.status directly to MatchStatus.CONDITION_FAILED, or compare .value to MatchStatus.CONDITION_FAILED.value.
| self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) | |
| self.assertEqual(tracer_match.status, MatchStatus.CONDITION_FAILED) |
| # Opset 18+: axes is the second input | ||
| axes_input = reduce_node.inputs[1] | ||
| axes_const_value = axes_input.const_value | ||
| if axes_const_value is None: | ||
| return check_result.fail( | ||
| f"{self.reduce_op_type} axes input is not a constant." | ||
| ) | ||
| try: | ||
| axes_array = axes_const_value.numpy() | ||
| axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)] | ||
| except Exception: | ||
| return check_result.fail(f"Cannot parse {self.reduce_op_type} axes input.") |
There was a problem hiding this comment.
For opset 18+, this checks axes_input.const_value directly. In this codebase, values produced by Constant nodes may have const_value=None unless constant-propagation/folding has been run, which would cause the rule to miss valid fusions (e.g., ReduceMax axes coming from a Constant op). Consider using ir.convenience.get_const_tensor(axes_input) / onnxscript.rewriter._ir_utils.get_numpy_value() to reliably extract constant axes, and then flatten to a 1-D list.
| from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet | ||
|
|
||
|
|
||
| class FuseReduceArgToTopKBase(RewriteRuleClassBase): |
There was a problem hiding this comment.
| class FuseReduceArgToTopKBase(RewriteRuleClassBase): | |
| class _FuseReduceArgToTopKBase(RewriteRuleClassBase): |
| arg_node = arg_idx.producer() | ||
|
|
||
| if reduce_node is None or arg_node is None: | ||
| return check_result.fail("Cannot find producer nodes.") |
There was a problem hiding this comment.
Could this happen? I thought these two are originated from the pattern?
There was a problem hiding this comment.
We can remove this, I am a little bit too cautious here
| reduce_keepdims = ( | ||
| reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1 | ||
| ) | ||
| arg_keepdims = arg_keepdims_attr.as_int() if arg_keepdims_attr is not None else 1 |
|
|
||
| # ONNX default: keepdims = 1 for both Reduce and Arg operations | ||
| reduce_keepdims = ( | ||
| reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1 |
There was a problem hiding this comment.
@justinchuby is there a more clean way to get the default? (just curious)
There was a problem hiding this comment.
I think there should be a get_int method
|
|
||
| @property | ||
| @abstractmethod | ||
| def arg_op_type(self) -> str: |
There was a problem hiding this comment.
I wonder why not we just use node.op_type to get this property in the check? I think Node class is pretty well-defined already.
There was a problem hiding this comment.
Thanks for the feedback, yah that would be a better & cleaner option
| ) | ||
|
|
||
| # Step 3: Get axes from Reduce operation | ||
| # In opset 18+, axes is an input; in opset 13-17, it's an attribute |
There was a problem hiding this comment.
I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?
There was a problem hiding this comment.
That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!
There was a problem hiding this comment.
Only opset 18+ is fine
| input_x = reduce_node.inputs[0] | ||
| rank = len(input_x.shape) if input_x.shape is not None else None | ||
|
|
||
| def normalize_axis(axis: int, rank: int | None) -> int: |
There was a problem hiding this comment.
Moving this out and making it a private function would be preferrable.
|
|
||
| # Step 7: Normalize axes if rank is known (handle negative indices) | ||
| input_x = reduce_node.inputs[0] | ||
| rank = len(input_x.shape) if input_x.shape is not None else None |
There was a problem hiding this comment.
I wonder if symbolic shape could work on this case? @justinchuby
|
You will have to enable it here: |
I don’t think we want to enable this by default. It is unclear if this is generally more performant |
Fix pytorch/pytorch#76344
Context
As mentioned in the issue,
torch.max(dim=...)can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).In additional, given the
torch.min(dim=...)has the similar pattern with max, I also apply this optimization to it.Verification
Successfully passed existing OpInfo consistency tests: