Skip to content

[JAX] Add warning if using BSHD and max_segments_per_seq > 1#2796

Merged
jberchtold-nvidia merged 7 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-max-segments-per-seq-warning-bshd
Mar 30, 2026
Merged

[JAX] Add warning if using BSHD and max_segments_per_seq > 1#2796
jberchtold-nvidia merged 7 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-max-segments-per-seq-warning-bshd

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

Adds a small warning if the user tries to use BSHD with max_segments_per_seq > 1

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Adds a warning if the user tries to use BSHD with max_segments_per_seq > 1
  • Adds a new test to validate this warning is shown correctly

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR adds a UserWarning to fused_attn in transformer_engine/jax/attention.py when the caller sets max_segments_per_seq > 1 with a non-THD (BSHD) layout, since sequence packing is only meaningful for THD formats. The change is small, self-contained, and the warning logic — guarded by not qkv_layout.is_thd() — correctly covers all BSHD variants (BS3HD, BSHD_BS2HD, BSHD_BSHD_BSHD). The two issues flagged in earlier review rounds (stacklevel omission and the duplicate if causing IndentationError) are both resolved in the current state.

Key observations:

  • Warning is correctly placed in the non-legacy code path (after the early-return for deprecated mask usage), so it won't fire spuriously for callers using the old mask-based API.
  • stacklevel=2 is present, directing the warning to the caller's site rather than library internals.
  • The PR description and checklist claim a test class (TestMaxSegmentsPerSeqWarning) was added to tests/jax/test_fused_attn.py, but it was removed in a subsequent commit (6d7ad99) with no stated reason, leaving the new warning path without automated test coverage.

Confidence Score: 5/5

Safe to merge — the one-line logic change is correct and non-breaking; all remaining feedback is P2 quality/documentation concerns.

All identified issues are P2 (missing test coverage, inaccurate PR checklist). No runtime defects, security issues, or logic errors remain. Previous P0/P1 issues (duplicate if / missing stacklevel) are resolved.

No files require special attention for merge safety, but tests/jax/test_fused_attn.py should ideally have the removed test reinstated.

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Adds a correctly placed UserWarning (with stacklevel=2) when max_segments_per_seq > 1 is used with a non-THD layout; prior review issues (duplicate if, missing stacklevel) are resolved.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_attn(qkv, ..., max_segments_per_seq, qkv_layout)"]
    B{"sequence_descriptor is None\nor isinstance(jnp.ndarray)?"}
    C["warnings.warn(DeprecationWarning)\n+ raise ValueError if max_segments_per_seq != 1"]
    D["_legacy_fused_attn(...)  return"]
    E{"max_segments_per_seq > 1\nAND NOT qkv_layout.is_thd()?"}
    F["warnings.warn(UserWarning, stacklevel=2)\n'max_segments_per_seq only applies to THD'"]
    G["_fused_attn(...) → output  return"]

    A --> B
    B -- yes --> C --> D
    B -- no --> E
    E -- yes --> F --> G
    E -- no --> G
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into jberchtold/te-m..." | Re-trigger Greptile

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

)


class TestMaxSegmentsPerSeqWarning:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing this warning out, Jeremy !
I think negative testing is crucial in such cases..
However, I'm thinking that merging a test case for just checking a warning might be overkill.
We've got quite a few warnings in TE but we do not test them (from at least what I'm aware of) so I'd suggest we drop the test and just keep the warning change in fused attn.
Local negative testing to ensure that the warning gets triggered, followed by a passing CI for fused attn tests should be enough, I think . Thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good to me, I'll remove the new test

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

jberchtold-nvidia and others added 3 commits March 30, 2026 10:41
Co-authored-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@KshitijLakhani KshitijLakhani self-requested a review March 30, 2026 17:46
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !
Good to merge post successful CI runs for fused attn
Thanks !

@jberchtold-nvidia jberchtold-nvidia merged commit f4debf6 into NVIDIA:main Mar 30, 2026
11 of 14 checks passed
@jberchtold-nvidia jberchtold-nvidia deleted the jberchtold/te-max-segments-per-seq-warning-bshd branch March 30, 2026 23:36
pstjohn pushed a commit to pstjohn/TransformerEngine that referenced this pull request Mar 31, 2026
…2796)

* Add warning if using BSHD and max_segments_per_seq > 1

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/jax/attention.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update transformer_engine/jax/attention.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Remove warning test

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

---------

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants