Optimize aten::min/max.dim with TopK op#2780
Optimize aten::min/max.dim with TopK op#2780danielhumanmod wants to merge 12 commits intomicrosoft:mainfrom
Conversation
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2780 +/- ##
==========================================
+ Coverage 70.52% 72.24% +1.72%
==========================================
Files 228 241 +13
Lines 27135 29541 +2406
Branches 2727 2898 +171
==========================================
+ Hits 19137 21343 +2206
- Misses 7065 7219 +154
- Partials 933 979 +46 ☔ View full report in Codecov by Sentry. |
Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.
So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation. |
|
I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass) |
Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback! |
|
Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks! |
There was a problem hiding this comment.
Pull request overview
Adds a new ONNXScript rewriter rule to fuse Reduce{Max,Min} + Arg{Max,Min} patterns into a single TopK (plus optional Squeeze), aiming to improve performance for torch.min/max(dim=...)-style graphs.
Changes:
- Introduces
FuseReduce{Max,Min}Arg{Max,Min}ToTopKrewrite rules and aRewriteRuleSet. - Adds extensive unit tests covering success and failure conditions across opset 13 and 18.
- Validates numerical equivalence and serialized-model correctness for rewritten graphs.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py |
Implements the Reduce+Arg → TopK fusion rules for both max and min cases. |
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py |
Adds unit tests for the new fusion rules, including opset and attribute/input variants. |
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py
Outdated
Show resolved
Hide resolved
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py
Outdated
Show resolved
Hide resolved
| ) | ||
|
|
||
| # Step 3: Get axes from Reduce operation | ||
| # In opset 18+, axes is an input; in opset 13-17, it's an attribute |
There was a problem hiding this comment.
I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?
There was a problem hiding this comment.
That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!
There was a problem hiding this comment.
Only opset 18+ is fine
There was a problem hiding this comment.
Sorry @danielhumanmod Can you add a NOTE/comment somewhere that says the rule is only for opset 18+. Since now it's not for default rewrite rules, it could be used standalone for other users.
|
|
||
| # Step 7: Normalize axes if rank is known (handle negative indices) | ||
| input_x = reduce_node.inputs[0] | ||
| rank = len(input_x.shape) if input_x.shape is not None else None |
There was a problem hiding this comment.
I wonder if symbolic shape could work on this case? @justinchuby
There was a problem hiding this comment.
Skipping none of shape means this does not support dynamic at the moment. But symbolic inference should be able to handle the eq
There was a problem hiding this comment.
Ohh actually it is a very good catch, I will use shape.rank() instead to ensure to support both static and symbolic shape, thanks a bunch!
|
You will have to enable it here: |
I don’t think we want to enable this by default. It is unclear if this is generally more performant. @danielhumanmod you may simply expose the rule in https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/rules/common/__init__.py |
|
Hey team I solved all the pending comments, appreciate if you could take another look when you have time, thanks! cc @justinchuby @titaiwangms |
titaiwangms
left a comment
There was a problem hiding this comment.
Thank you. an unblocking comment
justinchuby
left a comment
There was a problem hiding this comment.
Unblocking but I would need to review it more closely next week. Thanks
Hey @justinchuby gently ping this PR, appreciate if you could take another look when you have time, thanks! |
justinchuby
left a comment
There was a problem hiding this comment.
Thanks for the work on this PR — the base/subclass architecture mirrors _FuseBatchNormBase cleanly, the test suite is thorough with good failure-reason assertions, and the decision to expose-but-not-default is architecturally sound. Here are the remaining issues that need attention before merge:
🔴 Blocking
1. Rules not exported from common/__init__.py
@justinchuby explicitly asked to "expose the rule in onnxscript/rewriter/rules/common/__init__.py" — but the PR diff adds only two files and makes no change to any __init__.py. The rules are unreachable without importing from the private _fuse_reduce_arg_to_topk module directly, which violates the project convention that _* modules are internal.
Please add to common/__init__.py:
from onnxscript.rewriter.rules.common._fuse_reduce_arg_to_topk import (
fuse_reduce_max_argmax_to_topk_rule,
fuse_reduce_min_argmin_to_topk_rule,
)(See naming note below on the fuse_ prefix.)
2. Missing dtype guard — bfloat16 inputs produce invalid graphs
ReduceMax/ReduceMin support bfloat16 in opset 18, but TopK does not. With no dtype check, the rule can fire on a bfloat16 model and emit an invalid TopK node that fails ONNX checker validation — a silent correctness regression.
Please add to check() after obtaining the input:
_TOPK_SUPPORTED_DTYPES = frozenset({
ir.DataType.FLOAT16, ir.DataType.FLOAT, ir.DataType.DOUBLE,
ir.DataType.INT8, ir.DataType.INT16, ir.DataType.INT32, ir.DataType.INT64,
ir.DataType.UINT8, ir.DataType.UINT16, ir.DataType.UINT32, ir.DataType.UINT64,
})
if x.dtype is not None and x.dtype not in _TOPK_SUPPORTED_DTYPES:
return check_result.fail(f"Input dtype {x.dtype} is not supported by TopK")🟠 High
3. Open maintainer thread — opset 18+ NOTE still unresolved
@titaiwangms has an open (unresolved) review thread requesting a visible NOTE that these rules are opset 18+ only, since they can be used standalone. The module docstring mentions it in the Constraints: block, but users who encounter FuseReduceMaxArgMaxToTopK or FuseReduceMinArgMinToTopK in isolation won't see it. Please add a one-liner to each public class docstring, e.g.:
# NOTE: Requires opset 18+. Apply a version converter before using on older models.
This will also allow @titaiwangms to resolve that thread.
4. No explicit opset guard — only coincidental protection
The opset 18+ constraint is enforced incidentally (pre-18 ReduceMax puts axes in an attribute rather than an input, so the pattern simply fails to match). A manually constructed pre-18 model with axes-as-input could bypass this. Consider an explicit check, or at minimum document the coincidental guard in a comment.
🟡 Medium
5. except Exception too broad
The except Exception block in check() masks real bugs (MemoryError, unexpected AttributeError, etc.). Please narrow to except (ValueError, TypeError).
6. _normalize_axis defined but not used in rewrite()
rewrite() contains inline axis normalization logic that duplicates the _normalize_axis() static method on the same class. This is fragile — if either copy is updated, they silently diverge. Please have rewrite() call self._normalize_axis().
7. Rule naming breaks fuse_*_rule convention
Every rule in common/__init__.py uses a verb-first name: fuse_batchnorm_into_conv_rule, cast_cast_rule, collapse_slice_rule. The new rules are noun-first (reduce_max_argmax_to_topk_rule), which breaks autocomplete and grep discoverability for users searching fuse_*. Please rename to fuse_reduce_max_argmax_to_topk_rule and fuse_reduce_min_argmin_to_topk_rule.
🔵 Minor
select_last_index_attr→select_last_index: The_attrsuffix implies anir.Attrobject; it holds anint. Should match the style of adjacent variables (reduce_keepdims,arg_axis, etc.).- Step-numbered comments:
# Step 1…# Step 7(and# Step 2b) describe what the next line does, not why. The_fuse_batchnorm.pyreference has zero step comments. Keep only the comments that explain non-obvious reasons (e.g.,# ONNX default: axis=0 for ArgMax/ArgMin). del contextshould bedel context # Unused— matches_fuse_batchnorm.pystyle.rngproperty creates a new RNG on every access: Properties are expected to return a consistent object. A module-level constant_RNG = np.random.default_rng(20260127)avoids the subtle behavior difference betweenself.rng.shuffle(a); self.rng.shuffle(b)vsrng = self.rng; rng.shuffle(a); rng.shuffle(b).sorted=1withk=1is redundant: A single-element result is trivially sorted; omittingsorted(or usingsorted=False) is cleaner.TestFuseReduceMaxArgMaxToTopKmissing class docstring — theMintest class has one; theMaxclass doesn't.- Linting: Several RUFF/format and EDITORCONFIG issues remain open from the automated scan. Please run
lintrunner -aand commit the result. - Missing test coverage (8 lines per Codecov): the pre-opset-18 fallback path and scalar-axes edge case are not tested.
Overall: the logic is correct and the architecture is clean. The two blocking issues (missing export and bfloat16 guard) are small, targeted fixes. Happy to re-review once those are addressed.
justinchuby
left a comment
There was a problem hiding this comment.
Addendum — one additional medium-severity finding:
8. Potential TypeError in rewrite() — _normalize_axis not called
rewrite() computes x.shape.rank() + axis directly when axis < 0:
if axis < 0 and x.shape is not None:
axis = x.shape.rank() + axisIf x.shape is not None but x.shape.rank() returns None (unknown-rank tensor with a non-None shape object), this crashes with a TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'. The _normalize_axis() static method on the same class already handles this case correctly by guarding on rank is not None. rewrite() should call self._normalize_axis(axis, x.shape.rank() if x.shape is not None else None) instead of duplicating the logic inline. This is both a DRY violation and a latent crash bug.
|
^ did a quick AI review. Suggestions seem reasonable. |
Thanks for the feedback, will take a look these days! |
|
Hey @justinchuby I solve the feedback above, appreciate another review when you have a chacne, thanks! |
Fix pytorch/pytorch#76344
Context
As mentioned in the issue,
torch.max(dim=...)can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).In additional, given the
torch.min(dim=...)has the similar pattern with max, I also apply this optimization to it.Verification
Successfully passed existing OpInfo consistency tests: