Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
174 commits
Select commit Hold shift + click to select a range
e0ae107
initial implementation for mxfp8
cyanguwa Jan 31, 2026
23434b5
semi-working FP8; broken F16
cyanguwa Feb 4, 2026
dbb68b8
clean up last commit
cyanguwa Feb 4, 2026
c627231
comment out F16 pass
cyanguwa Feb 4, 2026
d27a267
Merge branch 'NVIDIA:main' into mxfp8_fwd
cyanguwa Feb 6, 2026
3f3b9e6
pull in grouped_quantize for MXFP8
cyanguwa Feb 6, 2026
850b16e
grouped tensor - pytorch
cyanguwa Feb 7, 2026
46f2eb1
quantize mxfp8
cyanguwa Feb 7, 2026
e86207c
fix shapes/strides
cyanguwa Feb 10, 2026
4e854d5
fix unfused; clean up
cyanguwa Feb 12, 2026
cd06398
split d to d_qk/d_v; attempt at bwd
cyanguwa Feb 13, 2026
d2a63a1
merge main
cyanguwa Feb 13, 2026
730a472
fix last merge
cyanguwa Feb 14, 2026
d9ff566
update FE
cyanguwa Feb 14, 2026
2b264d7
attempt at SWA/MLA
cyanguwa Feb 14, 2026
2008bed
remove prints
cyanguwa Feb 14, 2026
239f58a
remove leftover prints
cyanguwa Feb 14, 2026
f44a775
Revert "update FE"
cyanguwa Feb 14, 2026
965572b
update FE
cyanguwa Feb 14, 2026
91025c7
fix MLA O strides; add bottom_right_diagonal
cyanguwa Feb 17, 2026
d655e7e
attempt at bwd
cyanguwa Feb 18, 2026
a4ab691
fix get_quantizers; attempt at bwd
cyanguwa Feb 19, 2026
a85070d
fix fprop; add o_format
cyanguwa Feb 20, 2026
8909b35
attempt at bwd with o_format/d_out_format/dqkv_layout
cyanguwa Feb 20, 2026
90a636c
fix dtype/o_format/etc in bwd calls
cyanguwa Feb 21, 2026
8c72dea
fix generateMatrixStridesWithFormats and _v1; fix padding for mxfp8
cyanguwa Feb 21, 2026
5f23edd
fix upon last commit for paddedsizes
cyanguwa Feb 21, 2026
18c5580
add mxfp8 env var
cyanguwa Feb 21, 2026
6847645
disable FA for mxfp8
cyanguwa Feb 21, 2026
c5a98d5
add mha test
cyanguwa Feb 21, 2026
7e61ecd
attempt at bwd; force determinism; fix shapes
cyanguwa Feb 24, 2026
6d468da
remove prints
cyanguwa Feb 26, 2026
9f8e856
update FE
cyanguwa Feb 26, 2026
facef79
update FE from pre-merge branch to post-merge develop
cyanguwa Feb 26, 2026
fd33cca
allow MXFP8 linear + f16 attn
cyanguwa Feb 26, 2026
5079d55
test cp a2a
cyanguwa Feb 27, 2026
06b7d49
remove prints temporarily
cyanguwa Feb 27, 2026
7fbe399
test cp p2p
cyanguwa Feb 27, 2026
aa05a2a
minor fixes for mla
cyanguwa Feb 28, 2026
00e6693
open up a2a for mla
cyanguwa Feb 28, 2026
b8d28ce
test ag
cyanguwa Feb 28, 2026
d6ecadc
tweaks for last commit
cyanguwa Feb 28, 2026
3ac48cd
enable mla ag
cyanguwa Mar 1, 2026
169ae8a
merge main
cyanguwa Mar 1, 2026
5d4fa5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2026
81c18fa
fix merge
cyanguwa Mar 1, 2026
1f14f2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2026
ccebe77
fix merge
cyanguwa Mar 1, 2026
c52c5f4
revert to main grouped tensor impl
cyanguwa Mar 1, 2026
5b776ec
minor tweaks to return to main
cyanguwa Mar 1, 2026
4eee2bc
remove prints
cyanguwa Mar 3, 2026
8500121
fix combine_and_quantize for f16
cyanguwa Mar 3, 2026
0c2c466
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
6744aee
minor tweaks
cyanguwa Mar 3, 2026
4cec878
tweak tests
cyanguwa Mar 3, 2026
5c8e939
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2026
7b6b364
fix ds descale_o
cyanguwa Mar 3, 2026
462eb4f
Revert "fix ds descale_o"
cyanguwa Mar 3, 2026
77995d2
minor fixes for p2p and ag
cyanguwa Mar 7, 2026
586b698
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2026
1e7cd70
tweak cp test skips
cyanguwa Mar 7, 2026
6d7766a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2026
6d33db8
update FE
cyanguwa Mar 11, 2026
92e6aac
fix bwd KV tensors
cyanguwa Mar 12, 2026
3cb6f0e
tweak recipe control and backend selection
cyanguwa Mar 12, 2026
c57ece4
tweak quantizer logic
cyanguwa Mar 12, 2026
87a7e1e
minor fixes after last two commits
cyanguwa Mar 13, 2026
3b015f3
improve generate strides
cyanguwa Mar 13, 2026
6717e1a
minor fixes for previous commit
cyanguwa Mar 13, 2026
c918b9d
fix bwd for current/delayed
cyanguwa Mar 13, 2026
af60216
tweak test configs
cyanguwa Mar 13, 2026
6ac41d2
fix dO/dO_f16 strides
cyanguwa Mar 13, 2026
0a0722f
fix tests: SWA logic/test configs
cyanguwa Mar 13, 2026
89b44f8
fix ag
cyanguwa Mar 13, 2026
7c0ba7f
add fp8 sink attn
cyanguwa Mar 13, 2026
e68f785
fix a2a comm for F16
cyanguwa Mar 14, 2026
ae53980
remove nan/inf print in test
cyanguwa Mar 14, 2026
4b314e7
fix fa a2a
cyanguwa Mar 14, 2026
4b5d623
fix fa a2a+p2p f16
cyanguwa Mar 14, 2026
fdab7db
update FE to include new fixes
cyanguwa Mar 16, 2026
39b57e9
fix thd for bwd
cyanguwa Mar 17, 2026
dc49479
refactor a2a for fu/fa
cyanguwa Mar 17, 2026
dea59e4
update FE to fix d64
cyanguwa Mar 17, 2026
9da8ec9
refactor ag
cyanguwa Mar 17, 2026
a250b20
refactor p2p/a2a+p2p; mostly regarding shapes
cyanguwa Mar 18, 2026
630545e
add shadow f16 fwd
cyanguwa Mar 18, 2026
a78ea9a
update FE to fix SWA/BRCM
cyanguwa Mar 18, 2026
59eff74
switch to GH FE temporarily
cyanguwa Mar 19, 2026
6472e66
merge main
cyanguwa Mar 19, 2026
1691747
switch back to GL FE
cyanguwa Mar 19, 2026
d41eca3
update FE to latest commit
cyanguwa Mar 19, 2026
e0b65a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
e51ec9f
update group tensor usage after merge main
cyanguwa Mar 19, 2026
7bb40d5
env vars for qdq(q,k), o_f16 tests
cyanguwa Mar 19, 2026
29c2f4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
c10f05c
allow other recipes than mxfp8
cyanguwa Mar 19, 2026
773c678
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
0ef408b
fix grouped tensor for MLA
cyanguwa Mar 19, 2026
4429e58
change cp test configs
cyanguwa Mar 19, 2026
08af36b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
4dd1418
add shadow f16 bwd
cyanguwa Mar 19, 2026
ad4d4da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
f2266f4
fix a2a+p2p for sbhd
cyanguwa Mar 20, 2026
1674b0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2026
712d4f9
fix last commit and causal flag for fa
cyanguwa Mar 20, 2026
f9463e2
enable fp8 sink and disable fp8_mha
cyanguwa Mar 21, 2026
299bc63
minor cleanup for cp/non-cp
cyanguwa Mar 21, 2026
ed62903
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2026
94ae209
update FE for FP8 sink
cyanguwa Mar 23, 2026
a9028b2
fix TE for FP8 sink
cyanguwa Mar 23, 2026
a6f56e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2026
706095f
temporary: random sink/print sink
cyanguwa Mar 23, 2026
4c004ee
Revert "temporary: random sink/print sink"
cyanguwa Mar 23, 2026
e023d3b
replace d_out_format with do_format
cyanguwa Mar 24, 2026
7577919
fix compare_and_assert for None cases
cyanguwa Mar 24, 2026
f0b1e2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
ee388e5
remove logic for b and simplify logic for dqkv types
cyanguwa Mar 24, 2026
cacc59d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
de82fe1
minor fix for ndim_q/kv
cyanguwa Mar 24, 2026
706012a
add explanation of fp8_output/grad in MHA
cyanguwa Mar 24, 2026
746010e
tidy up FP8 checks for bhsd/learnable
cyanguwa Mar 24, 2026
2283081
remove leading underscores in nvte_convert_qkv_format
cyanguwa Mar 24, 2026
e693e6f
simplify logic in generateMatrixStridesWithLayout
cyanguwa Mar 24, 2026
edf1b2a
clean up strides/ifelse-recipe logic
cyanguwa Mar 24, 2026
09b21ee
tweak checks in utils.py
cyanguwa Mar 24, 2026
49a54c0
tweak UnfusedDPA
cyanguwa Mar 24, 2026
e5d49d2
enable testing for ag+swa and disable fp8_mha
cyanguwa Mar 24, 2026
2c63d83
tweak FusedAttn, fp8/f16 tensor naming/docstring
cyanguwa Mar 24, 2026
7f62b98
replace d_out_format with do_format
cyanguwa Mar 24, 2026
4b9240c
fix lint
cyanguwa Mar 24, 2026
2a21a3a
clean up a2a
cyanguwa Mar 25, 2026
a18cd7c
clean up ag
cyanguwa Mar 25, 2026
a19ccb3
clean up p2p/a2a+p2p
cyanguwa Mar 25, 2026
4ba2ef5
tweak test configs
cyanguwa Mar 25, 2026
875931c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
f0bf680
qdq dO in bwd shadow f16 path
cyanguwa Mar 25, 2026
2d80d38
tweak qdq dO logic
cyanguwa Mar 25, 2026
0cf9738
remove prints in shadow paths
cyanguwa Mar 25, 2026
813d39d
update FE to allow non-determinism
cyanguwa Mar 25, 2026
bdc0c47
fuse qkv transposes; first pass
cyanguwa Mar 26, 2026
e69a06a
remap parallelism to grid(bh, splits, 3) block(s/splits x d); use nve…
cyanguwa Mar 26, 2026
aab8856
allocate contiguous block for qkv
cyanguwa Mar 26, 2026
78055e4
fix grouped tensor row/col scale_inv offsets
cyanguwa Mar 26, 2026
d8f9ac9
use fused permute kernels
cyanguwa Mar 26, 2026
ca53769
quantize row/col as needed in fwd/bwd, non-cp/cp
cyanguwa Mar 27, 2026
f19e852
Revert "quantize row/col as needed in fwd/bwd, non-cp/cp"
cyanguwa Mar 27, 2026
2d403f9
Reapply "quantize row/col as needed in fwd/bwd, non-cp/cp"
cyanguwa Mar 27, 2026
f9e4e20
fix v_col format when row is quantized
cyanguwa Mar 27, 2026
fde366a
add back necessary bwd quants for shadow paths/cp a2a
cyanguwa Mar 27, 2026
81f723d
remove ZInv for all layouts except T3HD
cyanguwa Mar 27, 2026
89daa49
fix cp p2p with zinv
cyanguwa Mar 28, 2026
60740fa
temporarily switch to GH FE main
cyanguwa Mar 28, 2026
7fdf269
Merge branch 'main' into add_mxfp8
cyanguwa Mar 28, 2026
a7ff000
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2026
b0db79e
switch back to GL FE
cyanguwa Mar 28, 2026
f662a4a
fix ag after merge main
cyanguwa Mar 28, 2026
cbf6edd
add condition for qdq(do) to not affect other tests
cyanguwa Mar 28, 2026
0642251
fix custom_mha_fp8 test
cyanguwa Mar 28, 2026
e6ffc6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2026
fd9a750
fix amax dqkv
cyanguwa Mar 30, 2026
4f2e4f4
fix fp8_recipe in DPA utils
cyanguwa Mar 30, 2026
3869145
remove use of amax for mxfp8
cyanguwa Mar 31, 2026
641c05c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
59db112
add o_format/do_format/dqkv_layout to cache indicators for fp8 and f16
cyanguwa Mar 31, 2026
f1d1809
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
c491908
enable sink attn + FP8 in CP
cyanguwa Mar 31, 2026
6af3105
update FE to GH v1.22.0
cyanguwa Apr 3, 2026
508044b
fix for inconsistent kwarg name in permute to grouped tensor
cyanguwa Apr 4, 2026
2532a50
add TMA permute
cyanguwa Apr 4, 2026
d7c27f6
Revert "add TMA permute"
cyanguwa Apr 4, 2026
ba411a2
TMA load for bhsd transposes
cyanguwa Apr 6, 2026
5ada28d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 6, 2026
4a47e4d
Merge branch 'main' into add_mxfp8
cyanguwa Apr 6, 2026
6911aba
fix some lint
cyanguwa Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 89 files
+1 −1 CMakeLists.txt
+2 −0 README.md
+359 −199 include/cudnn_frontend/graph_interface.h
+14 −0 include/cudnn_frontend/graph_properties.h
+7 −7 include/cudnn_frontend/node/diagonal_band_mask.h
+23 −2 include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+38 −5 include/cudnn_frontend/node/sdpa_fp8_bwd.h
+7 −7 include/cudnn_frontend/node/softmax.h
+202 −192 include/cudnn_frontend/plans.h
+1 −1 include/cudnn_frontend_version.h
+1 −0 python/cudnn/README.md
+25 −1 python/cudnn/__init__.py
+137 −61 python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py
+207 −173 ...cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/discrete_B_blockscaled_grouped_gemm_dglu_dbias.py
+146 −61 python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py
+241 −128 ...on/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/discrete_B_blockscaled_grouped_gemm_glu_bias.py
+37 −8 python/cudnn/discrete_grouped_gemm/discrete_kernel_utils.py
+3 −0 python/cudnn/experimental/__init__.py
+3 −0 python/cudnn/experimental/ops/__init__.py
+1,079 −0 python/cudnn/experimental/ops/sdpa.py
+189 −412 python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py
+0 −4,427 python/cudnn/grouped_gemm/grouped_gemm_dglu/continugous_blockscaled_grouped_gemm_dglu_quant_dbias_fusion.py
+159 −97 python/cudnn/grouped_gemm/grouped_gemm_dglu/moe_blockscaled_grouped_gemm_dglu_dbias.py
+4 −2 python/cudnn/grouped_gemm/grouped_gemm_dswiglu/grouped_gemm_dswiglu_quant.py
+202 −403 python/cudnn/grouped_gemm/grouped_gemm_glu/api.py
+0 −3,713 python/cudnn/grouped_gemm/grouped_gemm_glu/continugous_blockscaled_grouped_gemm_glu_quant_bias_fusion.py
+218 −90 python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py
+349 −60 python/cudnn/grouped_gemm/grouped_gemm_quant/api.py
+10 −5 python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py
+6 −4 python/cudnn/grouped_gemm/grouped_gemm_swiglu/grouped_gemm_swiglu_quant.py
+36 −7 python/cudnn/grouped_gemm/moe_kernel_helpers.py
+12 −0 python/cudnn/sdpa/__init__.py
+581 −0 python/cudnn/sdpa/api.py
+438 −0 python/cudnn/sdpa/fmha_backward_sm100_2kernel.py
+3,016 −0 python/cudnn/sdpa/fmha_dkdv_d256_sm100.py
+1,968 −0 python/cudnn/sdpa/fmha_dq_d256_sm100.py
+1,143 −0 python/cudnn/sdpa/fmha_utils.py
+784 −0 python/cudnn/sdpa/utils.py
+24 −0 python/cudnn/wrapper.py
+47 −0 python/pygraph/pygraph.cpp
+23 −2 python/pygraph/pygraph.h
+10 −4 python/pygraph/sdpa.cpp
+2 −4 samples/cpp/misc/serialization.cpp
+2 −2 samples/cpp/sdpa/fp16_fwd_with_max_and_sum_exp.cpp
+2 −1 samples/legacy_samples/fp8_flash_mha_sample.cpp
+2 −2 samples/legacy_samples/fp8_flash_mha_sample.h
+1 −1 samples/legacy_samples/test_list.cpp
+4 −4 test/cpp/tensor.cpp
+9 −1 test/python/conftest.py
+152 −0 test/python/fe_api/test_discrete_grouped_gemm_dswiglu.py
+201 −7 test/python/fe_api/test_discrete_grouped_gemm_dswiglu_utils.py
+148 −0 test/python/fe_api/test_discrete_grouped_gemm_swiglu.py
+15 −1 test/python/fe_api/test_discrete_grouped_gemm_swiglu_utils.py
+3 −0 test/python/fe_api/test_fe_api_utils.py
+384 −0 test/python/fe_api/test_grouped_gemm_dglu.py
+19 −8 test/python/fe_api/test_grouped_gemm_dswiglu_utils.py
+389 −0 test/python/fe_api/test_grouped_gemm_glu.py
+391 −0 test/python/fe_api/test_grouped_gemm_quant.py
+45 −22 test/python/fe_api/test_grouped_gemm_quant_utils.py
+28 −12 test/python/fe_api/test_grouped_gemm_swiglu_utils.py
+157 −0 test/python/fe_api/test_sdpa_bwd.py
+352 −0 test/python/fe_api/test_sdpa_bwd_utils.py
+1 −0 test/python/sdpa/fp16.py
+6 −2 test/python/sdpa/fp8.py
+11 −9 test/python/sdpa/mxfp8.py
+4 −1 test/python/sdpa/mxfp8_ref.py
+1 −0 test/python/sdpa/random_config.py
+579 −0 test/python/test_cudnn_sdpa_op.py
+32 −6 test/python/test_mhas_v2.py
+107 −0 test/python/test_sdpa_fp8_serialization.py
+7 −1 tools/cudnn_repro/README.md
+13 −34 tools/cudnn_repro/cudnn_repro/__main__.py
+44 −0 tools/cudnn_repro/cudnn_repro/repro_command.py
+55 −0 tools/cudnn_repro/cudnn_repro/routing.py
+2 −7 tools/cudnn_repro/cudnn_repro/stage1_annotate.py
+67 −15 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_bwd.py
+168 −0 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_bwd.py
+168 −0 tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_fwd.py
+2 −7 tools/cudnn_repro/cudnn_repro/stage2_build_repro.py
+4 −32 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_bwd.py
+26 −0 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fp8_bwd.py
+26 −0 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fp8_fwd.py
+4 −31 tools/cudnn_repro/cudnn_repro/stage2_build_repro_sdpa_fwd.py
+61 −0 tools/cudnn_repro/cudnn_repro/utils.py
+172 −0 tools/cudnn_repro/tests/test_cudnn_repro_bwd.py
+90 −0 tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py
+229 −0 tools/cudnn_repro/tests/test_cudnn_repro_fp8.py
+25 −0 tools/cudnn_repro/tests/test_cudnn_repro_fp8_closed_loop.py
+94 −0 tools/cudnn_repro/tests/test_cudnn_repro_schema.py
45 changes: 37 additions & 8 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
DotProductAttention,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
)
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
MXFP8BlockScaling,
Format,
)
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling
from utils import ModelConfig, compare_and_assert

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
Expand Down Expand Up @@ -180,6 +186,7 @@ def run_dpa_with_cp(
scaling_mode="delayed",
f16_O="False",
is_training="True",
deterministic="False",
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
Expand All @@ -188,11 +195,15 @@ def run_dpa_with_cp(
is_training = is_training == "True"

# set up environment variables and config
if deterministic == "True":
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
else:
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
fp8_bwd = fp8_bwd == "True" and dtype == "fp8"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0"
fp8_dpa = fp8_dpa == "True" and dtype == "fp8"
fp8_mha = fp8_mha == "True" and dtype == "fp8"
f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True"
fp8_mha = fp8_mha == "True" and dtype == "fp8" and scaling_mode != "mxfp8"
f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True"
os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
Expand Down Expand Up @@ -247,6 +258,8 @@ def run_dpa_with_cp(
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
if scaling_mode == "current":
fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
if scaling_mode == "mxfp8":
fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)

# instantiate attention module
core_attn = DotProductAttention(
Expand Down Expand Up @@ -302,10 +315,25 @@ def run_dpa_with_cp(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
)
if scaling_mode == "mxfp8":
qkv_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
)
qkv_quantizer.optimize_for_gemm = True
qkv_quantizer.internal = False
dout_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
rowwise=True,
columnwise=True,
)
dout_quantizer.optimize_for_gemm = True
dout_quantizer.internal = False
qkv_layout = "_".join([qkv_format] * 3)
q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]]
if fp8_mha:
q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
for x in [q, k, v]:
x.requires_grad = True

Expand Down Expand Up @@ -413,7 +441,7 @@ def run_dpa_with_cp(
dout_quantizer.scale.fill_(1.0)
dout_quantizer.amax.fill_(0.0)
if fp8_mha:
q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
if is_training:
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
Expand Down Expand Up @@ -494,6 +522,7 @@ def run_dpa_with_cp(

# get outputs
tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_]
names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"]
if fp8_mha:
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
Expand All @@ -502,11 +531,11 @@ def run_dpa_with_cp(
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[5] = tensors_to_deq
for tensor in tensors:
for i, tensor in enumerate(tensors):
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN"
assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf"
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
Expand Down
101 changes: 77 additions & 24 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,20 +1803,45 @@ def get_model(dtype, config):
return outputs


attn_mask_type = "causal"
model_configs_fp8_vs_f16 = {
# test: ModelConfig(b, sq, hq, dqk)
"fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
"fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
"fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
"fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
"fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
"fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
"fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
"fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
"fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
"fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
"fp8_9": ModelConfig(
2,
4096,
128,
192,
head_dim_v=128,
),
"fp8_10": ModelConfig(
2,
4096,
128,
192,
head_dim_v=128,
attn_mask_type="causal",
),
"fp8_11": ModelConfig(
2,
4096,
128,
192,
head_dim_v=128,
attn_mask_type="causal_bottom_right",
),
"fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
"fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)),
"fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
"fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
"fp8_16": ModelConfig(
2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
),
"fp8_17": ModelConfig(
2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable"
),
"fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
"fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
"fp8_20": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
}

param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
Expand All @@ -1833,7 +1858,7 @@ def get_model(dtype, config):
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"])
def test_mha_fp8_vs_f16(
dtype,
model,
Expand Down Expand Up @@ -1864,6 +1889,12 @@ def test_mha_fp8_vs_f16(
fp8_dpa=True,
fp8_mha=True,
)
elif scaling_mode == "mxfp8":
fp8_recipe = recipe.MXFP8BlockScaling(
fp8_format=recipe.Format.E4M3,
fp8_dpa=True,
fp8_mha=False,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, _ = get_available_attention_backends(
Expand Down Expand Up @@ -2083,7 +2114,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
"""Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model]
Expand Down Expand Up @@ -2115,6 +2146,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8_format=recipe.Format.HYBRID,
fp8_dpa=True,
)
elif scaling_mode == "mxfp8":
fp8_recipe = recipe.MXFP8BlockScaling(
fp8_format=recipe.Format.E4M3,
fp8_dpa=True,
fp8_mha=False,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, _ = get_available_attention_backends(
Expand Down Expand Up @@ -2186,7 +2223,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
atol = 5e-1
rtol = 5e-2
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
bwd_names = ["dq", "dk", "dv", "d_softmax_offset"]
if flash_attn_supported and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
Expand Down Expand Up @@ -2275,7 +2312,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
with quantized_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim_qk,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
Expand All @@ -2285,6 +2322,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layer_number=1,
attention_type="self",
qkv_format=qkv_format,
softmax_type=config.softmax_type,
).to(dtype=dtype, device="cuda")
if not is_training:
dpa = dpa.eval()
Expand Down Expand Up @@ -2320,7 +2358,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim_qk,
"dqk": config.head_dim_qk,
"dv": config.head_dim_v,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
Expand All @@ -2336,6 +2375,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
if i == 2:
layout = layout.replace("d", "dv")
else:
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
if config.dropout_p == 0.0:
tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
Expand All @@ -2360,6 +2403,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:

qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
qkv_format_kv = qkv_format_kv.replace("d", "dv")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
Expand All @@ -2370,21 +2414,24 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
inp[1],
inp[2],
qkv_format=qkv_format,
window_size=config.window_size,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
fp8_output=fp8_dpa,
)
if is_training:
out.backward(out_grad)
d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = dpa.softmax_offset.grad

if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
return out, (None, None, None)
return out, (inp[0].grad, inp[1].grad, inp[2].grad, d_softmax_offset)
return out, (None, None, None, d_softmax_offset)


model_configs_fp8 = {
Expand Down Expand Up @@ -2636,6 +2683,8 @@ def forward(
quantization_params=qkv_quantizer,
use_split_accumulator=_2X_ACC_FPROP,
)
qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd"
o_format = "bshd" if cudnn_frontend_version == 1 else "thd"
qkv = qkv.view(-1, 3, h, d)
qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
torch.save(qkv_fp16, "qkv.pt")
Expand Down Expand Up @@ -2664,7 +2713,8 @@ def forward(
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
qkv_layout=qkv_layout,
o_format=o_format,
attn_bias_type="no_bias",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
rng_gen=None,
Expand All @@ -2687,6 +2737,8 @@ def forward(
ctx.num_heads = num_heads
ctx.mask_type = mask_type
ctx.dtype = inp.dtype
ctx.qkv_layout = qkv_layout
ctx.o_format = o_format

ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
Expand All @@ -2704,7 +2756,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx)

proj_dgrad = ctx.dO_quantizer(grad_output)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)

dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
Expand All @@ -2717,7 +2768,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
out,
proj_dgrad.view_as(out),
ctx.qkv_dtype,
fp8_dtype_backward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
None,
Expand All @@ -2728,7 +2778,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
attn_scale=None,
dropout=ctx.p_dropout,
fast_zero_fill=ctx.fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
qkv_layout=ctx.qkv_layout,
o_format=ctx.o_format,
do_format=ctx.o_format,
dqkv_layout=ctx.qkv_layout,
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
)
Expand Down
Loading
Loading