Skip to content

[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836

Open
pedramr wants to merge 1 commit intoNVIDIA:mainfrom
pedramr:fix/sm103-flash-attn-allowlist
Open

[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836
pedramr wants to merge 1 commit intoNVIDIA:mainfrom
pedramr:fix/sm103-flash-attn-allowlist

Conversation

@pedramr
Copy link
Copy Markdown

@pedramr pedramr commented Apr 4, 2026

Description

The head_dim > 192 gate for FlashAttention 2 in get_attention_backend used an exact-match
compute capability allowlist: (8,0), (9,0), (10,0), (12,0). This excluded sm103 (B300/GB300),
sm89 (L40S/RTX 4090), sm86 (A40/RTX 3090), and other valid architectures where flash-attn
supports head_dim up to 256.

This PR replaces the allowlist with a >= sm80 range check, matching flash-attn's own gate:
Dao-AILab/flash-attention@bbb21d6

The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear
to be supported based on flash-attn's >= sm80 guarantee.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Changes

  • Replace exact-match compute capability allowlist with device_compute_capability < (8, 0) range check
  • Update debug log message from sm80/90/100+ to sm80+

…itectures

Replace the exact-match compute capability allowlist with a >= sm80 range
check, matching flash-attn's own gate:
Dao-AILab/flash-attention@bbb21d6

The allowlist ((8,0), (9,0), (10,0), (12,0)) missed sm103 (B300), sm89
(L40S), sm86 (A40), and others where FA2 supports head_dim up to 256.
The sm103 case was validated on hardware with head_dim=256; the remaining
architectures appear to be supported based on flash-attn's >= sm80 guarantee.

Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 4, 2026

Greptile Summary

This PR fixes a regression where FlashAttention 2 with head_dim > 192 was incorrectly disabled on valid sm80+ architectures (sm86, sm89, sm103, etc.) due to an exact-match compute-capability allowlist. The fix replaces the allowlist with a < (8, 0) range check, matching flash-attn's own gate.

  • The new condition on line 634 is technically dead code: an earlier filter at line 428 already sets use_flash_attention_2 = False for any device_compute_capability < (8, 0), so the sub-expression can never be True. The behaviour is correct, but the branch can be removed entirely for clarity.

Confidence Score: 5/5

Safe to merge — the behavioural fix is correct and only one minor dead-code style note remains.

The only finding is P2: the new device_compute_capability < (8, 0) sub-expression is unreachable because the earlier sm80 filter already guards it, but this does not affect correctness. The intended bug-fix (allowing head_dim > 192 on sm86, sm89, sm103, and other sm80+ architectures) is correctly achieved.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Replaces exact-match compute-capability allowlist {(8,0),(9,0),(10,0),(12,0)} with a < (8,0) range check for the head_dim > 192 FA2 guard; the condition is now effectively dead code (already covered by the earlier sm80 filter at line 428) but the resulting behaviour is correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend called] --> B{device_compute_capability < sm80?}
    B -- Yes --> C[use_flash_attention_2 = False\nline 431]
    B -- No --> D{use_flash_attention_2?}
    C --> END[continue with FA2 disabled]
    D -- No --> END
    D -- Yes --> E{head_dim_qk > 256\nor head_dim_qk pct 8 != 0\nor head_dim_qk > 192 AND cc < sm80?}
    E -- Third branch always False cc >= sm80 guaranteed here --> F[Only first two branches can disable FA2]
    E -- Yes first/second branch --> G[use_flash_attention_2 = False\nline 646]
    F --> H[FA2 stays enabled for head_dim 193-256 on sm80+ sm86 sm89 sm103 etc]
    G --> END
Loading

Reviews (1): Last reviewed commit: "[PyTorch] Fix FlashAttention 2 head_dim ..." | Re-trigger Greptile

head_dim_qk > 192
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
)
or (head_dim_qk > 192 and device_compute_capability < (8, 0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead code: condition is always False at this point

Lines 428–431 unconditionally set use_flash_attention_2 = False whenever device_compute_capability < (8, 0). By the time execution reaches line 634, use_flash_attention_2 can only be True if device_compute_capability >= (8, 0), so the sub-expression device_compute_capability < (8, 0) is never true and the entire third or branch is unreachable. The bug-fix intent is correct (no longer blocking head_dim > 192 on sm86/sm89/sm103), but the residual condition could be confusing to future readers who might believe it provides a meaningful guard.

Consider removing the dead branch entirely:

Suggested change
or (head_dim_qk > 192 and device_compute_capability < (8, 0))
or head_dim_qk % 8 != 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant