Skip to content

Optimize aten::min/max.dim with TopK op#2780

Open
danielhumanmod wants to merge 6 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim
Open

Optimize aten::min/max.dim with TopK op#2780
danielhumanmod wants to merge 6 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim

Conversation

@danielhumanmod
Copy link

Fix pytorch/pytorch#76344

Context

As mentioned in the issue, torch.max(dim=...) can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).

In additional, given the torch.min(dim=...) has the similar pattern with max, I also apply this optimization to it.

Verification

Successfully passed existing OpInfo consistency tests:

  • pytest tests/function_libs/torch_lib/ops_test.py
  • pytest tests/function_libs/torch_lib/e2e_ops_tests.py

@danielhumanmod
Copy link
Author

@danielhumanmod please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

@codecov
Copy link

codecov bot commented Jan 25, 2026

Codecov Report

❌ Patch coverage is 96.24060% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.72%. Comparing base (e06dd92) to head (301635e).
⚠️ Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
.../rewriter/rules/common/_fuse_reduce_arg_to_topk.py 91.30% 6 Missing and 2 partials ⚠️
...iter/rules/common/_fuse_reduce_arg_to_topk_test.py 98.85% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2780      +/-   ##
==========================================
+ Coverage   70.46%   70.72%   +0.25%     
==========================================
  Files         228      230       +2     
  Lines       27258    27443     +185     
  Branches     2761     2757       -4     
==========================================
+ Hits        19208    19409     +201     
+ Misses       7100     7096       -4     
+ Partials      950      938      -12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

@github-project-automation github-project-automation bot moved this from Todo to In Progress in ONNX Script Review Board Jan 25, 2026
@danielhumanmod
Copy link
Author

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.

  1. From ONNX runtime perspective,

    1. CPU EP provide a fastline when k = 1, which performs a simple linear scan. So on CPU, it seems to behave identically to a fused max+argmax.
    2. CUDA EP will walk through the whole Bitonic/Radix sort process, which can involve more complex instructions. But the upside is that these operations happen primarily in shared memory.
  2. PyTorch Inductor (as an reference): it adopts a similar approach—splitting into reduce_max/arg_max in IR—but leaves it to the runtime (Scheduler) to fuse them. However, when I checked ONNX Runtime, it didn't seem to have an optimization rule to automatically fuse ReduceMax and ArgMax, which implies the split approach effectively incurs one more IO pass compared to TopK

So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation.

@justinchuby
Copy link
Collaborator

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

@danielhumanmod
Copy link
Author

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback!

@danielhumanmod
Copy link
Author

Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new ONNXScript rewriter rule to fuse Reduce{Max,Min} + Arg{Max,Min} patterns into a single TopK (plus optional Squeeze), aiming to improve performance for torch.min/max(dim=...)-style graphs.

Changes:

  • Introduces FuseReduce{Max,Min}Arg{Max,Min}ToTopK rewrite rules and a RewriteRuleSet.
  • Adds extensive unit tests covering success and failure conditions across opset 13 and 18.
  • Validates numerical equivalence and serialized-model correctness for rewritten graphs.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py Implements the Reduce+Arg → TopK fusion rules for both max and min cases.
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py Adds unit tests for the new fusion rules, including opset and attribute/input variants.


# Check that the error message is the expected one
tracer_match = tracer.best_matches_map[rule][0]
self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tracer_match.status is a MatchStatus enum, but the test compares tracer_match.status.value (an int) to MatchStatus.CONDITION_FAILED (an enum). This will fail. Compare tracer_match.status directly to MatchStatus.CONDITION_FAILED, or compare .value to MatchStatus.CONDITION_FAILED.value.

Suggested change
self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED)
self.assertEqual(tracer_match.status, MatchStatus.CONDITION_FAILED)

Copilot uses AI. Check for mistakes.
Comment on lines 118 to 129
# Opset 18+: axes is the second input
axes_input = reduce_node.inputs[1]
axes_const_value = axes_input.const_value
if axes_const_value is None:
return check_result.fail(
f"{self.reduce_op_type} axes input is not a constant."
)
try:
axes_array = axes_const_value.numpy()
axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)]
except Exception:
return check_result.fail(f"Cannot parse {self.reduce_op_type} axes input.")
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For opset 18+, this checks axes_input.const_value directly. In this codebase, values produced by Constant nodes may have const_value=None unless constant-propagation/folding has been run, which would cause the rule to miss valid fusions (e.g., ReduceMax axes coming from a Constant op). Consider using ir.convenience.get_const_tensor(axes_input) / onnxscript.rewriter._ir_utils.get_numpy_value() to reliably extract constant axes, and then flatten to a 1-D list.

Copilot uses AI. Check for mistakes.
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet


class FuseReduceArgToTopKBase(RewriteRuleClassBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FuseReduceArgToTopKBase(RewriteRuleClassBase):
class _FuseReduceArgToTopKBase(RewriteRuleClassBase):

arg_node = arg_idx.producer()

if reduce_node is None or arg_node is None:
return check_result.fail("Cannot find producer nodes.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this happen? I thought these two are originated from the pattern?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this, I am a little bit too cautious here

reduce_keepdims = (
reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1
)
arg_keepdims = arg_keepdims_attr.as_int() if arg_keepdims_attr is not None else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting


# ONNX default: keepdims = 1 for both Reduce and Arg operations
reduce_keepdims = (
reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby is there a more clean way to get the default? (just curious)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be a get_int method


@property
@abstractmethod
def arg_op_type(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why not we just use node.op_type to get this property in the check? I think Node class is pretty well-defined already.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, yah that would be a better & cleaner option

)

# Step 3: Get axes from Reduce operation
# In opset 18+, axes is an input; in opset 13-17, it's an attribute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only opset 18+ is fine

input_x = reduce_node.inputs[0]
rank = len(input_x.shape) if input_x.shape is not None else None

def normalize_axis(axis: int, rank: int | None) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this out and making it a private function would be preferrable.


# Step 7: Normalize axes if rank is known (handle negative indices)
input_x = reduce_node.inputs[0]
rank = len(input_x.shape) if input_x.shape is not None else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if symbolic shape could work on this case? @justinchuby

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate?

@titaiwangms
Copy link
Contributor

You will have to enable it here:

_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (

@justinchuby
Copy link
Collaborator

You will have to enable it here:

_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (

I don’t think we want to enable this by default. It is unclear if this is generally more performant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

[ONNX] Use topk to export max(dim,keepdim) to onnx

3 participants