Skip to content

Optimize aten::min/max.dim with TopK op#2780

Open
danielhumanmod wants to merge 12 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim
Open

Optimize aten::min/max.dim with TopK op#2780
danielhumanmod wants to merge 12 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim

Conversation

@danielhumanmod
Copy link
Copy Markdown

Fix pytorch/pytorch#76344

Context

As mentioned in the issue, torch.max(dim=...) can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).

In additional, given the torch.min(dim=...) has the similar pattern with max, I also apply this optimization to it.

Verification

Successfully passed existing OpInfo consistency tests:

  • pytest tests/function_libs/torch_lib/ops_test.py
  • pytest tests/function_libs/torch_lib/e2e_ops_tests.py

@danielhumanmod
Copy link
Copy Markdown
Author

@danielhumanmod please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

@codecov
Copy link
Copy Markdown

codecov bot commented Jan 25, 2026

Codecov Report

❌ Patch coverage is 95.76271% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.24%. Comparing base (e6f79e1) to head (e84dc89).
⚠️ Report is 29 commits behind head on main.

Files with missing lines Patch % Lines
.../rewriter/rules/common/_fuse_reduce_arg_to_topk.py 90.36% 5 Missing and 3 partials ⚠️
...iter/rules/common/_fuse_reduce_arg_to_topk_test.py 98.68% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2780      +/-   ##
==========================================
+ Coverage   70.52%   72.24%   +1.72%     
==========================================
  Files         228      241      +13     
  Lines       27135    29541    +2406     
  Branches     2727     2898     +171     
==========================================
+ Hits        19137    21343    +2206     
- Misses       7065     7219     +154     
- Partials      933      979      +46     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

@github-project-automation github-project-automation bot moved this from Todo to In Progress in ONNX Script Review Board Jan 25, 2026
@danielhumanmod
Copy link
Copy Markdown
Author

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.

  1. From ONNX runtime perspective,

    1. CPU EP provide a fastline when k = 1, which performs a simple linear scan. So on CPU, it seems to behave identically to a fused max+argmax.
    2. CUDA EP will walk through the whole Bitonic/Radix sort process, which can involve more complex instructions. But the upside is that these operations happen primarily in shared memory.
  2. PyTorch Inductor (as an reference): it adopts a similar approach—splitting into reduce_max/arg_max in IR—but leaves it to the runtime (Scheduler) to fuse them. However, when I checked ONNX Runtime, it didn't seem to have an optimization rule to automatically fuse ReduceMax and ArgMax, which implies the split approach effectively incurs one more IO pass compared to TopK

So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation.

@justinchuby
Copy link
Copy Markdown
Collaborator

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

@danielhumanmod
Copy link
Copy Markdown
Author

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback!

@danielhumanmod
Copy link
Copy Markdown
Author

Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new ONNXScript rewriter rule to fuse Reduce{Max,Min} + Arg{Max,Min} patterns into a single TopK (plus optional Squeeze), aiming to improve performance for torch.min/max(dim=...)-style graphs.

Changes:

  • Introduces FuseReduce{Max,Min}Arg{Max,Min}ToTopK rewrite rules and a RewriteRuleSet.
  • Adds extensive unit tests covering success and failure conditions across opset 13 and 18.
  • Validates numerical equivalence and serialized-model correctness for rewritten graphs.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py Implements the Reduce+Arg → TopK fusion rules for both max and min cases.
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py Adds unit tests for the new fusion rules, including opset and attribute/input variants.

)

# Step 3: Get axes from Reduce operation
# In opset 18+, axes is an input; in opset 13-17, it's an attribute
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only opset 18+ is fine

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @danielhumanmod Can you add a NOTE/comment somewhere that says the rule is only for opset 18+. Since now it's not for default rewrite rules, it could be used standalone for other users.


# Step 7: Normalize axes if rank is known (handle negative indices)
input_x = reduce_node.inputs[0]
rank = len(input_x.shape) if input_x.shape is not None else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if symbolic shape could work on this case? @justinchuby

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping none of shape means this does not support dynamic at the moment. But symbolic inference should be able to handle the eq

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh actually it is a very good catch, I will use shape.rank() instead to ensure to support both static and symbolic shape, thanks a bunch!

@titaiwangms
Copy link
Copy Markdown
Contributor

You will have to enable it here:

_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (

@justinchuby
Copy link
Copy Markdown
Collaborator

justinchuby commented Feb 7, 2026

You will have to enable it here:

_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (

I don’t think we want to enable this by default. It is unclear if this is generally more performant. @danielhumanmod you may simply expose the rule in https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/rules/common/__init__.py

@danielhumanmod
Copy link
Copy Markdown
Author

Hey team I solved all the pending comments, appreciate if you could take another look when you have time, thanks! cc @justinchuby @titaiwangms

Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. an unblocking comment

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unblocking but I would need to review it more closely next week. Thanks

@danielhumanmod
Copy link
Copy Markdown
Author

Unblocking but I would need to review it more closely next week. Thanks

Hey @justinchuby gently ping this PR, appreciate if you could take another look when you have time, thanks!

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work on this PR — the base/subclass architecture mirrors _FuseBatchNormBase cleanly, the test suite is thorough with good failure-reason assertions, and the decision to expose-but-not-default is architecturally sound. Here are the remaining issues that need attention before merge:

🔴 Blocking

1. Rules not exported from common/__init__.py

@justinchuby explicitly asked to "expose the rule in onnxscript/rewriter/rules/common/__init__.py" — but the PR diff adds only two files and makes no change to any __init__.py. The rules are unreachable without importing from the private _fuse_reduce_arg_to_topk module directly, which violates the project convention that _* modules are internal.

Please add to common/__init__.py:

from onnxscript.rewriter.rules.common._fuse_reduce_arg_to_topk import (
    fuse_reduce_max_argmax_to_topk_rule,
    fuse_reduce_min_argmin_to_topk_rule,
)

(See naming note below on the fuse_ prefix.)

2. Missing dtype guard — bfloat16 inputs produce invalid graphs

ReduceMax/ReduceMin support bfloat16 in opset 18, but TopK does not. With no dtype check, the rule can fire on a bfloat16 model and emit an invalid TopK node that fails ONNX checker validation — a silent correctness regression.

Please add to check() after obtaining the input:

_TOPK_SUPPORTED_DTYPES = frozenset({
    ir.DataType.FLOAT16, ir.DataType.FLOAT, ir.DataType.DOUBLE,
    ir.DataType.INT8, ir.DataType.INT16, ir.DataType.INT32, ir.DataType.INT64,
    ir.DataType.UINT8, ir.DataType.UINT16, ir.DataType.UINT32, ir.DataType.UINT64,
})
if x.dtype is not None and x.dtype not in _TOPK_SUPPORTED_DTYPES:
    return check_result.fail(f"Input dtype {x.dtype} is not supported by TopK")

🟠 High

3. Open maintainer thread — opset 18+ NOTE still unresolved

@titaiwangms has an open (unresolved) review thread requesting a visible NOTE that these rules are opset 18+ only, since they can be used standalone. The module docstring mentions it in the Constraints: block, but users who encounter FuseReduceMaxArgMaxToTopK or FuseReduceMinArgMinToTopK in isolation won't see it. Please add a one-liner to each public class docstring, e.g.:

# NOTE: Requires opset 18+. Apply a version converter before using on older models.

This will also allow @titaiwangms to resolve that thread.

4. No explicit opset guard — only coincidental protection

The opset 18+ constraint is enforced incidentally (pre-18 ReduceMax puts axes in an attribute rather than an input, so the pattern simply fails to match). A manually constructed pre-18 model with axes-as-input could bypass this. Consider an explicit check, or at minimum document the coincidental guard in a comment.

🟡 Medium

5. except Exception too broad

The except Exception block in check() masks real bugs (MemoryError, unexpected AttributeError, etc.). Please narrow to except (ValueError, TypeError).

6. _normalize_axis defined but not used in rewrite()

rewrite() contains inline axis normalization logic that duplicates the _normalize_axis() static method on the same class. This is fragile — if either copy is updated, they silently diverge. Please have rewrite() call self._normalize_axis().

7. Rule naming breaks fuse_*_rule convention

Every rule in common/__init__.py uses a verb-first name: fuse_batchnorm_into_conv_rule, cast_cast_rule, collapse_slice_rule. The new rules are noun-first (reduce_max_argmax_to_topk_rule), which breaks autocomplete and grep discoverability for users searching fuse_*. Please rename to fuse_reduce_max_argmax_to_topk_rule and fuse_reduce_min_argmin_to_topk_rule.

🔵 Minor

  • select_last_index_attrselect_last_index: The _attr suffix implies an ir.Attr object; it holds an int. Should match the style of adjacent variables (reduce_keepdims, arg_axis, etc.).
  • Step-numbered comments: # Step 1# Step 7 (and # Step 2b) describe what the next line does, not why. The _fuse_batchnorm.py reference has zero step comments. Keep only the comments that explain non-obvious reasons (e.g., # ONNX default: axis=0 for ArgMax/ArgMin).
  • del context should be del context # Unused — matches _fuse_batchnorm.py style.
  • rng property creates a new RNG on every access: Properties are expected to return a consistent object. A module-level constant _RNG = np.random.default_rng(20260127) avoids the subtle behavior difference between self.rng.shuffle(a); self.rng.shuffle(b) vs rng = self.rng; rng.shuffle(a); rng.shuffle(b).
  • sorted=1 with k=1 is redundant: A single-element result is trivially sorted; omitting sorted (or using sorted=False) is cleaner.
  • TestFuseReduceMaxArgMaxToTopK missing class docstring — the Min test class has one; the Max class doesn't.
  • Linting: Several RUFF/format and EDITORCONFIG issues remain open from the automated scan. Please run lintrunner -a and commit the result.
  • Missing test coverage (8 lines per Codecov): the pre-opset-18 fallback path and scalar-axes edge case are not tested.

Overall: the logic is correct and the architecture is clean. The two blocking issues (missing export and bfloat16 guard) are small, targeted fixes. Happy to re-review once those are addressed.

Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addendum — one additional medium-severity finding:

8. Potential TypeError in rewrite()_normalize_axis not called

rewrite() computes x.shape.rank() + axis directly when axis < 0:

if axis < 0 and x.shape is not None:
    axis = x.shape.rank() + axis

If x.shape is not None but x.shape.rank() returns None (unknown-rank tensor with a non-None shape object), this crashes with a TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'. The _normalize_axis() static method on the same class already handles this case correctly by guarding on rank is not None. rewrite() should call self._normalize_axis(axis, x.shape.rank() if x.shape is not None else None) instead of duplicating the logic inline. This is both a DRY violation and a latent crash bug.

@justinchuby
Copy link
Copy Markdown
Collaborator

^ did a quick AI review. Suggestions seem reasonable.

@danielhumanmod
Copy link
Copy Markdown
Author

^ did a quick AI review. Suggestions seem reasonable.

Thanks for the feedback, will take a look these days!

@danielhumanmod
Copy link
Copy Markdown
Author

Hey @justinchuby I solve the feedback above, appreciate another review when you have a chacne, thanks!

@justinchuby justinchuby added this to the 0.6.0 milestone Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

[ONNX] Use topk to export max(dim,keepdim) to onnx

5 participants