[JAX] Add warning if using BSHD and max_segments_per_seq > 1#2796
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax |
Greptile SummaryThis PR adds a Key observations:
Confidence Score: 5/5Safe to merge — the one-line logic change is correct and non-breaking; all remaining feedback is P2 quality/documentation concerns. All identified issues are P2 (missing test coverage, inaccurate PR checklist). No runtime defects, security issues, or logic errors remain. Previous P0/P1 issues (duplicate if / missing stacklevel) are resolved. No files require special attention for merge safety, but tests/jax/test_fused_attn.py should ideally have the removed test reinstated. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_attn(qkv, ..., max_segments_per_seq, qkv_layout)"]
B{"sequence_descriptor is None\nor isinstance(jnp.ndarray)?"}
C["warnings.warn(DeprecationWarning)\n+ raise ValueError if max_segments_per_seq != 1"]
D["_legacy_fused_attn(...) return"]
E{"max_segments_per_seq > 1\nAND NOT qkv_layout.is_thd()?"}
F["warnings.warn(UserWarning, stacklevel=2)\n'max_segments_per_seq only applies to THD'"]
G["_fused_attn(...) → output return"]
A --> B
B -- yes --> C --> D
B -- no --> E
E -- yes --> F --> G
E -- no --> G
Reviews (5): Last reviewed commit: "Merge branch 'main' into jberchtold/te-m..." | Re-trigger Greptile |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
|
/te-ci jax |
tests/jax/test_fused_attn.py
Outdated
| ) | ||
|
|
||
|
|
||
| class TestMaxSegmentsPerSeqWarning: |
There was a problem hiding this comment.
Thanks for testing this warning out, Jeremy !
I think negative testing is crucial in such cases..
However, I'm thinking that merging a test case for just checking a warning might be overkill.
We've got quite a few warnings in TE but we do not test them (from at least what I'm aware of) so I'd suggest we drop the test and just keep the warning change in fused attn.
Local negative testing to ensure that the warning gets triggered, followed by a passing CI for fused attn tests should be enough, I think . Thoughts ?
There was a problem hiding this comment.
That sounds good to me, I'll remove the new test
Co-authored-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci jax |
KshitijLakhani
left a comment
There was a problem hiding this comment.
LGTM !
Good to merge post successful CI runs for fused attn
Thanks !
…2796) * Add warning if using BSHD and max_segments_per_seq > 1 Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/jax/attention.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * Update transformer_engine/jax/attention.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * Remove warning test Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Description
Adds a small warning if the user tries to use BSHD with
max_segments_per_seq > 1Type of change
Changes
max_segments_per_seq > 1Checklist: