diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 7b9b711c22..97f6cb3b88 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 +Subproject commit 97f6cb3b88cacff507cca1280db5650a457d92b3 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0f36a8816d..cda5c42d50 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -19,8 +19,14 @@ DotProductAttention, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, +) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -180,6 +186,7 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + deterministic="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" @@ -188,11 +195,15 @@ def run_dpa_with_cp( is_training = is_training == "True" # set up environment variables and config + if deterministic == "True": + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" fp8_bwd = fp8_bwd == "True" and dtype == "fp8" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" fp8_dpa = fp8_dpa == "True" and dtype == "fp8" - fp8_mha = fp8_mha == "True" and dtype == "fp8" - f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + fp8_mha = fp8_mha == "True" and dtype == "fp8" and scaling_mode != "mxfp8" + f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True" os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -247,6 +258,8 @@ def run_dpa_with_cp( fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( @@ -302,10 +315,25 @@ def run_dpa_with_cp( fp8_dtype=tex.DType.kFloat8E5M2, device="cuda", ) + if scaling_mode == "mxfp8": + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] if fp8_mha: - q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -413,7 +441,7 @@ def run_dpa_with_cp( dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) if fp8_mha: - q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: @@ -494,6 +522,7 @@ def run_dpa_with_cp( # get outputs tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): @@ -502,11 +531,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" + assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 2eb307aa48..58f5ebb7bb 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1803,20 +1803,45 @@ def get_model(dtype, config): return outputs +attn_mask_type = "causal" model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128), - "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), - "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), - "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "fp8_9": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + ), + "fp8_10": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type="causal", + ), + "fp8_11": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type="causal_bottom_right", + ), + "fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), + "fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), + "fp8_16": ModelConfig( + 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "fp8_17": ModelConfig( + 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + ), + "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + "fp8_20": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1833,7 +1858,7 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_mha_fp8_vs_f16( dtype, model, @@ -1864,6 +1889,12 @@ def test_mha_fp8_vs_f16( fp8_dpa=True, fp8_mha=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2083,7 +2114,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -2115,6 +2146,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8_format=recipe.Format.HYBRID, fp8_dpa=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2186,7 +2223,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 - bwd_names = ["dq", "dk", "dv"] + bwd_names = ["dq", "dk", "dv", "d_softmax_offset"] if flash_attn_supported and fused_attn_supported_f16: logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) logging.debug("========== {:^25s} ==========".format("forward output")) @@ -2275,7 +2312,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -2285,6 +2322,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type="self", qkv_format=qkv_format, + softmax_type=config.softmax_type, ).to(dtype=dtype, device="cuda") if not is_training: dpa = dpa.eval() @@ -2320,7 +2358,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim_qk, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -2336,6 +2375,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] if config.dropout_p == 0.0: tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") @@ -2360,6 +2403,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") @@ -2370,6 +2414,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: inp[1], inp[2], qkv_format=qkv_format, + window_size=config.window_size, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=config.max_seqlen_q, @@ -2377,14 +2422,16 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) + d_softmax_offset = None + if is_training and config.softmax_type != "vanilla": + d_softmax_offset = dpa.softmax_offset.grad if is_training: - return out, (inp[0].grad, inp[1].grad, inp[2].grad) - return out, (None, None, None) + return out, (inp[0].grad, inp[1].grad, inp[2].grad, d_softmax_offset) + return out, (None, None, None, d_softmax_offset) model_configs_fp8 = { @@ -2636,6 +2683,8 @@ def forward( quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, ) + qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd" + o_format = "bshd" if cudnn_frontend_version == 1 else "thd" qkv = qkv.view(-1, 3, h, d) qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") @@ -2664,7 +2713,8 @@ def forward( attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + qkv_layout=qkv_layout, + o_format=o_format, attn_bias_type="no_bias", attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", rng_gen=None, @@ -2687,6 +2737,8 @@ def forward( ctx.num_heads = num_heads ctx.mask_type = mask_type ctx.dtype = inp.dtype + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer @@ -2704,7 +2756,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_func_ctx(ctx) proj_dgrad = ctx.dO_quantizer(grad_output) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, @@ -2717,7 +2768,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], out, proj_dgrad.view_as(out), ctx.qkv_dtype, - fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], None, @@ -2728,7 +2778,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], attn_scale=None, dropout=ctx.p_dropout, fast_zero_fill=ctx.fast_zero_fill, - qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + do_format=ctx.o_format, + dqkv_layout=ctx.qkv_layout, attn_bias_type="no_bias", attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", ) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5aaf67061b..6536da78b9 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -17,6 +17,8 @@ from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils @@ -26,6 +28,12 @@ pytest_logging_level = logging.getLevelName(logging.root.level) +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) @@ -39,13 +47,11 @@ "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA - "cp_2_2": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) - ), # GQA + "cp_2_2": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA - "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_0": ModelConfig(2, 4096, 128, 192, attn_mask_type="causal", head_dim_v=128), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 @@ -73,10 +79,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"] + configs = ["cp_2_0", "cp_2_2", "cp_3_0"] # ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"] model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} dtypes = ["bf16"] - qkv_formats = ["sbhd", "thd"] + # qkv_formats = ["sbhd", "thd"] @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @@ -94,25 +100,34 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config.context_parallel = True config.cp_comm_type = cp_comm_type - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No support for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if ( + config.window_size != (-1, 0) + and config.window_size != (-1, -1) + and cp_comm_type + in [ + "p2p", + "a2a+p2p", + ] + ): + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + # FlashAttention / CP implementation specific: MLA only with KV P2P if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} @@ -150,8 +165,22 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_0": ModelConfig( + 2, + 4096, + 32, + 128, + num_gqa_groups=4, + attn_mask_type="causal", + ), # GQA + "cp_2_1": ModelConfig( + 2, + 4096, + 32, + 128, + attn_mask_type="causal", + window_size=(128, 0), + ), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -189,7 +218,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA - "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal"), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA @@ -206,6 +235,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_4_2": ModelConfig( 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), # GQA + "cp_4_3": ModelConfig( + 2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + ), # GQA } @@ -214,21 +246,24 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: configs = [ - "cp_1_0", - "cp_1_1", - "cp_1_4", - "cp_1_5", + # "cp_1_0", + # "cp_1_1", + # "cp_1_4", + # "cp_1_5", "cp_2_0", - "cp_2_2", - "cp_2_3", - "cp_2_4", - "cp_3_2", - "cp_3_4", + "cp_2_1", + # "cp_2_2", + # "cp_2_3", + # "cp_2_4", + "cp_3_1", + # "cp_3_2", + # "cp_3_4", "cp_4_2", + "cp_4_3", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] - qkv_formats = ["sbhd", "thd"] + # qkv_formats = ["sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -240,96 +275,81 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): + config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") - if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+!") - if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") - if dtype == "fp8" and get_device_compute_capability() < (9, 0): - pytest.skip("FP8 attention is only supported on sm90+!") + if get_device_compute_capability() < (9, 0) and qkv_format == "thd": + pytest.skip("Only sm90+ architectures support THD format!") + if get_device_compute_capability() < (9, 0) and dtype == "fp8": + pytest.skip("Only sm90+ architectures support FP8 attention!") + + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("dtype=fp8 requires fp8_dpa=True or fp8_mha=True!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: - pytest.skip("Only fp8 works with fp8_bwd=True!") - - config = model_configs_fused_attn[model] - config.context_parallel = True - config.cp_comm_type = cp_comm_type + pytest.skip("fp8_bwd=True requires dtype=fp8!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("dtype!=fp8 requires fp8_dpa=False and fp8_mha=False!") - if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip("THD format does not support post_scale_bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if dtype == "fp8" and cp_comm_type == "all_gather": - pytest.skip( - "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - ) if dtype == "fp8" and qkv_format == "thd": - pytest.skip("FP8 attention cannot work with THD format yet!") + pytest.skip("No support for FP8 attention with THD format!") if dtype == "fp8" and config.attn_bias_type != "no_bias": - pytest.skip("FP8 attention cannot work with bias yet!") - if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("FP8 attention cannot work with sliding window yet!") - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): - pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" - ) - if dtype != "fp8" and (fp8_mha or fp8_dpa): - pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") - if dtype == "fp8" and not (fp8_mha or fp8_dpa): - pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") - if dtype != "fp8" and scaling_mode is not None: - pytest.skip("Only fp8 works with scaling_mode != None!") - if dtype == "fp8" and scaling_mode is None: - pytest.skip("fp8 only works with scaling_mode != None!") - if ( - dtype == "fp8" - and scaling_mode == "current" - and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + pytest.skip("No support for FP8 attention with bias!") + + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No supprt for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 ): - pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode != "current"): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently only support KV P2P!") - if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently does not support FP8 attention!") - if dtype == "fp8" and config.softmax_type != "vanilla": - pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") - if config.softmax_type != "vanilla" and cp_comm_type != "a2a": pytest.skip( - "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") if ( - get_cudnn_version() < (9, 18, 0) - and config.softmax_type != "vanilla" + config.softmax_type != "vanilla" and qkv_format == "thd" + and get_cudnn_version() < (9, 18, 0) ): - pytest.skip( - "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" - " non-vanilla softmax types!" - ) + pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") + + if dtype == "fp8" and scaling_mode is None: + pytest.skip("dtype=fp8 requires scaling_mode != None!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("dtype!=fp8 requires scaling_mode = None!") + if dtype != "fp8" and not f16_O: + pytest.skip("dtype!=fp8 requires f16_O=True!") + if scaling_mode == "delayed" and f16_O: + pytest.skip("scaling_mode=delayed requires f16_O=False!") + if scaling_mode == "mxfp8" and not f16_O: + pytest.skip("scaling_mode=mxfp8 requires f16_O=True!") + if scaling_mode == "mxfp8" and fp8_mha: + pytest.skip("No support for scaling_mode=mxfp8 with fp8_mha=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -353,6 +373,12 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + if fp8 and scaling_mode == "mxfp8": + fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=True) + fp8_meta["local_recipes"] = [ + MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=True), + ] + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -362,6 +388,7 @@ def test_cp_with_fused_attention( fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -381,6 +408,7 @@ def test_cp_with_fused_attention( scaling_mode=scaling_mode, f16_O=f16_O, is_training=is_training, + deterministic=_deterministic, log_level=pytest_logging_level, ), ) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 929f02453d..c7e2dff477 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -164,6 +164,10 @@ def reset_rng_states() -> None: def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): + if a is None and b is None: + logging.debug(f"{name_a} vs {name_b}: both are None") + return + if not is_fp8: torch.testing.assert_close(a, b, atol=atol, rtol=rtol) return diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 6c66746e62..97e9a620ba 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -4,18 +4,115 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + #include "../common.h" +#include "../util/cuda_driver.h" +#include "../util/ptx.cuh" +#include "../utils.cuh" #include "transformer_engine/fused_attn.h" namespace transformer_engine { namespace flash_attention { +/// Packed vector of N elements of T; alignment matches a single wide load/store of N * sizeof(T) bytes. +template +struct alignas(sizeof(T) * N) Vec { + T data[N]; +}; + constexpr int warp_size = 32; constexpr int type_size = 2; // FP16 or BF16 constexpr int nvec = sizeof(uint64_t) / type_size; +constexpr int nvec128 = sizeof(uint4) / type_size; constexpr int load_size = warp_size * nvec; constexpr int block_size = 512; +// TMA permute kernel configuration +constexpr int tma_permute_threads = 32; +constexpr int tma_permute_s_tile = 32; + +// ---- 4D TMA PTX wrappers ---- + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_global_to_shared( + void *dst_shmem, const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, + uint32_t c3, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t dst = __cvta_generic_to_shared(dst_shmem); + uint32_t bar = __cvta_generic_to_shared(mbar); + asm volatile( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3, %4, %5}], [%6];" ::"r"(dst), + "l"(tensor_map), "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(bar) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_global_to_shared requires SM 10.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_tensor_4d_shared_to_global( + const CUtensorMap *tensor_map, uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3, + void *src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + uint32_t src = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + " [%0, {%1, %2, %3, %4}], [%5];" ::"l"(tensor_map), + "r"(c0), "r"(c1), "r"(c2), "r"(c3), "r"(src) + : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_4d_shared_to_global requires SM 9.0+."); +#endif +} + +// ---- Host-side 4D tensor map creation ---- +// +// Creates a 4D TMA descriptor for a densely-packed tensor whose logical +// dimensions (innermost-first) are [dim0, dim1, dim2, dim3]. +// +// The box (tile) copied per TMA instruction is [box0, box1, box2, box3]. + +static void create_4D_tensor_map(CUtensorMap &tensorMap, void *dataPtr, DType dtype, uint64_t dim0, + uint64_t dim1, uint64_t dim2, uint64_t dim3, uint32_t box0, + uint32_t box1, uint32_t box2, uint32_t box3) { + cuda_driver::ensure_context_exists(); + static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { + void *ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(ptr); + }(); + + CUtensorMapDataType tma_dtype; + size_t elem_bytes; + switch (dtype) { + case DType::kFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bytes = 2; + break; + case DType::kBFloat16: + tma_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bytes = 2; + break; + default: + NVTE_ERROR("create_4D_tensor_map: unsupported dtype"); + } + + constexpr uint32_t rank = 4; + uint64_t size[rank] = {dim0, dim1, dim2, dim3}; + uint64_t stride[rank - 1] = { + dim0 * elem_bytes, + dim0 * dim1 * elem_bytes, + dim0 * dim1 * dim2 * elem_bytes, + }; + uint32_t boxSize[rank] = {box0, box1, box2, box3}; + uint32_t elemStride[rank] = {1, 1, 1, 1}; + + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, tma_dtype, rank, dataPtr, size, stride, boxSize, elemStride, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA)); +} + template __launch_bounds__(block_size) __global__ void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z, @@ -35,8 +132,8 @@ __launch_bounds__(block_size) __global__ T *my_output = qkv + offset_output; for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size); - *out = *reinterpret_cast(my_input + i * load_size * 3); + Vec *const out = reinterpret_cast *>(my_output + i * load_size); + *out = *reinterpret_cast *>(my_input + i * load_size * 3); } } @@ -61,8 +158,8 @@ __launch_bounds__(block_size) __global__ T *my_output = qkv + offset_output; for (int i = 0; i < Z; ++i) { - uint64_t *out = reinterpret_cast(my_output + i * load_size * 3); - *out = *reinterpret_cast(my_input + i * load_size); + Vec *const out = reinterpret_cast *>(my_output + i * load_size * 3); + *out = *reinterpret_cast *>(my_input + i * load_size); } } @@ -133,6 +230,327 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream NVTE_CHECK_CUDA(cudaGetLastError()); } +// ---- TMA helpers for strided (BSHD/SBHD) tensors ---- +// +// Strided BSHD: TMA dims [D, H, S, B], coords [0, h, s, b] +// Strided SBHD: TMA dims [D, H, B, S], coords [0, h, b, s] + +template +__device__ __forceinline__ void issue_tma_load_strided(T *smem_buf, const CUtensorMap *tma, + size_t h_i, size_t s_tile, size_t b_i, + uint64_t *mbar, size_t tile_bytes) { + ptx::mbarrier_arrive_expect_tx(mbar, static_cast(tile_bytes)); + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(s_tile), + static_cast(b_i), mbar); + } else { + cp_async_bulk_tensor_4d_global_to_shared(smem_buf, tma, 0, static_cast(h_i), + static_cast(b_i), + static_cast(s_tile), mbar); + } +} + +template +__device__ __forceinline__ void issue_tma_store_strided(const CUtensorMap *tma, T *smem_buf, + size_t h_i, size_t s_tile, size_t b_i) { + if constexpr (kIsBshdBshdBshd) { + cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), + static_cast(s_tile), + static_cast(b_i), smem_buf); + } else { + cp_async_bulk_tensor_4d_shared_to_global(tma, 0, static_cast(h_i), + static_cast(b_i), + static_cast(s_tile), smem_buf); + } + ptx::cp_async_bulk_commit_group(); +} + +__device__ __forceinline__ void st_global_cs_uint4(uint4 *ptr, uint4 val) { + asm volatile("st.global.cs.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(val.x), "r"(val.y), + "r"(val.z), "r"(val.w) + : "memory"); +} + +// ---- Forward: BSHD/SBHD → BHSD ---- +// +// TMA load from strided input → smem → non-temporal stores to contiguous output. + +template +__launch_bounds__(tma_permute_threads) __global__ + void permute_to_grouped_tensor_fwd_kernel(const __grid_constant__ CUtensorMap tma_q_in, + const __grid_constant__ CUtensorMap tma_k_in, + const __grid_constant__ CUtensorMap tma_v_in, + T *__restrict__ q_out, T *__restrict__ k_out, + T *__restrict__ v_out, size_t b, size_t s_q, + size_t h_q, size_t d_qk, size_t s_kv, size_t h_kv, + size_t d_v, unsigned int permute_s_splits) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const int which = blockIdx.z; + const CUtensorMap *tma_in = which == 0 ? &tma_q_in : (which == 1 ? &tma_k_in : &tma_v_in); + T *__restrict__ tensor_out = which == 0 ? q_out : (which == 1 ? k_out : v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; + + if (b_i >= b) return; + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = + (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + + const size_t out_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; + + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + __shared__ __align__(8) uint64_t mbar; + const bool is_leader = (threadIdx.x == 0); + + if (is_leader) { + ptx::mbarrier_init(&mbar, static_cast(blockDim.x)); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + constexpr size_t S_TILE = tma_permute_s_tile; + const uint32_t tile_bytes = static_cast(S_TILE * Ddim * sizeof(T)); + int parity = 0; + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + if (is_leader) { + issue_tma_load_strided(smem, tma_in, h_i, s_tile, b_i, &mbar, tile_bytes); + } else { + ptx::mbarrier_arrive(&mbar); + } + + ptx::mbarrier_wait_parity(&mbar, parity); + parity ^= 1; + + T *__restrict__ out_ptr = tensor_out + out_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + uint4 v = *reinterpret_cast(smem + i); + st_global_cs_uint4(reinterpret_cast(out_ptr + i), v); + } + + __syncthreads(); + } + + if (is_leader) { + ptx::mbarrier_invalid(&mbar); + } +#endif +} + +// ---- Backward: BHSD → BSHD/SBHD ---- +// +// Vectorized loads from contiguous input → smem → TMA store to strided output. + +template +__launch_bounds__(tma_permute_threads) __global__ void permute_to_grouped_tensor_bwd_kernel( + const T *__restrict__ grad_q, const T *__restrict__ grad_k, const T *__restrict__ grad_v, + const __grid_constant__ CUtensorMap tma_q_out, const __grid_constant__ CUtensorMap tma_k_out, + const __grid_constant__ CUtensorMap tma_v_out, size_t b, size_t s_q, size_t h_q, size_t d_qk, + size_t s_kv, size_t h_kv, size_t d_v, unsigned int permute_s_splits) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const int which = blockIdx.z; + const T *__restrict__ tensor_in = which == 0 ? grad_q : (which == 1 ? grad_k : grad_v); + const CUtensorMap *tma_out = which == 0 ? &tma_q_out : (which == 1 ? &tma_k_out : &tma_v_out); + const size_t Sdim = which == 0 ? s_q : s_kv; + const size_t Hdim = which == 0 ? h_q : h_kv; + const size_t Ddim = which == 0 ? d_qk : (which == 1 ? d_qk : d_v); + + const size_t h_grid = h_q > h_kv ? h_q : h_kv; + const size_t b_i = static_cast(blockIdx.x) / h_grid; + const size_t h_i = static_cast(blockIdx.x) % h_grid; + + if (b_i >= b) return; + if (which == 0) { + if (h_i >= h_q) return; + } else { + if (h_i >= h_kv) return; + } + + const unsigned int s_part = blockIdx.y; + const size_t s_begin = + (Sdim * static_cast(s_part)) / static_cast(permute_s_splits); + const size_t s_end = + (Sdim * static_cast(s_part + 1)) / static_cast(permute_s_splits); + if (s_begin >= s_end) return; + + const size_t in_base = b_i * Hdim * Sdim * Ddim + h_i * Sdim * Ddim; + + extern __shared__ __align__(128) char smem_raw[]; + T *smem = reinterpret_cast(smem_raw); + + constexpr size_t S_TILE = tma_permute_s_tile; + constexpr size_t vec_elems = sizeof(uint4) / sizeof(T); + + for (size_t s_tile = s_begin; s_tile < s_end; s_tile += S_TILE) { + const size_t tile_rows = min(S_TILE, s_end - s_tile); + + const T *__restrict__ in_ptr = tensor_in + in_base + s_tile * Ddim; + const size_t total_elems = tile_rows * Ddim; + + for (size_t i = threadIdx.x * vec_elems; i < total_elems; + i += static_cast(blockDim.x) * vec_elems) { + *reinterpret_cast(smem + i) = *reinterpret_cast(in_ptr + i); + } + + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (threadIdx.x == 0) { + issue_tma_store_strided(tma_out, smem, h_i, s_tile, b_i); + } + + ptx::cp_async_bulk_wait_group(); + __syncthreads(); + } +#endif +} + +// Helper: create a 4D TMA descriptor for the strided (BSHD or SBHD) tensor. +// +// For BSHD [B, S, H, D]: TMA dims [D, H, S, B], box [D, 1, S_TILE, 1] +// For SBHD [S, B, H, D]: TMA dims [D, H, B, S], box [D, 1, 1, S_TILE] +static void create_strided_tensor_map(CUtensorMap &map, void *ptr, DType dtype, size_t b, size_t s, + size_t h, size_t d, bool is_bshd) { + if (is_bshd) { + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), + static_cast(s), static_cast(b), + static_cast(d), 1, static_cast(tma_permute_s_tile), 1); + } else { + create_4D_tensor_map(map, ptr, dtype, static_cast(d), static_cast(h), + static_cast(b), static_cast(s), + static_cast(d), 1, 1, static_cast(tma_permute_s_tile)); + } +} + +void permute_to_grouped_tensor_fwd(Tensor q, Tensor k, Tensor v, Tensor q_out, Tensor k_out, + Tensor v_out, NVTE_QKV_Layout original_layout, + cudaStream_t stream) { + using namespace transformer_engine; + const size_t b = q_out.shape()[0]; + const size_t h_q = q_out.shape()[1]; + const size_t s_q = q_out.shape()[2]; + const size_t d_qk = q_out.shape()[3]; + const size_t h_kv = k_out.shape()[1]; + const size_t s_kv = k_out.shape()[2]; + const size_t d_v = v_out.shape()[3]; + + NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, + "permute_to_grouped_tensor_fwd: head dim must be divisible by ", nvec128, "."); + + const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + + alignas(64) CUtensorMap tma_q_in{}, tma_k_in{}, tma_v_in{}; + create_strided_tensor_map(tma_q_in, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); + create_strided_tensor_map(tma_k_in, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); + create_strided_tensor_map(tma_v_in, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + + const size_t s_min = std::min(s_q, s_kv); + const unsigned int permute_s_splits = + std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); + const size_t h_grid = std::max(h_q, h_kv); + dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); + + const size_t d_max = std::max(d_qk, d_v); + const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + tma_q_in, tma_k_in, tma_v_in, reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_fwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + tma_q_in, tma_k_in, tma_v_in, reinterpret_cast(q_out.data.dptr), + reinterpret_cast(k_out.data.dptr), reinterpret_cast(v_out.data.dptr), + b, s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void permute_to_grouped_tensor_bwd(Tensor grad_q, Tensor grad_k, Tensor grad_v, Tensor q, Tensor k, + Tensor v, NVTE_QKV_Layout original_layout, cudaStream_t stream) { + using namespace transformer_engine; + const size_t b = grad_q.shape()[0]; + const size_t h_q = grad_q.shape()[1]; + const size_t s_q = grad_q.shape()[2]; + const size_t d_qk = grad_q.shape()[3]; + const size_t h_kv = grad_k.shape()[1]; + const size_t s_kv = grad_k.shape()[2]; + const size_t d_v = grad_v.shape()[3]; + + NVTE_CHECK(d_qk % nvec128 == 0 && d_v % nvec128 == 0, + "permute_to_grouped_tensor_bwd: head dim must be divisible by ", nvec128, "."); + + const bool is_bshd = (original_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + + alignas(64) CUtensorMap tma_q_out{}, tma_k_out{}, tma_v_out{}; + create_strided_tensor_map(tma_q_out, q.data.dptr, q.dtype(), b, s_q, h_q, d_qk, is_bshd); + create_strided_tensor_map(tma_k_out, k.data.dptr, k.dtype(), b, s_kv, h_kv, d_qk, is_bshd); + create_strided_tensor_map(tma_v_out, v.data.dptr, v.dtype(), b, s_kv, h_kv, d_v, is_bshd); + + const size_t s_min = std::min(s_q, s_kv); + const unsigned int permute_s_splits = + std::max(1u, static_cast(s_min / static_cast(tma_permute_threads))); + const size_t h_grid = std::max(h_q, h_kv); + dim3 grid(static_cast(b * h_grid), permute_s_splits, 3); + + const size_t d_max = std::max(d_qk, d_v); + const size_t smem_bytes = tma_permute_s_tile * d_max * sizeof(uint16_t); + + if (is_bshd) { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), tma_q_out, tma_k_out, tma_v_out, b, + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_q.dtype(), dtype, auto kernel = permute_to_grouped_tensor_bwd_kernel; + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + kernel<<>>( + reinterpret_cast(grad_q.data.dptr), + reinterpret_cast(grad_k.data.dptr), + reinterpret_cast(grad_v.data.dptr), tma_q_out, tma_k_out, tma_v_out, b, + s_q, h_q, d_qk, s_kv, h_kv, d_v, permute_s_splits);); + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // namespace flash_attention } // namespace transformer_engine @@ -153,3 +571,27 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET *convertNVTETensorCheck(v), *convertNVTETensorCheck(qkv), stream); } + +void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, + NVTETensor k_out, NVTETensor v_out, + NVTE_QKV_Layout original_layout, cudaStream_t stream) { + NVTE_API_CALL(nvte_permute_to_grouped_tensor_fwd); + using namespace transformer_engine; + + flash_attention::permute_to_grouped_tensor_fwd( + *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), *convertNVTETensorCheck(v), + *convertNVTETensorCheck(q_out), *convertNVTETensorCheck(k_out), + *convertNVTETensorCheck(v_out), original_layout, stream); +} + +void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NVTETensor grad_v, + NVTETensor q, NVTETensor k, NVTETensor v, + NVTE_QKV_Layout original_layout, cudaStream_t stream) { + NVTE_API_CALL(nvte_permute_to_grouped_tensor_bwd); + using namespace transformer_engine; + + flash_attention::permute_to_grouped_tensor_bwd( + *convertNVTETensorCheck(grad_q), *convertNVTETensorCheck(grad_k), + *convertNVTETensorCheck(grad_v), *convertNVTETensorCheck(q), *convertNVTETensorCheck(k), + *convertNVTETensorCheck(v), original_layout, stream); +} diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3d6e3a0aac..5498c601a6 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -131,6 +131,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_SD_SD_SD; default: NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_layout_group."); @@ -172,6 +174,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_format."); @@ -192,6 +196,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_q_format."); @@ -212,12 +218,93 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_kv_format."); } } +// map one NVTE_QKV_Format to another +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t) { + size_t b_tmp = 0, h_tmp = 0, s_tmp = 0, d_tmp = 0, t_tmp = 0; + switch (src_format) { + case NVTE_QKV_Format::NVTE_BSHD: + b_tmp = src_shape[0]; + s_tmp = src_shape[1]; + h_tmp = src_shape[2]; + d_tmp = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_SBHD: + s_tmp = src_shape[0]; + b_tmp = src_shape[1]; + h_tmp = src_shape[2]; + d_tmp = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_BHSD: + b_tmp = src_shape[0]; + h_tmp = src_shape[1]; + s_tmp = src_shape[2]; + d_tmp = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_THD: + t_tmp = src_shape[0]; + h_tmp = src_shape[1]; + d_tmp = src_shape[2]; + break; + default: + NVTE_ERROR("src_format not supported!"); + break; + } + switch (dst_format) { + case NVTE_QKV_Format::NVTE_BSHD: + dst_shape[0] = b_tmp; + dst_shape[1] = s_tmp; + dst_shape[2] = h_tmp; + dst_shape[3] = d_tmp; + break; + case NVTE_QKV_Format::NVTE_SBHD: + dst_shape[0] = s_tmp; + dst_shape[1] = b_tmp; + dst_shape[2] = h_tmp; + dst_shape[3] = d_tmp; + break; + case NVTE_QKV_Format::NVTE_BHSD: + dst_shape[0] = b_tmp; + dst_shape[1] = h_tmp; + dst_shape[2] = s_tmp; + dst_shape[3] = d_tmp; + break; + case NVTE_QKV_Format::NVTE_THD: + dst_shape[0] = t_tmp; + dst_shape[1] = h_tmp; + dst_shape[2] = d_tmp; + break; + default: + NVTE_ERROR("dst_format not supported!"); + break; + } + + if (b != nullptr) { + *b = b_tmp; + } + if (h != nullptr) { + *h = h_tmp; + } + if (s != nullptr) { + *s = s_tmp; + } + if (d != nullptr) { + *d = d_tmp; + } + if (t != nullptr) { + *t = t_tmp; + } +} + // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, @@ -269,9 +356,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: d_qk=192, d_v=128 + (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && + head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && + // pre-9.21: {bshd, sbhd}, {vanilla} + // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} + ((cudnn_runtime_version < 92100 && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || + (cudnn_runtime_version >= 92100 && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && + !requires_64bit_ragged_offset && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { @@ -410,12 +510,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -556,19 +659,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -587,23 +688,27 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; + std::vector tmp_shape(4); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, + &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_qk, &t_kv); + if (input_V->scaling_mode != NVTE_MXFP8_1D_SCALING) { + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); + } else { + nvte_convert_qkv_format(kv_format, input_V->columnwise_data.shape, kv_format, tmp_shape, &b, + &h_kv, &s_kv, &d_v, &t_kv); } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; + if (q_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_q->data.shape[0] - 1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] - 1; } + int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -655,9 +760,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, - return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, - input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + return_max_logit, attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, + input_V, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else @@ -667,10 +772,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, + attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + input_Q, input_K, input_V, input_SoftmaxOffset, input_output_S, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -687,11 +794,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -712,22 +820,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector tmp_shape(4); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, + &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; + b = input_cu_seqlens_q->data.shape[0] - 1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] - 1; } auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); @@ -763,11 +869,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, - input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, input_Q, input_K, + input_V, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, + output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -776,14 +883,28 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K, - input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, - output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + size_t i = 0; + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_ZInv = nullptr; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_SoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + const Tensor *input_dO_f16 = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, + qkv_layout, o_format, do_format, dqkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, + input_M, input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, + output_dQ, output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eed6740740..f8c3992587 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -54,11 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -80,8 +80,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); - NVTE_QKV_Format q_format = nvte_get_q_format(layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -89,7 +89,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const int sm_arch_ = cuda::sm_arch(device_id); bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); @@ -135,7 +135,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, + o_format, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout_NOT_SET, bias_type, mask_type, softmax_type, @@ -202,17 +205,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); if (is_paged_kv) { generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(), - layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); } else { - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); } @@ -368,7 +371,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); if (is_ragged_q) { @@ -513,7 +516,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } - const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, @@ -551,7 +554,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, @@ -578,8 +582,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); - NVTE_QKV_Format q_format = nvte_get_q_format(layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -587,7 +591,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const int sm_arch_ = cuda::sm_arch(device_id); bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); if (is_paged_kv) { NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); @@ -632,7 +636,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( scaling_factor, true, dropout_probability, - layout, + qkv_layout, + o_format, + do_format, + dqkv_layout, bias_type, mask_type, softmax_type, @@ -703,13 +710,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), qkv_layout, NVTE_QKV_Matrix::NVTE_O_Matrix); q = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1024,7 +1031,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; } - const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, @@ -1067,13 +1074,14 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1202,12 +1210,12 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, - is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, - devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, + devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1228,6 +1236,7 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, @@ -1300,12 +1309,12 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, - devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, + devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, + devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4dd7f3d1da..19dc94e755 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -24,18 +24,20 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 80e64370f9..6fa366dc2c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,16 +1652,20 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, + void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, + void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, + NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -1669,19 +1673,25 @@ void fused_attn_fp8_fwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); try { FADescriptor_v1 descriptor{b, @@ -1689,8 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -1704,13 +1714,16 @@ void fused_attn_fp8_fwd_impl_v1( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, + o_format, + NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout_NOT_SET, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + softmax_type, + window_size_left, + window_size_right, + bottom_right_diagonal, true, qkv_tensor_type, o_tensor_type, @@ -1736,6 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // amax_o std::shared_ptr, // Stats std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -1762,31 +1776,28 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr descale_q, descale_k, descale_v; std::shared_ptr descale_s, scale_s, scale_o; - std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr bias, softmax_offset, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + // Q, K, V, attn_scale + std::vector q_strides(4), k_strides(4), v_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_strides) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_strides) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_strides) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -1794,21 +1805,57 @@ void fused_attn_fp8_fwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - - if (is_delayed_scaling) { - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); - } - if (is_current_scaling) { - scale_o = mha_graph->tensor(1.0f); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_o"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } + } else if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, + v_scale_strides.data(), kv_format); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_attributes sdpa_options; @@ -1818,6 +1865,20 @@ void fused_attn_fp8_fwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 92100) { + if (window_size_left != -1) { + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } + } + // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1855,19 +1916,41 @@ void fused_attn_fp8_fwd_impl_v1( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( - Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_sink_token(softmax_offset); + } - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); - amax_o->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); + std::shared_ptr O, Stats, amax_s, amax_o; + if (is_delayed_scaling || is_current_scaling) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_s = outputs[2]; + amax_o = outputs[3]; + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } else if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_o = outputs[2]; + } - amax_s->set_output(true) + std::vector o_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); + O->set_output(true) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_strides) + .set_data_type(o_tensor_type); + amax_o->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); @@ -1890,10 +1973,15 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, attn_scale, O, amax_s, amax_o); + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, + nullptr, attn_scale, O, nullptr, amax_o) + : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -1904,17 +1992,17 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, - attn_scale, O, amax_s, amax_o, Stats, bias, seq_q, seq_kv, dropout_seed, dropout_offset] = - get_graph(sdpa_fp8_fprop_cache, descriptor); + attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, dropout_seed, + dropout_offset] = get_graph(sdpa_fp8_fprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -1937,17 +2025,19 @@ void fused_attn_fp8_fwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_s, devPtrDescaleS}, - {scale_s, devPtrScaleS}, {attn_scale, &scaling_factor}, {O, devPtrO}, - {amax_s, devPtrAmaxS}, - {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; if (is_delayed_scaling) { variant_pack[scale_o] = devPtrScaleO; } + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[amax_s] = devPtrAmaxS; + variant_pack[amax_o] = devPtrAmaxO; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -1972,6 +2062,10 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -1980,20 +2074,26 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, - void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, - void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, - void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, - void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, + void* devPtrdK, void* devPtrdV, void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, + void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, + void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, + void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, + void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, + void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -2001,20 +2101,26 @@ void fused_attn_fp8_bwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; - const auto cudnn_runtime_version = cudnnGetVersion(); auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2024,8 +2130,8 @@ void fused_attn_fp8_bwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -2039,13 +2145,16 @@ void fused_attn_fp8_bwd_impl_v1( scaling_factor, true, dropout_probability, - layout, + qkv_layout, + o_format, + do_format, + dqkv_layout, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + softmax_type, + window_size_left, + window_size_right, + bottom_right_diagonal, deterministic, qkv_tensor_type, o_tensor_type, @@ -2056,18 +2165,25 @@ void fused_attn_fp8_bwd_impl_v1( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::shared_ptr, // Q + std::shared_ptr, // Q_t + std::shared_ptr, // K + std::shared_ptr, // K_t + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2084,6 +2200,8 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dP std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -2108,54 +2226,54 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, + attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, + descale_v; std::shared_ptr descale_s, descale_o; - std::shared_ptr descale_dP, descale_dO; + std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; std::shared_ptr scale_dQ, scale_dK, scale_dV; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset; + std::shared_ptr seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() + // Q, K, V, O, dO, stats, attn_scale + std::vector q_strides(4), k_strides(4), v_strides(4), o_strides(4), dO_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_strides.data(), do_format); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_strides) + .set_data_type(qkv_tensor_type)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_strides) + .set_data_type(qkv_tensor_type)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_strides) + .set_data_type(qkv_tensor_type)); + O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_strides) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_strides) + .set_data_type(do_tensor_type)); + Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") .set_dim({b, h, s_q, 1}) .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -2163,33 +2281,130 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { - descale_o = mha_graph->tensor(1.0f); - } else { - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); - } - descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - - if (is_delayed_scaling) { - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); - } - if (is_current_scaling) { - scale_dQ = mha_graph->tensor(1.0f); - scale_dK = mha_graph->tensor(1.0f); - scale_dV = mha_graph->tensor(1.0f); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Descale_dP, Scale_dP, Descale_o, Descale_dO, Scale_dQ, Scale_dK, Scale_dV + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + if (is_current_scaling && is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } + } else if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + // Q_t, K_t, dO_t, dO_f16 + std::vector q_t_strides(4), k_t_strides(4), dO_t_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_strides.data(), do_format); + Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_strides) + .set_data_type(qkv_tensor_type)); + K_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_strides) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_strides) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_strides) + .set_data_type(o_tensor_type)); + // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + std::vector q_scale_strides(4), q_t_scale_strides(4), k_scale_strides(4), + k_t_scale_strides(4), v_scale_strides(4), dO_scale_strides(4), dO_t_scale_strides(4); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, + q_t_scale_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, + k_t_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, + v_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, + dO_scale_strides.data(), do_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, + dO_t_scale_strides.data(), do_format); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_q_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k_t") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) + .set_stride(k_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO") + .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) + .set_stride(dO_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) + .set_stride(dO_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; @@ -2198,6 +2413,20 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 92100) { + if (window_size_left != -1) { + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } + } + // sdpa_backward_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2251,40 +2480,75 @@ void fused_attn_fp8_bwd_impl_v1( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( - q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_dsink_token(d_softmax_offset); + } - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); - amax_dQ->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true) + std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; + if (is_delayed_scaling || is_current_scaling) { + std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP) = + std::apply([](const auto&... elems) { return std::make_tuple(elems...); }, + mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, + scale_dV, scale_dP, sdpa_backward_options)); + } else if (is_mxfp8) { + std::tie(dQ, dK, dV, amax_dQ, amax_dK, amax_dV) = std::apply( + [](const auto&... elems) { return std::make_tuple(elems...); }, + mha_graph->sdpa_fp8_backward(Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, + descale_q_t, descale_k, descale_k_t, descale_v, descale_dO, + descale_dO_t, sdpa_backward_options)); + } + std::vector dq_strides(4), dk_strides(4), dv_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, dq_strides.data(), + dk_strides.data(), dv_strides.data(), dqkv_layout); + dQ->set_output(true) + .set_dim({b, h, s_q, d_qk}) + .set_stride(dq_strides) + .set_data_type(dqkv_tensor_type); + dK->set_output(true) + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(dk_strides) + .set_data_type(dqkv_tensor_type); + dV->set_output(true) + .set_dim({b, hg, s_kv, d_v}) + .set_stride(dv_strides) + .set_data_type(dqkv_tensor_type); + amax_dQ->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true) + amax_dK->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true) + amax_dV->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + if (is_delayed_scaling || is_current_scaling) { + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } - dO->set_data_type(do_tensor_type); - dQ->set_data_type(dqkv_tensor_type); - dK->set_data_type(dqkv_tensor_type); - dV->set_data_type(dqkv_tensor_type); - - std::tuple, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO std::shared_ptr, // attn_scale std::shared_ptr, // descale_q @@ -2307,10 +2571,16 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dV std::shared_ptr> // amax_dP key_tensors_tuple = std::make_tuple( - q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); + auto mxfp8_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -2322,17 +2592,18 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, + bias_tuple, softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - - auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, + descale_k_t, descale_dO_t, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, + dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2349,37 +2620,47 @@ void fused_attn_fp8_bwd_impl_v1( // build variant pack std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrK}, - {v, devPtrV}, - {o, devPtrO}, - {stats, devPtrM}, + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {Stats, devPtrM}, {dO, devPtrdO}, {attn_scale, &scaling_factor}, {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, {descale_dO, devPtrDescaledO}, - {descale_s, devPtrDescaleS}, - {descale_dP, devPtrDescaledP}, - {scale_s, devPtrScaleS}, - {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, - {amax_dQ, devPtrAmaxdQ}, - {amax_dK, devPtrAmaxdK}, - {amax_dV, devPtrAmaxdV}, - {amax_dP, devPtrAmaxdP}, }; - + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + variant_pack[amax_dQ] = devPtrAmaxdQ; + variant_pack[amax_dK] = devPtrAmaxdK; + variant_pack[amax_dV] = devPtrAmaxdV; + } + if (is_delayed_scaling || (is_current_scaling && !is_O_in_F16)) { + variant_pack[descale_o] = devPtrDescaleO; + } if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (!is_O_in_F16) { - variant_pack[descale_o] = devPtrDescaleO; + if (is_mxfp8) { + variant_pack[Q_t] = devPtrQ_t; + variant_pack[K_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + variant_pack[descale_dO_t] = devPtrDescaledO_t; } /* if (is_bias) { @@ -2410,11 +2691,16 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); } -} +} // NOLINT(readability/fn_size) #endif @@ -2423,57 +2709,87 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, + const Tensor* input_K, const Tensor* input_V, + const Tensor* input_SoftmaxOffset, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = input_Q->data.dptr; - void* devPtrK = input_K->data.dptr; - void* devPtrV = input_V->data.dptr; - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - + void *devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; + void *devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; + void *devPtrO = nullptr, *devPtrAmaxO = nullptr, *devPtrScaleO = nullptr; + void *devPtrAmaxS = nullptr, *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr; + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrO = output_O->data.dptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + devPtrAmaxO = output_O->amax.dptr; + } else if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrV = input_V->columnwise_data.dptr; + devPtrDescaleV = input_V->columnwise_scale_inv.dptr; + } + void* devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + int i = 0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_ZInv->data.dtype = DType::kFloat32; + } + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + int i = 0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; + devPtrZInv = nullptr; + if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrZInv = output_ZInv->data.dptr; + } + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); void* devPtrcuSeqlensKV = @@ -2488,17 +2804,20 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, + devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, @@ -2521,24 +2840,34 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } } // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, - const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, - const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, + const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, + const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, + const Tensor* input_S, const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, + Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; void* devPtrV = input_V->data.dptr; void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_K->scale_inv.dptr; + void* devPtrDescaleV = input_V->scale_inv.dptr; + void *devPtrQ_t = nullptr, *devPtrK_t = nullptr, *devPtrDescaleQ_t = nullptr, + *devPtrDescaleK_t = nullptr; + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrQ_t = input_Q->columnwise_data.dptr; + devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; + devPtrK_t = input_K->columnwise_data.dptr; + devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; + } void* devPtrO = input_O->data.dptr; const DType O_type = input_O->data.dtype; @@ -2548,25 +2877,46 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; + void *devPtrdO_t = nullptr, *devPtrdO_f16 = nullptr, *devPtrDescaledO_t = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrdO_t = input_dO->columnwise_data.dptr; + devPtrdO_f16 = input_dO_f16->data.dptr; + devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; + } void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; + void* devPtrZInv = (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; + + void *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr, *devPtrAmaxdP = nullptr, + *devPtrScaledP = nullptr, *devPtrDescaledP = nullptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrScaleS = input_S->scale.dptr; + devPtrDescaleS = input_S->scale_inv.dptr; + devPtrAmaxdP = input_output_dP->amax.dptr; + devPtrScaledP = input_output_dP->scale.dptr; + devPtrDescaledP = input_output_dP->scale_inv.dptr; + } - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; + void* devPtrSoftmaxOffset = nullptr; + void* devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void* devPtrdQ = output_dQ->data.dptr; void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dQ->amax.dptr; - void* devPtrAmaxdV = output_dQ->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dQ->scale.dptr; - void* devPtrScaledV = output_dQ->scale.dptr; + void *devPtrAmaxdQ = nullptr, *devPtrAmaxdK = nullptr, *devPtrAmaxdV = nullptr, + *devPtrScaledQ = nullptr, *devPtrScaledK = nullptr, *devPtrScaledV = nullptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrAmaxdQ = output_dQ->amax.dptr; + devPtrAmaxdK = output_dK->amax.dptr; + devPtrAmaxdV = output_dV->amax.dptr; + devPtrScaledQ = output_dQ->scale.dptr; + devPtrScaledK = output_dK->scale.dptr; + devPtrScaledV = output_dV->scale.dptr; + } void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); @@ -2582,21 +2932,28 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, - devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, + devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, + devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); + } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + // remove this when cuDNN FE supports FP8 + THD + NVTE_CHECK(input_ZInv != nullptr && input_ZInv->data.dptr != nullptr, + "ZInv tensor required for FP8 fused attention backward with T3HD layout."); fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 225e700eff..2f6c1105bd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,26 +15,31 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, + const Tensor *input_SoftmaxOffset, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, + const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index a897b09330..f37eeb0c68 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,6 +293,27 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_transpose_dim_idx] = d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 1ec1616c4a..b600261f40 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -27,11 +27,198 @@ enum NVTE_QKV_Matrix { NVTE_K_Matrix = 1, // keys NVTE_K_Matrix_Transpose = 2, // keys transposed NVTE_V_Matrix = 3, // values - NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_V_Matrix_Transpose = 4, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output }; +// Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) +struct MXFP8PaddedSizes { + int64_t s_q_padded; + int64_t s_kv_padded; + int64_t s_q_scale; + int64_t s_kv_scale; + int64_t s_q_scale_padded; + int64_t s_kv_scale_padded; + int64_t d_qk_padded; + int64_t d_v_padded; + int64_t d_qk_scale; + int64_t d_v_scale; + int64_t d_qk_scale_padded; + int64_t d_v_scale_padded; +}; + +// Pad s and d for MXFP8 quantization +inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { + constexpr int64_t block_size = 32; + MXFP8PaddedSizes p; + p.s_q_padded = ((s_q + 127) / 128) * 128; + p.s_kv_padded = ((s_kv + 127) / 128) * 128; + p.s_q_scale = (s_q + block_size - 1) / block_size; + p.s_kv_scale = (s_kv + block_size - 1) / block_size; + p.s_q_scale_padded = ((p.s_q_scale + 3) / 4) * 4; + p.s_kv_scale_padded = ((p.s_kv_scale + 3) / 4) * 4; + p.d_qk_padded = ((d_qk + 127) / 128) * 128; + p.d_v_padded = ((d_v + 127) / 128) * 128; + p.d_qk_scale = (d_qk + block_size - 1) / block_size; + p.d_v_scale = (d_v + block_size - 1) / block_size; + p.d_qk_scale_padded = ((p.d_qk_scale + 3) / 4) * 4; + p.d_v_scale_padded = ((p.d_v_scale + 3) / 4) * 4; + return p; +} + +// Get matrix strides for a 4D tensor [batch_size, num_heads, sequence_len, head_dim] given a QKV format. +// strides must point to at least 4 int64_t elements. +inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strides, NVTE_QKV_Format format) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strides[b_dim] = s * h * d; + strides[h_dim] = d; + strides[s_dim] = h * d; + strides[d_dim] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strides[b_dim] = h * d; + strides[h_dim] = d; + strides[s_dim] = b * h * d; + strides[d_dim] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strides[b_dim] = h * s * d; + strides[h_dim] = s * d; + strides[s_dim] = d; + strides[d_dim] = 1; + break; + default: + NVTE_CHECK(false, "Invalid format."); + break; + } +} + +// get matrix strides based on layout and matrix type +inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, int64_t s_q, + int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t *q_strides, int64_t *k_strides, + int64_t *v_strides, NVTE_QKV_Layout layout) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; + const NVTE_QKV_Format q_format = nvte_get_q_format(layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_strides, q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_strides, kv_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_v, v_strides, kv_format); + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } +} + void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); @@ -106,7 +293,10 @@ struct FADescriptor_v1 { float attnScale; bool isTraining; float dropoutProbability; - NVTE_QKV_Layout layout; + NVTE_QKV_Layout qkv_layout; + NVTE_QKV_Format o_format; + NVTE_QKV_Format do_format; + NVTE_QKV_Layout dqkv_layout; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; NVTE_Softmax_Type softmax_type; @@ -123,18 +313,19 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, - bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, return_max_logit) < + bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, + do_format, dqkv_layout, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, bias_type, + qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + return_max_logit) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, - rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.return_max_logit); + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, + rhs.o_format, rhs.do_format, rhs.dqkv_layout, rhs.mask_type, rhs.softmax_type, + rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.return_max_logit); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8d9adeb620..65cdaca7d0 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,6 +52,8 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ + NVTE_BHSD_BHSD_BHSD = 25, /*!< BHSD_BHSD_BHSD layout */ + NVTE_QKV_Layout_NOT_SET, /*!< Not set */ }; /*! \enum NVTE_QKV_Layout_Group @@ -70,6 +72,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, + /*! SD_SD_SD QKV layouts, e.g. BHSD_BHSD_BHSD */ + NVTE_SD_SD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -90,6 +94,10 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, + /*! BHSD QKV format, e.g. BHSD_BHSD_BHSD */ + NVTE_BHSD = 7, + /*! Not set */ + NVTE_QKV_Format_NOT_SET, }; /*! \enum NVTE_Bias_Type @@ -188,6 +196,22 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \brief Convert one NVTE_QKV_Format to another. + * + * \param[in] src_format The source format. + * \param[in] src_shape The source shape. + * \param[in] dst_format The destination format. + * \param[out] dst_shape The destination shape. + * \param[out] b The batch size. + * \param[out] h The number of heads. + * \param[out] s The sequence length. + * \param[out] d The head dimension. + * \param[out] t The number of tokens. + */ +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t); + /*! \brief Get fused attention backend based on input parameters. * * \param[in] is_training Whether the model is in training mode. @@ -274,6 +298,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -283,19 +308,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -347,6 +370,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. + * \param[in] do_format Output gradient's format. + * \param[in] dqkv_layout QKV gradient tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -366,11 +392,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * @@ -584,6 +611,36 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, cudaStream_t stream); +/*! \brief Permute Q, K, V to grouped tensors. + * + * \param[in] q Query tensor + * \param[in] k Key tensor + * \param[in] v Value tensor + * \param[out] q_out Output query tensor + * \param[out] k_out Output key tensor + * \param[out] v_out Output value tensor + * \param[in] original_layout Original QKV layout. + * \param[in] stream CUDA stream. + */ +void nvte_permute_to_grouped_tensor_fwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor q_out, + NVTETensor k_out, NVTETensor v_out, + NVTE_QKV_Layout original_layout, cudaStream_t stream); + +/*! \brief Permute Q, K, V back to original layout. + * + * \param[in] grad_q Gradient of query tensor + * \param[in] grad_k Gradient of key tensor + * \param[in] grad_v Gradient of value tensor + * \param[out] q Original query tensor + * \param[out] k Original key tensor + * \param[out] v Original value tensor + * \param[in] original_layout Original QKV layout. + * \param[in] stream CUDA stream. + */ +void nvte_permute_to_grouped_tensor_bwd(NVTETensor grad_q, NVTETensor grad_k, NVTETensor grad_v, + NVTETensor q, NVTETensor k, NVTETensor v, + NVTE_QKV_Layout original_layout, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6adba23a8f..96e6803ec5 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -48,7 +48,8 @@ .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ - .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -74,7 +75,8 @@ .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ + .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 442366035a..08e3e9aa46 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,6 +29,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, @@ -36,7 +37,6 @@ ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.constants import ( - TE_DType, QKVLayouts, dist_group_type, ) @@ -152,6 +152,18 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +_run_shadow_f16_fwd = os.getenv("NVTE_RUN_SHADOW_F16_FWD", "0") == "1" +_replace_out_return_with_shadow_f16 = ( + os.getenv("NVTE_REPLACE_OUT_RETURN_WITH_SHADOW_F16", "0") == "1" +) +_replace_out_save_with_shadow_f16 = os.getenv("NVTE_REPLACE_OUT_SAVE_WITH_SHADOW_F16", "0") == "1" +_replace_aux_with_shadow_f16 = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW_F16", "0") == "1" +_run_shadow_f16_bwd = os.getenv("NVTE_RUN_SHADOW_F16_BWD", "0") == "1" +_replace_dq_with_shadow_f16 = os.getenv("NVTE_REPLACE_DQ_WITH_SHADOW_F16", "0") == "1" +_replace_dk_with_shadow_f16 = os.getenv("NVTE_REPLACE_DK_WITH_SHADOW_F16", "0") == "1" +_replace_dv_with_shadow_f16 = os.getenv("NVTE_REPLACE_DV_WITH_SHADOW_F16", "0") == "1" +_qdq_dO_in_mxfp8_bprop = os.getenv("NVTE_QDQ_DO_IN_MXFP8_BPROP", "0") == "1" +_qdq_dO_in_f16_bprop = os.getenv("NVTE_QDQ_DO_IN_F16_BPROP", "0") == "1" class FP8EmulationFunc(torch.autograd.Function): @@ -173,15 +185,22 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - q_fp8, k_fp8, v_fp8 = combine_and_quantize( + # always in sbhd_sbhd_sbhd shape at this point + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype ) + if isinstance(quantizer, MXFP8Quantizer): + # always in bhsd_bhsd_bhsd shape at this point; permute it back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: - t_fp8 = quantizer(tensor1) - tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + if quantizer is not None: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) else: tensors = (tensor1, tensor2, tensor3) ctx.quantizer = quantizer @@ -193,16 +212,23 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou def backward(ctx, grad1, grad2, grad3): # pylint: disable=missing-function-docstring if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: - dt_fp8 = ctx.quantizer(grad1) - tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + if ctx.quantizer is not None: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + else: + tensors = grad1, grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] - dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + # always in sbhd_sbhd_sbhd shape at this point + dq_fp8, dk_fp8, dv_fp8, new_qkv_layout = combine_and_quantize( ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer ) tensors = combine_and_dequantize( - ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype + new_qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) + if isinstance(ctx.quantizer, MXFP8Quantizer): + # always in bhsd_bhsd_bhsd shape at this point; permute it back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -405,10 +431,9 @@ def forward( ) ) - batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 - # [b, np, sq, sk] + # [b, h, sq, sk] output_size = ( query_layer.size(1), query_layer.size(2), @@ -427,12 +452,7 @@ def forward( int(query_layer.shape[2] / value_layer.shape[2]), dim=2 ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting result tensor: [b * np, sq, sk] + # preallocting result tensor: [b * h, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], @@ -446,14 +466,15 @@ def forward( scale /= self.layer_number if fp8: + # get fp8 recipe for DPA + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=S_quantizer.dtype, device="cuda" @@ -461,25 +482,50 @@ def forward( dP_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=dP_quantizer.dtype, device="cuda" ) + # disable swizzle for MXFP8Quantizer + for quantizer in [ + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ]: + if isinstance(quantizer, MXFP8Quantizer): + quantizer.optimize_for_gemm = False + quantizer.internal = False - if "2" in qkv_layout or "3" in qkv_layout: - qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) - qkv_layout = "_".join([qkv_format] * 3) + # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + QKV_quantizer, + "QKV_quantizer", + "sbhd_sbhd_sbhd", ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + dQKV_quantizer, + "dQKV_quantizer", + "sbhd_sbhd_sbhd", ) - # Raw attention scores. [b * np, sq, sk] + # [sq, b, h, d] -> [sq, b * h, d] + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, h, d] -> [sk, b * h, d] + key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) + + # Raw attention scores. [b * h, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ).view(*output_size) @@ -487,8 +533,8 @@ def forward( elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" matmul_result = torch.bmm( - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] ) matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale @@ -513,8 +559,8 @@ def forward( ) matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ) @@ -531,13 +577,13 @@ def forward( # max attention score max_logit = None if self.return_max_logit: - # matmul_result [b, np, sq, dk], max_logit [np] + # matmul_result [b, h, sq, dk], max_logit [h] max_logit = matmul_result if attn_mask_type != "no_mask": max_logit = self.mask_func(matmul_result, attention_mask) max_logit = torch.amax(max_logit, dim=(0, 2, 3)) - # add attention sink to the last column: [b, np, sq, sk+1] + # add attention sink to the last column: [b, h, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( [ @@ -562,7 +608,7 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) - # remove attention sink: [b, np, sq, sk] + # remove attention sink: [b, h, sq, sk] if self.softmax_type != "vanilla": attention_probs = attention_probs[..., :-1] @@ -572,7 +618,7 @@ def forward( attention_probs = self.attention_dropout(attention_probs) # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] + # [sk, b, h, d] --> [b, h, sq, d] output_size = ( value_layer.size(1), value_layer.size(2), @@ -580,10 +626,10 @@ def forward( value_layer.size(3), ) - # change view [sk, b * np, hn] + # change view [sk, b * h, d] value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] + # change view [b * h, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) if fp8: @@ -592,37 +638,37 @@ def forward( attention_probs, None, None, S_quantizer, "S_quantizer", None ) - # matmul: [b * np, sq, hn] + # matmul: [b * h, sq, d] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] + # change view [b, h, sq, d] context_layer = context_layer.view(*output_size) if q_format == "sbhd": - # [b, np, sq, hn] --> [sq, b, np, hn] + # [b, h, sq, d] --> [sq, b, h, d] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(seqlen, batch_size, -1) + # [sq, b, h, d] --> [sq, b, hd] + context_layer = context_layer.view(max_seqlen_q, batch_size, -1) if q_format == "bshd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [b, sq, hp] - context_layer = context_layer.view(batch_size, seqlen, -1) + # [b, sq, h, d] --> [b, sq, hd] + context_layer = context_layer.view(batch_size, max_seqlen_q, -1) if q_format == "thd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [tq, np, hn] + # [b, sq, h, d] --> [tq, h, d] context_layer = ConvertBSHDtoTHD.apply( context_layer, cu_seqlens_q, ) - # [tq, np, hn] --> [tq, hp] + # [tq, h, d] --> [tq, hd] context_layer = context_layer.view(context_layer.shape[0], -1) if fp8: @@ -1198,21 +1244,26 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - # input types are inferred from the real data while output types are controlled by fp8_output - # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) + # qkv_layout may change due to MXFP8 quantization + # o_format should stay the same as original q_format + original_qkv_layout = qkv_layout + _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) + + # input types are inferred from real data while output types are controlled by fp8_output + # fp8_output should be set upstream assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) - # whether bwd kernel in FP8: + # whether fwd kernel will be run in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel will be run in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # get nominal data type for out @@ -1221,16 +1272,21 @@ def forward( out_nominal_dtype = q.dtype max_logit = None + orig_q, orig_k, orig_v = q, k, v + orig_qkv_layout = qkv_layout if fp8: fused_attention_backend = FusedAttnBackend["FP8"] # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E4M3 + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; + # dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer, used_in_backward=is_training + ) # print quantizers print_quantizers( @@ -1248,6 +1304,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, @@ -1270,6 +1327,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1280,20 +1338,84 @@ def forward( cuda_graph=is_graph_capturing(), ) - # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + if _run_shadow_f16_fwd: + # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + assert all( + x.dtype in [torch.float16, torch.bfloat16] for x in [q, k, v] + ), "q, k, v must be torch.float16 or torch.bfloat16" + out_f16_, aux_ctx_tensors_f16, *_ = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + orig_q, + orig_k, + orig_v, + out_nominal_dtype, + FusedAttnBackend["F16_arbitrary_seqlen"], + attn_bias, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + None, # s_quantizer + None, # o_quantizer + attn_scale, + dropout_p, + fast_zero_fill, + orig_qkv_layout, + o_format, + attn_bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + rng_gen, + softmax_offset, + return_max_logit, + is_graph_capturing(), + ) + # if torch.cuda.current_device() == 0: + # print( + # f"L{layer_number}: real/shadow out min:" + # f" {out_.min():.4f}/{out_f16_.min():.4f}, max:" + # f" {out_.max():.4f}/{out_f16_.max():.4f}" + # ) + # print( + # f"L{layer_number}: real/shadow stats min:" + # f" {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max:" + # f" {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}" + # ) + + # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ - out = out_ - - if isinstance(out_, Float8Tensor): - if not is_output_fp8 or not is_bwd_fp8: - out = out_.dequantize().view(out_.shape) - else: - if is_output_fp8 or ( + out_f16 = out_ + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if isinstance(out_, QuantizedTensorStorage): + if not is_output_fp8 or bwd_requires_o_f16: + out_f16 = out_.dequantize().view(out_.shape) + else: + if is_output_fp8 or bwd_requires_o_fp8: out_fp8 = O_quantizer(out_) # print quantizers @@ -1309,21 +1431,52 @@ def forward( ) # return appropriate tensors - out_ret = out_fp8 if is_output_fp8 else out + out_ret = out_fp8 if is_output_fp8 else out_f16 + if _run_shadow_f16_fwd and _replace_out_return_with_shadow_f16: + out_ret = out_f16_ + if _run_shadow_f16_fwd and _replace_aux_with_shadow_f16: + aux_ctx_tensors[0] = aux_ctx_tensors_f16[0] - # save appropriate tensors + # save q, k, v, o tensors fp8_tensors = (None, None, None, None) - qkvo_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) if is_bwd_fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) - qkvo_tensors = (None, None, None, out) - else: + f16_tensors = (None, None, None, out_f16) + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + if _run_shadow_f16_bwd: + f16_tensors = (q, k, v, out_f16) else: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) - qkvo_tensors = (q, k, v, out) + if _run_shadow_f16_fwd and not _replace_aux_with_shadow_f16: + tmp_quantizer = QKV_quantizer.copy() + if isinstance(tmp_quantizer, MXFP8Quantizer): + tmp_quantizer.optimize_for_gemm = False + q_fp8_, k_fp8_, _, _ = combine_and_quantize( + original_qkv_layout, q, k, v, tmp_quantizer, used_in_backward=True + ) + q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) + k_ = k_fp8_.dequantize(dtype=out_nominal_dtype) + if isinstance(tmp_quantizer, MXFP8Quantizer): + qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) + if qkv_format == "bshd": + q = q_.permute(0, 2, 1, 3).contiguous() + k = k_.permute(0, 2, 1, 3).contiguous() + elif qkv_format == "sbhd": + q = q_.permute(2, 0, 1, 3).contiguous() + k = k_.permute(2, 0, 1, 3).contiguous() + else: + q, k = q_, k_ + if _run_shadow_f16_fwd and _replace_out_save_with_shadow_f16: + out_f16 = out_f16_ + f16_tensors = (q, k, v, out_f16) else: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( @@ -1348,6 +1501,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1358,10 +1512,10 @@ def forward( return_max_logit, is_graph_capturing(), ) - out = out_ + out_f16 = out_ out_ret = out_ fp8_tensors = (None, None, None, None) - qkvo_tensors = (q, k, v, out) + f16_tensors = (q, k, v, out_f16) nvtx_range_pop(f"{nvtx_label}") @@ -1375,7 +1529,7 @@ def forward( if ctx.fp8: tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out] + tensor_list = [q, k, v, out_f16] mark_activation_offload(*tensor_list) mark_activation_offload(*aux_ctx_tensors) @@ -1385,7 +1539,7 @@ def forward( tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, - *qkvo_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, @@ -1433,9 +1587,16 @@ def forward( ctx.qkv_layout = reload_layout[:-1] else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout + ctx.o_format = o_format + # dqkv should have the same layout as the original qkv + ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type @@ -1454,15 +1615,50 @@ def forward( @staticmethod def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring + d_out_shadow_f16 = d_out - # d_out is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): - d_out = ctx.dO_quantizer(d_out) - if not ctx.use_FAv2_bwd: - d_out._data = d_out._data.contiguous() - elif not ctx.use_FAv2_bwd: + if _qdq_dO_in_f16_bprop or _qdq_dO_in_mxfp8_bprop: + d_out_qdq_f16 = d_out + d_out_qdq_f16, _ = dpa_utils.permute_to_grouped_tensor(ctx.o_format, d_out_qdq_f16) + tmp_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True + ) + tmp_quantizer.optimize_for_gemm = False + d_out_qdq_fp8 = tmp_quantizer(d_out_qdq_f16) + d_out_qdq_f16 = d_out_qdq_fp8.dequantize(dtype=ctx.nominal_dtype) + if ctx.o_format == "bshd": + d_out_qdq_f16 = d_out_qdq_f16.permute(0, 2, 1, 3).contiguous() + elif ctx.o_format == "sbhd": + d_out_qdq_f16 = d_out_qdq_f16.permute(2, 0, 1, 3).contiguous() + swapped_do_with_qdq_do = False + if ctx.fp8 and _qdq_dO_in_mxfp8_bprop: + d_out = d_out_qdq_f16 + swapped_do_with_qdq_do = True + if ctx.fp8 and _qdq_dO_in_mxfp8_bprop and _run_shadow_f16_bwd: + d_out_shadow_f16 = d_out_qdq_f16 + swapped_do_with_qdq_do = True + if not ctx.fp8 and _qdq_dO_in_f16_bprop: + d_out = d_out_qdq_f16 + swapped_do_with_qdq_do = True + # if swapped_do_with_qdq_do: + # print(f"swapped, {ctx.fp8=},{_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") + # else: + # print(f"not swapped, {ctx.fp8=}, {_qdq_dO_in_mxfp8_bprop=}, {_qdq_dO_in_f16_bprop=}, {_run_shadow_f16_bwd=}, {_replace_dq_with_shadow_f16=}, {_replace_dk_with_shadow_f16=}, {_replace_dv_with_shadow_f16=}") + + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() + d_out_fp8 = None + do_format = ctx.o_format + if ctx.fp8: + if ctx.fp8_recipe.mxfp8(): + d_out, do_format = dpa_utils.permute_to_grouped_tensor(do_format, d_out) + if isinstance(d_out, QuantizedTensorStorage): + d_out_fp8 = d_out + else: + d_out_fp8 = ctx.dO_quantizer(d_out) ( q_fp8, k_fp8, @@ -1480,6 +1676,10 @@ def backward(ctx, d_out, *_args): ) = restore_from_func_ctx(ctx) aux_ctx_tensors = other_tensors + aux_ctx_tensors_shadow_f16 = aux_ctx_tensors + out_shadow_f16 = out + original_qkv_layout = ctx.dqkv_layout + original_qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() @@ -1523,14 +1723,6 @@ def backward(ctx, d_out, *_args): dqkv_nominal_dtype = ctx.nominal_dtype if ctx.fp8: - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - if ctx.is_output_fp8: - d_out_fp8 = d_out - else: - d_out_fp8 = ctx.dO_quantizer(d_out) - # print quantizers print_quantizers( "FusedAttnFunc.backward >> before: ", @@ -1543,27 +1735,31 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # get tex.DType for dq, dk, dv data - dqkv_te_dtype = d_out_fp8._fp8_dtype - - # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # DelayedScaling/Float8CurrentScaling/MXFP8BlockScaling: + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # out_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # DelayedScaling: + # out_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # - # dq_, dk_, dv_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_ = ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8 - ) + # Float8CurrentScaling: + # out_: NVTE_DPA_FP8CS_O_in_F16=1: + # torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # NVTE_DPA_FP8CS_O_in_F16=0: + # Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: + # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_ = out + if ctx.fp8_recipe.mxfp8(): + out_ = out + aux_ctx_tensors.append(d_out) dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1575,7 +1771,6 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1587,6 +1782,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + do_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1595,23 +1793,100 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) + if _run_shadow_f16_bwd: + original_qkv_layout = ctx.dqkv_layout + tmp_quantizer = ctx.QKV_quantizer.copy() + if isinstance(tmp_quantizer, MXFP8Quantizer): + tmp_quantizer.optimize_for_gemm = False + q_fp8_, k_fp8_, v_fp8_, _ = combine_and_quantize( + original_qkv_layout, q, k, v, tmp_quantizer, used_in_backward=True + ) + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.dequantize(dtype=dqkv_nominal_dtype) for x in (q_fp8_, k_fp8_, v_fp8_) + ] + if isinstance(tmp_quantizer, MXFP8Quantizer): + if original_qkv_format == "bshd": + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.permute(0, 2, 1, 3).contiguous() + for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16) + ] + elif original_qkv_format == "sbhd": + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.permute(2, 0, 1, 3).contiguous() + for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16) + ] + dq_shadow_f16, dk_shadow_f16, dv_shadow_f16, *rest = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_shadow_f16, + k_shadow_f16, + v_shadow_f16, + out_shadow_f16, + d_out_shadow_f16, + dqkv_nominal_dtype, + aux_ctx_tensors_shadow_f16, + FusedAttnBackend["F16_arbitrary_seqlen"], + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + None, + None, + None, + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + original_qkv_layout, + original_qkv_format, + original_qkv_format, + original_qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ctx.softmax_type, + ctx.window_size, + ctx.bottom_right_diagonal, + ctx.deterministic, + is_graph_capturing(), + ) + if _replace_dq_with_shadow_f16: + dq_ = dq_shadow_f16 + if _replace_dk_with_shadow_f16: + dk_ = dk_shadow_f16 + if _replace_dv_with_shadow_f16: + dv_ = dv_shadow_f16 + # if torch.cuda.current_device() == 0: + # print( + # f"L{ctx.layer_number}: real/shadow dq min:" + # f" {dq_.min():.4f}/{dq_shadow_f16.min():.4f}, max:" + # f" {dq_.max():.4f}/{dq_shadow_f16.max():.4f}" + # ) + # print( + # f"L{ctx.layer_number}: real/shadow dk min:" + # f" {dk_.min():.4f}/{dk_shadow_f16.min():.4f}, max:" + # f" {dk_.max():.4f}/{dk_shadow_f16.max():.4f}" + # ) + # print( + # f"L{ctx.layer_number}: real/shadow dv min:" + # f" {dv_.min():.4f}/{dv_shadow_f16.min():.4f}, max:" + # f" {dv_.max():.4f}/{dv_shadow_f16.max():.4f}" + # ) # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_float8tensor = isinstance(dq_, Float8Tensor) - if is_float8tensor and not ctx.is_input_fp8: + is_quantized_tensor = isinstance(dq_, QuantizedTensorStorage) + if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( - ctx.qkv_layout, + ctx.dqkv_layout, dq_, dk_, dv_, src_nominal_dtype=dq_.dtype, ) - if not is_float8tensor and ctx.is_input_fp8: + if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv = combine_and_quantize( - ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer + dq, dk, dv, _ = combine_and_quantize( + ctx.dqkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) # print quantizers @@ -1628,7 +1903,6 @@ def backward(ctx, d_out, *_args): else: if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) - dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, @@ -1641,7 +1915,6 @@ def backward(ctx, d_out, *_args): out, d_out, dqkv_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1653,6 +1926,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + do_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1817,9 +2093,9 @@ def forward( fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), "No fused attention backend supports this input combination!" assert all( - x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, QuantizedTensorStorage) for x in [query_layer, key_layer, value_layer] - ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." + ), "FusedAttention only supports FP16 and BF16 data types, or QuantizedTensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FusedAttention only supports CUDA tensors." @@ -1925,7 +2201,7 @@ def forward( " with FP8!" ) if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + all_quantizers = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) for q in all_quantizers: if isinstance(q, Float8CurrentScalingQuantizer): q.with_amax_reduction = True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..94422b2750 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -22,13 +22,11 @@ ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.graph import is_graph_capturing -from transformer_engine.pytorch.constants import ( - dist_group_type, - TE_DType, -) +from transformer_engine.pytorch.constants import dist_group_type from transformer_engine.pytorch.distributed import ( get_distributed_world_size, get_distributed_rank, @@ -59,6 +57,18 @@ _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def get_bsh_dims(tensor_format): + """Get batch dimension and sequence dimension from tensor format""" + if tensor_format in ["bshd", "sbhd", "bhsd"]: + batch_dim = tensor_format.index("b") + seq_dim = tensor_format.index("s") + head_dim = tensor_format.index("h") + else: # tensor_format == "thd" + batch_dim = seq_dim = tensor_format.index("t") + head_dim = tensor_format.index("h") + return batch_dim, seq_dim, head_dim + + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -237,10 +247,10 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -251,12 +261,12 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -410,15 +420,32 @@ def flash_attn_a2a_communicate( cp_stream: torch.cuda.Stream, before_attn: bool, qkv_format: str = "bshd", - cu_seqlens_padded: torch.Tensor = None, + cu_seqlens_q_padded: torch.Tensor = None, + cu_seqlens_kv_padded: torch.Tensor = None, + a2a_input_names: List[str] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" - - assert ( - qkv_format != "thd" or cu_seqlens_padded is not None - ), "cu_seqlens_padded is required for THD format!" + assert a2a_input_names in [ + ["q", "k", "v"], + ["out"], + ["dout"], + ["dq", "dk", "dv"], + ], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" + if a2a_input_names in [["out"], ["dout"]]: + assert qkv_format != "thd" or cu_seqlens_q_padded is not None, ( + f"flash_attn_a2a_communicate requires cu_seqlens_q_padded for {a2a_input_names} with" + " THD format!" + ) + if a2a_input_names in [["q", "k", "v"], ["dq", "dk", "dv"]]: + assert qkv_format != "thd" or ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), ( + "flash_attn_a2a_communicate requires cu_seqlens_q_padded and cu_seqlens_kv_padded for" + f" {a2a_input_names} with THD format!" + ) a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + _, _, head_dim = get_bsh_dims(qkv_format) if before_attn: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -430,18 +457,24 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # reorder the sequence chunks x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] + # [b, h//cp, cp*2, s//2, d] -> [b, h//cp, cp*s, d] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" - # [cp, t, np//cp, hn] -> [cp*t, np//cp, hn] + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i - 2] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) + # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks a2a_outputs[i - 2] = reorder_seq_chunks_after_a2a_before_attn_thd( @@ -450,14 +483,21 @@ def flash_attn_a2a_communicate( if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] - # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] - # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # [s, b, h, d] -> [s, b, cp, h//cp, d] + # [b, h, s, d] -> [b, cp, h//cp, s, d] + # [t, h, d] -> [t, cp, h//cp, d] + x = x.view( + *x.shape[:head_dim], + cp_size, + x.shape[head_dim] // cp_size, + *x.shape[head_dim + 1 :], + ) + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] + # [b, cp, h//cp, s, d] -> [cp, b, h//cp, s, d] + # [t, cp, h//cp, d] -> [cp, t, h//cp, d] + a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -467,30 +507,57 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - if qkv_format in ["bshd", "sbhd"]: - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + if qkv_format in ["bshd", "sbhd", "bhsd"]: + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [b, h//cp, cp*s, d] -> [b, h//cp, cp*2, s//2, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) else: # qkv_format == "thd" + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) - # [cp*t, np//cp, hn] -> [cp, t, np//cp, hn] + # [cp*t, h//cp, d] -> [cp, t, h//cp, d] a2a_inputs[i] = x.view(cp_size, -1, *x.shape[-2:]) if i > 1: with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] - # or [t, cp, np//cp, hn] -> [t, np, hn] + # [cp, 2, b, s//2, h//cp, d] -> [2, b, s//2, cp, h//cp, d] + # [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [cp, 2, b, h//cp, s//2, d] -> [2, b, cp, h//cp, s//2, d] + # [cp, t, h//cp, d] -> [t, cp, h//cp, d] + tmp_list = list(qkv_format) + if "t" not in qkv_format: + tmp_list.insert(0, "2") + tmp_list.insert(0, "c") + tmp_format = "".join(tmp_list) + head_dim_ = tmp_format.index("h") - 1 + tmp_list.insert(head_dim_, tmp_list.pop(0)) + x = x.movedim(0, head_dim_) + # [2, b, s//2, cp, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # [2, s//2, b, cp, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [2, b, cp, h//cp, s//2, d] -> [b, cp, h//cp, 2, s//2, d] + # [t, cp, h//cp, d] -> [t, cp, h//cp, d] + if "t" not in qkv_format: + tmp_format = "".join(tmp_list) + seq_dim_ = tmp_format.index("s") - 1 + tmp_list.insert(seq_dim_, tmp_list.pop(0)) + x = x.movedim(0, seq_dim_) + else: + seq_dim_ = 0 + x = x.contiguous() + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] + # [b, cp, h//cp, 2, s//2, d] -> [b*h, s, d] + # [t, cp, h//cp, d] -> [t, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -775,13 +842,16 @@ def cp_p2p_fwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step, O_quantizer_per_step, rank, @@ -867,11 +937,17 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_kv_padded_ = cu_seqlens_kv_padded fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -888,7 +964,8 @@ def cp_p2p_fwd_fused_attn( fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs, @@ -900,7 +977,10 @@ def cp_p2p_fwd_fused_attn( ) if fp8: - softmax_lse_per_step, _, rng_states = aux_ctx_tensors + if qkv_layout != "t3hd": + softmax_lse_per_step, rng_states = aux_ctx_tensors + else: + softmax_lse_per_step, _, rng_states = aux_ctx_tensors else: softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None @@ -1065,15 +1145,19 @@ def cp_p2p_bwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, + do_format, + dqkv_layout, attn_mask_type, attn_bias_type, deterministic, fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, S_quantizer, dP_quantizer_per_step, dQKV_quantizer_per_step, + QKV_quantizer_per_step, + dO_quantizer_per_step, q_part, k_part, v_part, @@ -1083,11 +1167,14 @@ def cp_p2p_bwd_fused_attn( ): """Per-tile backward call of CP P2P with FusedAttention backend""" if fp8: - aux_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - step - 1], - ] + if qkv_layout == "t3hd": + aux_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] else: aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] @@ -1106,11 +1193,14 @@ def cp_p2p_bwd_fused_attn( elif section == "upper-triangle": q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] if fp8: - aux_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - step - 1], - ] + if qkv_layout == "t3hd": + aux_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] else: aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] @@ -1123,16 +1213,32 @@ def cp_p2p_bwd_fused_attn( fp8_meta_kwargs = {} if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip( - [q_fp8, kv_fp8, kv_fp8], - [q_part, k_part, v_part], + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer_per_step, + used_in_forward=False, + used_in_backward=True, ) - ] - if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): - out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + if not fp8_recipe.mxfp8(): + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + else: + dout_part, do_format = dpa_utils.permute_to_grouped_tensor(do_format, dout_part) + aux_tensors.append(dout_part) + dout_part = dO_quantizer_per_step(dout_part) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step @@ -1148,7 +1254,6 @@ def cp_p2p_bwd_fused_attn( out_part, dout_part, bwd_nominal_dtype, - bwd_output_te_dtype, aux_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -1156,6 +1261,9 @@ def cp_p2p_bwd_fused_attn( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, + do_format=do_format, + dqkv_layout=dqkv_layout, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, deterministic=deterministic, @@ -1313,16 +1421,15 @@ def forward( ) # set up attention args - enable_mla = k.shape[-1] != v.shape[-1] - causal = "causal" in attn_mask_type - if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - + causal = "causal" in attn_mask_type + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = q.shape[:-1] + v.shape[-1:] batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None @@ -1337,13 +1444,10 @@ def forward( else: cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size - max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] - - fused_attn_backend = None amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] @@ -1352,9 +1456,9 @@ def forward( assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." fwd_nominal_dtype = q.dtype - is_input_fp8 = isinstance(q, Float8Tensor) + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; @@ -1362,7 +1466,6 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - ( QKV_quantizer, O_quantizer, @@ -1370,43 +1473,58 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) - q_f16 = None + # q, k, v a2a: gather s and split h + # FP8DS/CS: Float8Tensor -> torch.uint8 -> Float8Tensor + # MXFP8/F16: fwd_nominal_dtype q_fp8, k_fp8, v_fp8 = (None, None, None) - # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if fp8 and is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = (q._data, k._data, v._data) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + [q, k, v], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + True, + qkv_format=qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) - if fp8 and is_input_fp8: + if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) ] q, k, v = q_fp8, k_fp8, v_fp8 + post_a2a_o_shape = q.shape[:-1] + v.shape[-1:] # convert qkv to the right type + q_f16 = None + fused_attn_backend = None if fp8: assert use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - else: + elif not fp8_recipe.mxfp8(): # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers @@ -1427,10 +1545,11 @@ def forward( # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + S_quantizer_per_step[i] = S_quantizer.copy() if S_quantizer is not None else None O_quantizer_per_step[i] = O_quantizer.copy() - O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not fp8_recipe.mxfp8(): + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype @@ -1482,7 +1601,6 @@ def forward( attn_bias_ = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) @@ -1557,17 +1675,22 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + # q, k, v, o: + # causal: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # non-causal: [b, s, h, d] or [s, b, h, d] p2p_comm_buffers = [None for _ in range(cp_size)] k_shape = k.shape k_numel = k.numel() v_shape = v.shape + o_shape = q.shape[:-1] + v.shape[-1:] p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] # P2P communication and compute: each rank has cp_size steps - # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype - # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 + # MXFP8/F16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # FP8DS/CS attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + o_format = qkv_format for i in range(cp_size + 1): if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): @@ -1621,13 +1744,16 @@ def forward( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step[i], O_quantizer_per_step[i], rank, @@ -1775,8 +1901,8 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, h, sq, 1] -> [b, h, sq] or - # [t, h, 1] -> [t, np] + # [b, h, sq, 1] -> [b, h, sq] + # [t, h, 1] -> [t, h] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( @@ -1788,21 +1914,16 @@ def forward( out_per_step[i - 1] = out_per_step[i - 1].dequantize( dtype=torch.float32 ) - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": - if enable_mla: - out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape - ) + if fp8: + out = torch.zeros_like(out_per_step[0]).view(o_shape) else: - # MHA or GQA - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( - q.shape - ) + out = torch.zeros(o_shape, dtype=q.dtype, device=q.device) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -1842,7 +1963,7 @@ def forward( # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: if i == 0: out = flash_attn_fwd_out_correction_init( out_per_step[0], @@ -1850,10 +1971,7 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - if enable_mla: - out = out.view(v_shape) - else: - out = out.view(q.shape) + out = out.view(o_shape) else: flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), @@ -1862,7 +1980,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1873,7 +1991,7 @@ def forward( softmax_lse_in_packed_format, ) else: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: flash_attn_fwd_second_half_out_correction( out, out_per_step[i], @@ -1881,7 +1999,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1891,35 +2009,31 @@ def forward( True, softmax_lse_in_packed_format, ) - - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - ctx.batch_size = out.shape[0] - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - ctx.batch_size = out.shape[1] + out = out.view(post_a2a_o_shape) + out_part = out.to(fwd_nominal_dtype) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + out, + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + False, + qkv_format=o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + out = out.view(orig_o_shape) if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False ) - elif not use_fused_attention: - out = out.view(-1, *out.shape[-2:]) # update FP8 quantizers: amax across cp_size steps - if fp8 and use_fused_attention: + if fp8 and use_fused_attention and not fp8_recipe.mxfp8(): amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) O_quantizer.amax.copy_(amax_cp_fwd[1]) @@ -1942,7 +2056,11 @@ def forward( out_f16 = out.to(fwd_nominal_dtype) if fp8 and ( is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() + ) ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 @@ -1953,7 +2071,7 @@ def forward( kv_fp8 = None kv = p2p_comm_buffers[-1] - if fp8: + if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8], [q, kv]) @@ -1961,17 +2079,28 @@ def forward( # q, kv, out fp8_tensors = (None, None, None) f16_tensors = (None, None, None) + out_f16 = out_part if ctx.fp8: # fwd: fp8, bwd: fp8, save all fp8 fp8_tensors = (q_fp8, kv_fp8, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: f16_tensors = (None, None, out_f16) - elif fp8 and is_input_fp8: + elif fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) + elif fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): # fwd: fp8, bwd: f16, save all f16 # dequantize fp8 inputs q_f16 = q_fp8.dequantize() kv_f16 = kv_fp8.dequantize() f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and is_input_fp8 and fp8_recipe.mxfp8(): + # fwd: fp8, bwd: f16, save all f16 + # there is already an F16 version of the inputs + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q, k, v) + kv_f16 = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) + f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and not is_input_fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) elif fp8: # fwd: fp8, bwd: f16, save all f16 # inputs are already in f16 @@ -2009,7 +2138,6 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape @@ -2022,12 +2150,19 @@ def forward( ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 - ctx.enable_mla = enable_mla + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape + ctx.orig_o_shape = orig_o_shape + ctx.post_a2a_o_shape = post_a2a_o_shape ctx.k_numel = k_numel ctx.k_shape = k_shape ctx.v_shape = v_shape - + ctx.o_shape = o_shape + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -2036,14 +2171,14 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop(f"{nvtx_label}") - if return_max_logit: return out_ret, max_logit return out_ret @@ -2057,8 +2192,13 @@ def backward(ctx, dout, *_args): nvtx_range_push(f"{nvtx_label}") # dout is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): + # but in the case it's not, convert it to FP8 (except for MXFP8) before any operation + if ( + ctx.fp8 + and ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2098,7 +2238,6 @@ def backward(ctx, dout, *_args): # set up attention args causal = "causal" in ctx.attn_mask_type seq_dim = None - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") @@ -2137,13 +2276,13 @@ def backward(ctx, dout, *_args): if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() # [b, h, sq//2] -> [b, h, sq//2, 1] or - # [t//2, np] -> [t//2, h, 1] + # [t//2, h] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() # [b, h, sq] -> [b, h, sq, 1] or - # [t, np] -> [t, h, 1] + # [t, h] -> [t, h, 1] softmax_lse.unsqueeze_(-1) # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 @@ -2158,28 +2297,29 @@ def backward(ctx, dout, *_args): buffer_dtype = torch.uint8 dq_buffer = None dout_fp8 = None - bwd_output_te_dtype = None dkv_buffer = None if ctx.fp8: - assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" fused_attn_backend = FusedAttnBackend["FP8"] - q, kv, out = ( - q_fp8._data, - kv_fp8._data, - ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8._data - ), - ) + if not ctx.fp8_recipe.mxfp8(): + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype # dout: torch.Tensor, dtype=torch.uint8 - if ctx.is_output_fp8: + if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout - else: + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) - dout = dout_fp8._data + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data # print quantizers print_quantizers( @@ -2193,9 +2333,6 @@ def backward(ctx, dout, *_args): ctx.dP_quantizer, ) - # dout_fp8._fp8_dtype - bwd_output_te_dtype = ctx.dO_quantizer.dtype - # create buffers for reduction in float32 if ctx.fp8_recipe.delayed(): dq_buffer = torch.empty( @@ -2203,7 +2340,7 @@ def backward(ctx, dout, *_args): dtype=buffer_dtype, device=q.device, ) - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_buffer = torch.empty( q.shape, dtype=torch.float32, @@ -2217,7 +2354,7 @@ def backward(ctx, dout, *_args): ) dkv_recv_buffer = torch.empty_like(dkv_send_buffer) p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dkv_buffer = torch.zeros( kv.shape, dtype=torch.float32, @@ -2230,10 +2367,13 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dP_quantizer_per_step[i] = ( + ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + ) dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() - dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not ctx.fp8_recipe.mxfp8(): + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) @@ -2244,34 +2384,28 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # communicate for the 'a2a' part of 'a2a+p2p' + dout = dout.view(*ctx.orig_o_shape) if cp_size_a2a > 1: - if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) - out, dout = flash_attn_a2a_communicate( - [out, dout], + dout = flash_attn_a2a_communicate( + dout, chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) - - if ctx.enable_mla: - out = out.view(*ctx.v_shape) - dout = dout.view(*ctx.v_shape) - else: - # MHA or GQA - out = out.view(*q.shape) - dout = dout.view(*q.shape) + out = out.view(*ctx.o_shape) + dout = dout.view(*ctx.o_shape) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -2368,10 +2502,11 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or ctx.fp8_recipe.mxfp8() else out_fp8 ), - dout_fp8, + dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, softmax_lse, softmax_lse_, rng_states, @@ -2388,16 +2523,20 @@ def backward(ctx, dout, *_args): fused_attn_backend, ctx.softmax_scale, ctx.dropout_p, - qkv_layout, + ctx.qkv_layout, + ctx.qkv_format, + ctx.qkv_format, + ctx.qkv_layout, ctx.attn_mask_type, ctx.attn_bias_type, ctx.deterministic, ctx.fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, ctx.S_quantizer, dP_quantizer_per_step[i], dQKV_quantizer_per_step[i], + ctx.QKV_quantizer, + ctx.dO_quantizer, ] else: flash_attn_inputs = [ @@ -2471,7 +2610,7 @@ def backward(ctx, dout, *_args): if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8_recipe.delayed(): dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] # copy dq_ into the right buffer position @@ -2555,7 +2694,7 @@ def backward(ctx, dout, *_args): # dkv correction if ctx.fp8 and ctx.fp8_recipe.delayed(): dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] - elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + elif ctx.fp8 and (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()): dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] @@ -2645,9 +2784,10 @@ def backward(ctx, dout, *_args): # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + if not ctx.fp8_recipe.mxfp8(): + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) dq = dq_buffer if ctx.fp8_recipe.delayed(): @@ -2661,7 +2801,7 @@ def backward(ctx, dout, *_args): for x in [dq, dk, dv] ] dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.qkv_layout, dq, dk, dv, @@ -2670,7 +2810,7 @@ def backward(ctx, dout, *_args): ) dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dk = dkv[: ctx.k_numel].view(ctx.k_shape) dv = dkv[ctx.k_numel :].view(ctx.v_shape) @@ -2686,7 +2826,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers @@ -2704,7 +2844,8 @@ def backward(ctx, dout, *_args): if cp_size_a2a > 1: if ctx.fp8 and ctx.is_input_fp8: dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv - dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) + if not ctx.fp8_recipe.mxfp8(): + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2714,16 +2855,22 @@ def backward(ctx, dout, *_args): ctx.cp_group_a2a, ctx.cp_stream, False, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) - if ctx.fp8 and ctx.is_input_fp8: + if ctx.fp8 and ctx.is_input_fp8 and not ctx.fp8_recipe.mxfp8(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) ] - if ctx.qkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif ctx.qkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + dq, dk, dv = [ + x.view(y) + for x, y in zip( + [dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape] + ) + ] if attn_dbias is not None: # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] @@ -2821,27 +2968,42 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + o_format = qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) - qkv_dtype = q.dtype - - causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" - if use_fused_attention and causal and "bottom_right" not in attn_mask_type: - attn_mask_type = attn_mask_type + "_bottom_right" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert qkv_format != "thd", f"No support for cp_comm_type='all_gather' and {qkv_format=}." + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." + assert ( + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='all_gather' and {attn_bias_type=}." assert ( - use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or fa_utils.v2_3_plus + ), ( + "cp_comm_type='all_gather' only supports SWA through FusedAttention or FlashAttention" + f" >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='all_gather' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -2874,14 +3036,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - seq_dim = qkv_format.index("s") - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) if use_fused_attention or qkv_format == "thd": @@ -2890,30 +3044,90 @@ def forward( cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) else: cu_seqlens_q_padded = None + if use_fused_attention and attn_mask_type == "causal": + attn_mask_type = attn_mask_type + "_bottom_right" + causal = "causal" in attn_mask_type - # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + # FP8 setup + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + ( + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + fwd_nominal_dtype = q.dtype + q_fp8, k_fp8, v_fp8 = (q, k, v) if is_input_fp8 else (None, None, None) + q_f16, k_f16, v_f16 = (None, None, None) if is_input_fp8 else (q, k, v) + fused_attn_backend = None + fp8_meta_kwargs = {} + if fp8: + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + if not is_input_fp8 and not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer + elif use_fused_attention: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + orig_q_shape, _, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] + + # q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # reshape: split s + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) + # s dim first for all-gather + # [b, s, h, d]/[s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] - # [s, b, h, d] -> [cp, s, b, h, d] + # gather along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # pick out specific chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # reshape/flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # k_ag: [cp*s, b, h, d] + # v_ag: [cp*s, b, h, d] + # out_f16: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + q_shape, k_shape, v_shape = q.shape, k.shape, v.shape + o_shape = q.shape[:-1] + v.shape[-1:] + out_f16 = torch.empty(o_shape, dtype=fwd_nominal_dtype, device=q.device) + # create two streams to resolve wave quantization issue of Flash Attn in each step flash_attn_streams = [torch.cuda.current_stream(), cp_stream] - + # prepare per-step tensors local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] kv_seq_range_per_step = [None, None] window_size_per_step = [None, None] @@ -2921,16 +3135,15 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) max_logit_per_step = [None, None] max_logit = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( local_seq_chunk_ids[i], @@ -2950,13 +3163,27 @@ def forward( cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( k.shape[1], max_seqlen_kv_, k.device ) - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] if use_fused_attention: + new_qkv_layout = qkv_layout + if fp8: + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) ( out_per_step[i], - [softmax_lse_per_step[i], rng_states[i]], + aux_ctx_tensors, *max_logit_, ) = fused_attn_fwd( is_training, @@ -2964,14 +3191,15 @@ def forward( max_seqlen_kv_, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - qkv_dtype, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + q_part, + k_part, + v_part, + fwd_nominal_dtype, + fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -2980,9 +3208,19 @@ def forward( window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + if fp8: + if qkv_layout != "t3hd": + softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *_ = aux_ctx_tensors if return_max_logit: max_logit_per_step[i] = max_logit_[0] + if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): + out_per_step[i] = out_per_step[i].dequantize(dtype=fwd_nominal_dtype) else: fa_forward_args_thd = get_fa_args( True, @@ -2999,9 +3237,9 @@ def forward( fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( - q_, - k_, - v_, + q_part, + k_part, + v_part, *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, @@ -3017,61 +3255,152 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + # out_per_step[i]: fwd_nominal_dtype, [b, s//2, h, d] or [s//2, b, h, d] + # out_f16: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # max_logit_per_step[i]: torch.float32, [h] + # max_logit: torch.float32, [h] if return_max_logit and i == 0: max_logit = torch.clone(max_logit_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1]) - elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1]) + if o_format == "bshd": + out_f16[:, i - 1].copy_(out_per_step[i - 1]) + elif o_format == "sbhd": + out_f16[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) + + # all reduce max_logit across ranks if return_max_logit: torch.distributed.all_reduce( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - if use_fused_attention: - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - else: - out = out.view(-1, *out.shape[-2:]) + # out_f16: fwd_nominal_dtype + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + out_f16 = out_f16.view(orig_o_shape) - ctx.save_for_backward( - q, - k, - v, + # prepare for forward output and backward saves of out + out_fp8 = None + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if fp8 and (is_output_fp8 or bwd_requires_o_fp8): + out_fp8 = O_quantizer(out_f16) + out_ret = out_fp8 if is_output_fp8 else out_f16 + + # save tensors for backward + ctx.fp8 = fp8 and is_bwd_fp8 + ctx.fp8_recipe = fp8_recipe + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + # True: q split along s; k/v with s first, i.e. [s, b, h, d] + # False: original [b, s, h, d] or [s, b, h, d] + ctx.qkv_reshaped = True + # no load-balance related token shuffling; original token order in q/k/v/out_f16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out_f16/out_fp8: [b, s, h, d] or [s, b, h, d] + if ctx.fp8: + # q_fp8_save: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k_fp8_save: [s, b, h, d] + # v_fp8_save: [s, b, h, d] + q_fp8_save, k_fp8_save, v_fp8_save = None, None, None + if fp8_recipe.delayed() or fp8_recipe.float8_current_scaling(): + q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) + k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) + v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v in FP8, o in f16 + # MXFP8: q/k/v/o all in f16 + if fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) + f16_tensors = (None, None, None, out_f16) + elif fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_f16) + elif fp8: + # convert q/k/v to F16 if necessary, and save q/k/v/o all in F16 and original format + if is_input_fp8: + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + f16_tensors = (q_f16, k_f16, v_f16, out_f16) + ctx.qkv_reshaped = False + else: + # save all in F16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out_f16: [b, s, h, d] or [s, b, h, d] + f16_tensors = (q, k, v, out_f16) + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_q_padded, *cu_seqlens_kv_per_step, - *out_per_step, *softmax_lse_per_step, *rng_states, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects - ctx.qkv_dtype = qkv_dtype + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_format = qkv_format + ctx.dqkv_layout = qkv_layout + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.q_shape = q_shape + ctx.k_shape = k_shape + ctx.v_shape = v_shape + ctx.o_shape = o_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step + ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.O_quantizer = O_quantizer.copy() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: - return out, max_logit - return out + return out_ret, max_logit + return out_ret @staticmethod def backward(ctx, dout, *_args): @@ -3080,22 +3409,94 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] - cu_seqlens_kv_per_step = saved_tensors[5:7] - out_per_step = saved_tensors[7:9] - softmax_lse_per_step = saved_tensors[9:11] - rng_states = saved_tensors[11:13] + cu_seqlens_kv_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] + ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_q_padded, + cu_seqlens_kv_per_step[0], + cu_seqlens_kv_per_step[1], + softmax_lse_per_step[0], + softmax_lse_per_step[1], + rng_states[0], + rng_states[1], + ) = restore_from_func_ctx(ctx) kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step - seq_dim = ctx.qkv_format.index("s") - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(ctx.qkv_format) + _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_o, _ = get_bsh_dims(ctx.o_format) + causal = "causal" in ctx.attn_mask_type - dout = dout.view(q.shape) - dq = torch.empty_like(q) - dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) - dv = torch.zeros_like(dk) + # set up dout: + # FP8DS/CS: torch.uint8, [b, s, h, d] or [s, b, h, d] + # MXFP8/F16: torch.float16 or torch.bfloat16, [b, s, h, d] or [s, b, h, d] + dout_fp8 = None + if ctx.fp8: + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + dout = dout.view(ctx.o_shape) + + # set up q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): + q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if not ctx.qkv_reshaped: + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] + + # set up out: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): torch.uint8 + # FP8CS+_dpa_fp8_cs_o_in_f16: torch.float16 or torch.bfloat16 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + if ctx.fp8 and ( + ctx.fp8_recipe.delayed() + or (ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ): + out = out_fp8._data + out = out.view(ctx.o_shape) + + # set up dq, dk, dv: + # dq: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dk: fwd_nominal_dtype, [cp*s, b, h, d] + # dv: fwd_nominal_dtype, [cp*s, b, h, d] + dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + dk = torch.zeros( + (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=k.device, + ) + dv = torch.zeros( + (ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=v.device, + ) dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3105,23 +3506,22 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, h, d] -> [cp, s, b, h, d] + # gather k and v along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s: [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # select appropriate chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) - local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - + # set up flash_attn_bwd flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3153,57 +3553,126 @@ def backward(ctx, dout, *_args): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], kv_seq_range_per_step[i][1], ) max_seqlen_kv = seq_end_idx - seq_start_idx - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] - out_ = out_per_step[i] - dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + out_part = out.select(seq_dim_o, i).contiguous() + dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + if ctx.fp8 and ctx.qkv_layout == "t3hd": + aux_ctx_tensors = [ + softmax_lse_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + ] + else: + aux_ctx_tensors = [ + softmax_lse_per_step[i], + rng_states[i], + ] + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8_meta_kwargs = {} + new_qkv_layout = ctx.qkv_layout + do_format = ctx.o_format + if ctx.fp8: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o/do all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v/do in FP8, o in f16 + # MXFP8: q/k/v/do all in MXFP8, o/do_f16 in F16 + if not ctx.fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=ctx.fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + if ctx.fp8_recipe.delayed() or ( + ctx.fp8_recipe.float8_current_scaling() + and not _dpa_fp8_cs_o_in_f16 + ): + out_part = Float8Tensor.make_like( + out_fp8, data=out_part, dtype=ctx.fwd_nominal_dtype + ) + dout_part = Float8Tensor.make_like( + dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype + ) + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + ctx.qkv_layout, + q_part, + k_part, + v_part, + ctx.QKV_quantizer, + used_in_forward=False, + used_in_backward=True, + ) + dout_part, do_format = dpa_utils.permute_to_grouped_tensor( + do_format, dout_part + ) + aux_ctx_tensors.append(dout_part) + dout_part = ctx.dO_quantizer(dout_part) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - out_, - dout_, - ctx.qkv_dtype, - TE_DType[dout.dtype], + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.fwd_nominal_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=ctx.o_format, + do_format=do_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + if ctx.fp8 and all( + isinstance(x, QuantizedTensorStorage) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ): + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + x.dequantize(dtype=ctx.fwd_nominal_dtype) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ] else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_, k_, v_] + torch.empty_like(x) for x in [q_part, k_part, v_part] ] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - ctx.qkv_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=ctx.max_seqlen_q, @@ -3220,29 +3689,34 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] if ctx.use_flash_attn_3: - fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["is_causal"] = causal else: - fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["causal"] = causal flash_attn_bwd( - dout_, - q_, - k_, - v_, - out_, + dout_part, + q_part, + k_part, + v_part, + out_part, softmax_lse_per_step[i], *fa_backward_args_thd, **fa_backward_kwargs, ) if i > 0: + # dq/dk/dv, dq_per_step/dk_per_step/dv_per_step: ctx.fwd_nominal_dtype with torch.cuda.stream(flash_attn_streams[i - 1]): - if ctx.qkv_format == "bshd": + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dq_per_step[i]: [b, s//2, h, d] or [s//2, b, h, d] + if ctx.dqkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.qkv_format == "sbhd": + elif ctx.dqkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] + # dk/dv: [cp*s, b, h, d] + # dk_per_step[i - 1]/dv_per_step[i - 1]: [s_range, b, h, d] or [b, s_range, h, d] + # move s to first dim: [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim, 0).contiguous() + x.movedim(seq_dim_dqkv, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] ] # wait until dkv update of last step is done @@ -3252,6 +3726,7 @@ def backward(ctx, dout, *_args): kv_seq_range_per_step[i - 1][0], kv_seq_range_per_step[i - 1][1], ) + # add to dk/dv: [cp*s, b, h, d] dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) if i < len(local_seq_chunk_ids): @@ -3259,23 +3734,33 @@ def backward(ctx, dout, *_args): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + # put back together the right chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) + # reduce scatter: [cp*s, b, h, d] -> [s, b, h, d] dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) - dk = dk.movedim(0, seq_dim).contiguous() - dv = dv.movedim(0, seq_dim).contiguous() - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") + # reshape to original format: + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dk: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dv: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + dq = dq.view(*dq.shape[:seq_dim_dqkv], -1, *dq.shape[(seq_dim_dqkv + 2) :]) + dk = dk.movedim(0, seq_dim_dqkv).contiguous() + dv = dv.movedim(0, seq_dim_dqkv).contiguous() + + # quantize if necessary + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize(ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( None, dq, @@ -3298,6 +3783,10 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, + None, + None, ) @@ -3342,24 +3831,43 @@ def forward( ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) - + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + original_qkv_layout = qkv_layout + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] + o_format = qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_o, _ = get_bsh_dims(o_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type + + if qkv_format in ["bshd", "sbhd"]: + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='a2a', {attn_mask_type=} and {qkv_format=}." assert ( - not padding or qkv_format == "thd" - ), f"{attn_mask_type} mask type is not supported for BSHD and SBHD!" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='a2a' and {attn_bias_type=}." assert ( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + ), ( + "cp_comm_type='a2a' only supports SWA through FusedAttention or FlashAttention >= 2.3." + f" Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='a2a' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + ) + assert q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0, ( + "cp_comm_type='a2a' requires num_heads % cp_size == 0 for Q, K, V. Found num_heads_q =" + f" {q.shape[-2]}, num_heads_kv = {k.shape[-2]}, cp_size = {cp_size}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -3399,26 +3907,10 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert ( - q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 - ), "The number of attention heads needs to be divisible by CP size!" - - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - if qkv_format in ["bshd", "sbhd"]: - batch_dim = qkv_format.index("b") - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - batch_dim = seq_dim = qkv_format.index("t") - - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; @@ -3426,62 +3918,103 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) q_fp8, k_fp8, v_fp8 = (None, None, None) + fp8_meta_kwargs = {} if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer - else: - assert False, "FP8 is only supported with Fused Attention!" + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if is_input_fp8: + q_fp8, k_fp8, v_fp8 = q, k, v + elif not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: if use_fused_attention: - fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + # q, k, v: + # FP8DS/FP8CS: torch.uint8 + # MXFP8: torch.float16 or torch.bfloat16 + # F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, - seq_dim, + seq_dim_qkv, cp_size, cp_group, cp_stream, before_attn=True, qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) + + # softmax_offset: split h + # [1, h, 1, 1] -> [1, h//cp, 1, 1] if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True ) - out_fp8 = None - out_f16 = None - batch_size = q.shape[batch_dim] + # _part: inputs to attention kernel and saved for backward + # note: they have post a2a shapes q_part, k_part, v_part = q, k, v - out_part = None + out_part, out_fp8, out_f16 = None, None, None + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) if use_fused_attention: if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, + q_part, + k_part, + v_part, + QKV_quantizer, + used_in_backward=is_training, + ) + q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] + else: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3496,6 +4029,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3508,24 +4042,18 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if isinstance(out_, Float8Tensor): - out_fp8 = out_ - out_ = out_._data - if is_bwd_fp8 and not ( - fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - ): - out_part = out_fp8 - else: - out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) - else: - out_f16 = out_ - out_part = out_ - if ( - fp8 - and is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): - out_part = O_quantizer(out_) + # construct out_part for backward + # out_fp8 and out_f16 store the FP8 or F16 tensor for backward saves + out_fp8 = out_ + out_f16 = out_ + if bwd_requires_o_fp8: + if not isinstance(out_, QuantizedTensorStorage): + out_fp8 = O_quantizer(out_) + out_part = out_fp8 + if bwd_requires_o_f16: + if isinstance(out_, QuantizedTensorStorage): + out_f16 = out_.dequantize(dtype=fwd_nominal_dtype) + out_part = out_f16 else: fa_forward_args_thd = get_fa_args( True, @@ -3553,60 +4081,94 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ + # a2a: split s and gather h + # [b, s, h//cp, d] -> [b*s//cp, h, d] + # [s, b, h//cp, d] -> [s//cp*b, h, d] + # [t, h//cp, d] -> [t//cp, h, d] + if isinstance(out_, Float8TensorStorage): + out_ = out_._data chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, - seq_dim, + seq_dim_o, cp_size, cp_group, cp_stream, before_attn=False, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) - if return_max_logit: - max_logit = flash_attn_a2a_communicate_softmax_offset( - *max_logit, 0, cp_size, cp_group, cp_stream, False - ) - - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out_ = out_.view(-1, batch_size, *out_.shape[-2:]) + # [b*s//cp, h, d] -> [b, s//cp, h, d] + # [s//cp*b, h, d] -> [s//cp, b, h, d] + # [t//cp, h, d] -> [t//cp, h, d] + out_ = out_.view(orig_o_shape) - if fp8 and use_fused_attention: - if fp8_recipe.float8_current_scaling(): - out_f16 = out_ - if is_output_fp8: - out_fp8 = O_quantizer(out_) + # out_ret: output tensor for forward pass + # out_fp8 and out_f16 are reused here to store the FP8 or F16 tensor for forward returns + if fp8: if fp8_recipe.delayed(): out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) - if not is_output_fp8: + if is_output_fp8: + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): + out_fp8 = O_quantizer(out_) + out_f16 = out_ + else: + if fp8_recipe.delayed(): out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ else: out_f16 = out_ - out_ret = out_fp8 if is_output_fp8 else out_f16 + # all gather max logit + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + *max_logit, 0, cp_size, cp_group, cp_stream, False + ) + + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout + ctx.dqkv_format = qkv_format + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape + ctx.orig_o_shape = orig_o_shape + + # save tensors for backward ctx.fp8 = fp8 and is_bwd_fp8 fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) - if ctx.fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: - fp8_tensors = (q_part, k_part, v_part, None) - f16_tensors = (None, None, None, out_part) + if is_training: + if ctx.fp8: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # (FP8CS+_dpa_fp8_cs_o_in_f16) or MXFP8: q/k/v in FP8, o in F16 + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + # FP8DS/CS: convert post-a2a FP8 q/k/v to F16; out_part already in F16 + # MXFP8: save post-a2a pre-quantization F16 q/k/v; out_part already in F16 + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_part) + ctx.qkv_layout = original_qkv_layout + else: + q_part, k_part, v_part = combine_and_dequantize( + qkv_layout, q_part, k_part, v_part + ) + f16_tensors = (q_part, k_part, v_part, out_part) else: - fp8_tensors = (q_part, k_part, v_part, out_part) - elif fp8: - q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) - f16_tensors = (q_part, k_part, v_part, out_part) - else: - f16_tensors = (q_part, k_part, v_part, out_part) - + # all tensors are in F16 + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *f16_tensors, @@ -3618,16 +4180,13 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects - ctx.out_shape = out_ret.shape - ctx.batch_size = batch_size ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.deterministic = deterministic @@ -3649,11 +4208,13 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if return_max_logit: return out_ret, max_logit @@ -3681,60 +4242,53 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_func_ctx(ctx) - qkv_format = ctx.qkv_format - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - causal = "causal" in ctx.attn_mask_type - - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - seq_dim = qkv_format.index("t") - + _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype - dqkv_te_dtype = None fused_attn_backend = None - dout_fp8 = dout + causal = "causal" in ctx.attn_mask_type + + dout_fp8 = None + fp8_meta_kwargs = {} if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorStorage): - dout = ctx.dO_quantizer(dout) - dout_fp8 = dout - dqkv_te_dtype = dout._fp8_dtype + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout + if not ctx.fp8_recipe.mxfp8(): dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer - fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer - - else: - assert False, "FP8 is only supported with Fused Attention!" + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - - if not ctx.use_fused_attention: - if qkv_format in ["bshd", "sbhd"]: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) - else: - dout = dout.view(*ctx.out_shape) - + dout = dout.view(*ctx.orig_o_shape) + + # dout: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, chunk_ids_for_a2a, - seq_dim, + seq_dim_do, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=True, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=ctx.o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) flash_attn_bwd = None @@ -3752,7 +4306,7 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_right"] = ctx.window_size[1] fa_backward_kwargs["deterministic"] = ctx.deterministic else: - if qkv_format == "thd": + if ctx.o_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( _flash_attn_varlen_bwd, ) @@ -3779,12 +4333,21 @@ def backward(ctx, dout, *_args): dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: + do_format = ctx.o_format q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or ctx.fp8_recipe.mxfp8(): out_part = out - dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + if not ctx.fp8_recipe.mxfp8(): + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + else: + # do_format = bhsd for both dout (F16) and dout_part (MXFP8) + dout, do_format = dpa_utils.permute_to_grouped_tensor(do_format, dout) + aux_ctx_tensors.append(dout) + dout_part = ctx.dO_quantizer(dout) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3796,14 +4359,16 @@ def backward(ctx, dout, *_args): out_part, dout_part, bwd_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + do_format=do_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, @@ -3812,7 +4377,7 @@ def backward(ctx, dout, *_args): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if isinstance(dq, Float8Tensor): + if all(isinstance(x, Float8TensorStorage) for x in [dq, dk, dv]): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: @@ -3821,7 +4386,7 @@ def backward(ctx, dout, *_args): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - qkv_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -3847,24 +4412,33 @@ def backward(ctx, dout, *_args): **fa_backward_kwargs, ) + # dq, dk, dv: + # FP8DS: torch.uint8 + # FP8CS/MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, - seq_dim, + seq_dim_dqkv, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=False, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=ctx.dqkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) + dq, dk, dv = [ + x.view(y) + for x, y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape]) + ] - if qkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif qkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - + # d_bias, d_softmax_offset d_bias = None d_softmax_offset = None if ctx.use_fused_attention: @@ -3876,9 +4450,14 @@ def backward(ctx, dout, *_args): d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False ) + # convert dq, dk, dv to appropriate types if ctx.fp8: - if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ( + ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() + ) and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize( + ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer + ) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) @@ -3886,7 +4465,7 @@ def backward(ctx, dout, *_args): ] if not ctx.is_input_fp8: dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.dqkv_layout, dq, dk, dv, @@ -3894,7 +4473,6 @@ def backward(ctx, dout, *_args): ) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") - return ( None, dq, @@ -4069,17 +4647,6 @@ def attn_forward_func_with_cp( "all_gather", ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" - enable_mla = k.shape[-1] != v.shape[-1] - assert not enable_mla or cp_comm_type in [ - "p2p", - "a2a+p2p", - ], f"Context parallelism does not support MLA with {cp_comm_type=}!" - - if fp8 and fp8_meta is not None: - if fp8_meta["recipe"].fp8_dpa: - assert ( - softmax_type == "vanilla" - ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" @@ -4131,7 +4698,16 @@ def attn_forward_func_with_cp( elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [ + window_size, + cp_group, + cp_stream, + use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, + ] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..0d4d31f405 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -19,6 +19,7 @@ Recipe, DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, ) from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.quantization import ( @@ -30,7 +31,7 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( @@ -98,19 +99,26 @@ +-------------------+-----------+-----------------------------------------------------------------------------------+ | Linear | Attention | Configuration | +===================+===========+===================================================================================+ -| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | -| | | export NVTE_DPA_FP8_RECIPE="F16" | +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS, NVFP4 or MXFP8 to autocast(); | +| /MXFP8 | | export NVTE_DPA_FP8_RECIPE="F16" | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8DS | Pass FP8DS to autocast(); | +| FP8DS | FP8DS | Pass FP8DS to autocast(); | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8DS | Pass FP8CS to autocast(); | +| FP8CS | FP8DS | Pass FP8CS to autocast(); | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | +| MXFP8 | FP8DS | Pass MXFP8 to autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear MXFP8; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | @@ -118,19 +126,27 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8CS | Pass FP8DS to autocast(); | +| FP8DS | FP8CS | Pass FP8DS to autocast(); | | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8CS | Pass FP8CS to autocast(); | +| FP8CS | FP8CS | Pass FP8CS to autocast(); | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | +| MXFP8 | FP8CS | Pass MXFP8 to autocast(); | +| | | Attention creates a new FP8CS recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear MXFP8, and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | @@ -139,6 +155,18 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS/FP8CS | MXFP8 | Pass FP8DS/FP8CS to autocast(); | +| | | Attention creates a new MXFP8 recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear FP8DS/FP8CS | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| MXFP8 | MXFP8 | Pass MXFP8 to autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | MXFP8 | Pass NVFP4 to autocast(); | +| | | Attention MXFP8 reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ """ _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} @@ -600,7 +628,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False fp8_recipe.fp8_dpa = False fp8_recipe.fp8_mha = False - elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + elif ( + fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8() + ) and _dpa_fp8_recipe == "DelayedScaling": # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe fake_recipe = DelayedScaling( fp8_format=fp8_recipe.fp8_format, @@ -653,6 +683,25 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) fp8_recipe_dpa = fake_recipe fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.mxfp8() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ), + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP @@ -673,11 +722,26 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False - if not fp8_recipe_dpa.float8_per_tensor_scaling(): - assert not ( - fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha - ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + elif ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ) and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs @@ -1203,7 +1267,9 @@ def forward( cu_seqlens_kv_padded = None # get qkv's memory layout - if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + if all( + isinstance(x, Float8TensorStorage) for x in [query_layer, key_layer, value_layer] + ): ( qkv_layout, query_layer._data, @@ -1365,6 +1431,7 @@ def forward( attention_dropout=self.attention_dropout, context_parallel=context_parallel, cp_comm_type=self.cp_comm_type, + cp_size=cp_size, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 170cb2cd34..d6171d04f5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,11 +35,16 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8TensorStorage +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -220,6 +225,8 @@ class AttentionParams: Whether context parallelism is used or not. cp_comm_type : str, default = "p2p" The communication type of context parallelism. + cp_size : int, default = 1 + The group size of context parallelism. deterministic : bool, default = False Whether to run `DotProductAttention` with determinism or not. is_training : bool, default = True @@ -261,6 +268,7 @@ class AttentionParams: attention_dropout: float = 0.0 context_parallel: bool = False cp_comm_type: str = "p2p" + cp_size: int = 1 deterministic: bool = False is_training: bool = True fp8: bool = False @@ -338,6 +346,7 @@ def get_attention_backend( attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel cp_comm_type = attention_params.cp_comm_type + cp_size = attention_params.cp_size # pylint: disable=unused-variable deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 @@ -359,6 +368,7 @@ def get_attention_backend( "transformer_engine_version": te.__version__, "compute_capability": "sm" + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "cuda_version": torch.version.cuda, "flash_attn_version": ( str(FlashAttentionUtils.version) if FlashAttentionUtils.is_installed @@ -446,24 +456,33 @@ def get_attention_backend( qkv_dtype, ) use_flash_attention_2 = False - if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [ + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in ( torch.Tensor, Float8Tensor, - ]: + Float8TensorStorage, + ): if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( - "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s." + " Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}," + " qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage}. ", qkv_dtype, qkv_type, ) use_flash_attention_3 = False + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in ( + torch.Tensor, + Float8Tensor, + Float8TensorStorage, + MXFP8Tensor, + MXFP8TensorStorage, + ): if use_fused_attention: logger.debug( - "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. Supported:" + " qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, qkv_type =" + " {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor," + " MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) @@ -471,6 +490,9 @@ def get_attention_backend( # Filter: Execution type if fp8 and fp8_meta["recipe"].fp8_dpa: + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -478,6 +500,12 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False + if use_flash_attention_3 and not ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ): + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for %s", fp8_recipe.__class__.__name__) + use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() @@ -485,15 +513,21 @@ def get_attention_backend( if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] + if use_fused_attention and fp8_recipe.delayed(): + if ( + device_compute_capability >= (10, 0) + and deterministic + and cudnn_version < (9, 18, 0) + ): + logger.debug( + "Disabling FusedAttention for FP8 delayed scaling on arch >= sm100 with" + " determinism for cuDNN < 9.18.0" + ) + use_fused_attention = False if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False - # TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling - # determinism for Blackwell else: if cudnn_version < (9, 14, 0): logger.debug( @@ -503,10 +537,27 @@ def get_attention_backend( else: if deterministic and cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for FP8 current scaling requiring determinism" - " with cuDNN < 9.18.0" + "Disabling FusedAttention for FP8 current scaling with determinism" + " for cuDNN < 9.18.0" ) use_fused_attention = False + if use_fused_attention and fp8_recipe.mxfp8(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for MXFP8 on arch < sm100") + use_fused_attention = False + elif fp8_recipe.fp8_mha: + logger.debug("Disabling FusedAttention for MXFP8 with fp8_mha=True") + use_fused_attention = False + else: + if cudnn_version < (9, 21, 0): + logger.debug("Disabling FusedAttention for MXFP8 with cuDNN < 9.21.0") + use_fused_attention = False + elif qkv_format == "thd": + logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") + use_fused_attention = False + if use_fused_attention and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()): + logger.debug("Disabling FusedAttention for %s", fp8_recipe.__class__.__name__) + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: @@ -729,29 +780,36 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False if fp8 and fp8_meta["recipe"].fp8_dpa: - logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) - use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type - ) - use_unfused_attention = False - if qkv_format == "thd": - if cudnn_version < (9, 18, 0): + if use_fused_attention and ( + device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0) + ): logger.debug( - "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" - " version < 9.18", + "Disabling FusedAttention for softmax_type = %s in FP8 on sm < 100 with cuDNN" + " version < 9.21", softmax_type, ) use_fused_attention = False - if context_parallel: - if cp_comm_type != "a2a": + if use_unfused_attention: logger.debug( - "Disabling FusedAttention for context parallelism with softmax_type = %s and" - " cp_comm_type = %s", + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type, - cp_comm_type, ) - use_fused_attention = False + use_unfused_attention = False + if qkv_format == "thd" and cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False + if context_parallel and cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -829,10 +887,50 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: + elif fp8 and fp8_meta["recipe"].fp8_dpa and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and THD format" + ) + use_fused_attention = False + elif fp8 and fp8_meta["recipe"].fp8_dpa and core_attention_bias_type != "no_bias": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" - " MLA attention" + " attention and bias" + ) + use_fused_attention = False + elif core_attention_bias_type != "no_bias" and cp_comm_type != "p2p": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias" + " and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD" + " format and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif ( + window_size is not None + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and cp_comm_type in ["p2p", "a2a+p2p"] + ): + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with sliding" + " window attention and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif cp_comm_type in ["a2a", "a2a+p2p"] and (num_heads % 2 != 0 or num_gqa_groups % 2 != 0): + logger.debug( + "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" + " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)", + cp_comm_type, + num_heads, + num_gqa_groups, ) use_fused_attention = False @@ -885,9 +983,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + if ( + fp8 + and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha) + and (device_compute_capability < (10, 0) or cudnn_version < (9, 21, 0)) + ): logger.debug( "Disabling FusedAttention as it does not support sliding window attention for FP8" + " on sm < 100 with cuDNN version < 9.21" ) use_fused_attention = False elif attention_dropout != 0.0: @@ -1025,8 +1128,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and window_size is not None - and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( "Disabling FusedAttention as only sub-backend %s does not support " @@ -1071,15 +1174,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic: - if softmax_type != "vanilla": - logger.debug( - "Disabling FusedAttention for determinism reasons with softmax_type = %s. " - "Sink attention (off-by-one and learnable softmax) requires " - "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", - softmax_type, - ) - use_fused_attention = False - fused_attention_backend = None + # if softmax_type != "vanilla": + # logger.debug( + # "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + # "Sink attention (off-by-one and learnable softmax) requires " + # "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + # softmax_type, + # ) + # use_fused_attention = False + # fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["FP8"] and is_training @@ -2108,28 +2211,45 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers): +def get_attention_quantizers(fp8, fp8_recipe, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + QKV_quantizer.internal = False QKV_quantizer.set_usage(rowwise=True, columnwise=False) - O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.internal = False + O_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.internal = False dO_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + dP_quantizer.set_usage(rowwise=True, columnwise=False) + + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8_recipe.mxfp8(): + QKV_quantizer.columnwise_usage = True + QKV_quantizer.optimize_for_gemm = True + S_quantizer = None + O_quantizer.columnwise_usage = True + + dO_quantizer.columnwise_usage = True + dO_quantizer.optimize_for_gemm = True + dP_quantizer = None + dQKV_quantizer.columnwise_usage = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2183,18 +2303,149 @@ def print_quantizers( type_str = "DS" elif isinstance(q, Float8CurrentScalingQuantizer): type_str = "CS" - print( - f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" - f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" - ) + elif isinstance(q, MXFP8Quantizer): + type_str = "MXFP8" + if type_str in ["DS", "CS"]: + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) + else: + print(f"{label} >> {names[i]:14s}: {type_str}") + + +def permute_to_grouped_tensor(src_format, tensor): + """Permute tensor from src_format = {bshd, sbhd, thd} to des_format = {bhsd, htd} for MXFP8 quantization.""" + if src_format in ["bhsd", "htd"]: + return tensor, src_format + des_format = "bhsd" if src_format != "thd" else "htd" + # make tensor contiguous bshd/sbhd/thd + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + # permute bshd/sbhd to bhsd, and thd to htd + dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") + dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] + new_dims = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] + tensor = tensor.permute(*new_dims).contiguous() + return tensor, des_format + + +class PermuteToGroupedTensor(torch.autograd.Function): + """Permute Q, K, V from {bshd_bshd_bshd, sbhd_sbhd_sbhd} to bhsd_bhsd_bhsd.""" + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + original_layout: str = "bshd_bshd_bshd", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + ctx.original_layout = QKVLayout[original_layout] + return tex.permute_to_grouped_tensor_fwd(query, key, value, ctx.original_layout) + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + query_grad: torch.Tensor, + key_grad: torch.Tensor, + value_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + q, k, v = tex.permute_to_grouped_tensor_bwd( + query_grad, + key_grad, + value_grad, + ctx.original_layout, + ) + return q, k, v, None -def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): +def combine_and_quantize( + qkv_layout, q, k, v, qkv_quantizer, used_in_forward=True, used_in_backward=False +): """Combine q,k,v based on qkv_layout and quantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate + if isinstance(qkv_quantizer, MXFP8Quantizer): + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) + # permute q, k, v to bhsd/htd format + qkv_contiguous_block = False + if qkv_layout in ["bshd_bshd_bshd", "sbhd_sbhd_sbhd"]: + q, k, v = PermuteToGroupedTensor.apply(q, k, v, qkv_layout) + qkv_contiguous_block = True + else: + if q_format not in ["bhsd", "htd"]: + q, _ = permute_to_grouped_tensor(q_format, q) + if kv_format not in ["bhsd", "htd"]: + k, _ = permute_to_grouped_tensor(kv_format, k) + v, _ = permute_to_grouped_tensor(kv_format, v) + + qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" + # check shapes + original_shapes = [x.shape for x in [q, k, v]] + s_q, d_qk = q.shape[-2:] + s_kv, d_v = v.shape[-2:] + assert s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0, ( + "MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32" + f" == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." + ) + q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] + # quantize q, k, v + # if qkv_contiguous_block: + # if d_qk == d_v: + # first_dims = torch.tensor( + # [q.shape[0], k.shape[0], v.shape[0]], dtype=torch.int64, device=q.device + # ) + # qkv_2d = torch.cat([q, k, v], dim=0) + # grouped_tensor = tex.group_quantize(qkv_2d, qkv_quantizer, 3, first_dims) + # quantized_tensors = grouped_tensor.split_into_quantized_tensors() + # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] + # else: + # first_dims = torch.tensor([q.shape[0], k.shape[0]], dtype=torch.int64, device=q.device) + # qk_2d = torch.cat([q, k], dim=0) + # grouped_tensor = tex.group_quantize(qk_2d, qkv_quantizer, 2, first_dims) + # q_fp8, k_fp8 = grouped_tensor.split_into_quantized_tensors() + # v_fp8 = qkv_quantizer(v) + # else: + # input_tensors = [q, k, v] + # num_tensors = len(input_tensors) + # shapes = [x.shape for x in input_tensors] + # grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + # num_tensors=num_tensors, + # shapes=shapes, + # quantizer=qkv_quantizer, + # device="cuda", + # dtype=q.dtype, + # ) + # quantized_tensors = grouped_tensor.quantize(input_tensors) + # q_fp8, k_fp8, v_fp8 = quantized_tensors[0], quantized_tensors[1], quantized_tensors[2] + # else: + # q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + if used_in_forward and used_in_backward: + q_fp8, k_fp8, v_fp8 = [qkv_quantizer(x) for x in [q, k, v]] + if used_in_forward and not used_in_backward: + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = False + q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] + qkv_quantizer.rowwise_usage = False + qkv_quantizer.columnwise_usage = True + v_fp8 = qkv_quantizer(v) + if (not used_in_forward) and used_in_backward: + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = True + q_fp8, k_fp8 = [qkv_quantizer(x) for x in [q, k]] + qkv_quantizer.rowwise_usage = True + qkv_quantizer.columnwise_usage = False + v_fp8 = qkv_quantizer(v) + + # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv + q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] + + return q_fp8, k_fp8, v_fp8, qkv_layout + qkv_layout = qkv_layout.replace("paged_kv_", "") qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") @@ -2234,24 +2485,28 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): for x in [q_data, k_data, v_data] ] - return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout def combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None ): """Combine q,k,v based on qkv_layout and dequantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_group = len(qkv_layout.split("_")) - if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + if all(isinstance(x, QuantizedTensorStorage) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" if des_nominal_dtype is None: des_nominal_dtype = src_nominal_dtype + if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): + q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] + return q, k, v + + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d95d327c78..afc4622b22 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -795,15 +795,31 @@ def forward( fp8_dpa = fp8_recipe.fp8_dpa fp8_mha = fp8_recipe.fp8_mha float8_current_scaling = fp8_recipe.float8_current_scaling() + mxfp8_scaling = fp8_recipe.mxfp8() else: fp8_dpa = _dpa_fp8_recipe_dpa fp8_mha = _dpa_fp8_recipe_mha float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" - # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling - # DPA: always produce FP8 output when fp8=True to take advantage of the O amax - dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) - # Proj Gemm: match DPA output except for Float8CurrentScaling + mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" + + # QKV Gemm: do not produce FP8 output when fp8_mha = True if + # 1. RoPE is on: RoPE is only implemented in F16 currently + # 2. FP8CS recipe: due to cuBLAS limitation, FP8CS Gemms can not produce FP8 output + # 3. MXFP8 recipe: QKV Gemm produces QKV in bs(hd), sb(hd), t(hd) shapes, quantization of which would be along + # s/b/t and (hd) dimensions, whereas MXFP8 attention requires quantization along s and d, e.g. bhsd, sbhd, thd + qkv_fp8_output = ( + fp8 + and fp8_mha + and rotary_pos_emb is None + and not float8_current_scaling + and not mxfp8_scaling + ) + # DPA: produce FP8 output to take advantage of O amax from DPA; Projection Gemm can take FP8 or F16 inputs + # 1. FP8DS/FP8CS recipe: produce FP8 output + # 2. MXFP8 recipe: produce F16 output; again, due to quantization dimensions mismatch + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling + # Projection Gemm: match DPA output except + # 1. FP8CS recipe: produce F16 grads; again, due to cuBLAS limitation proj_fp8_grad = dpa_fp8_output and not float8_current_scaling layernorm_output = None diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 06bfb6ef3c..f086c4bcd0 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,6 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, + "bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -70,6 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, + "bhsd_bhsd_bhsd": NVTE_QKV_Layout.NVTE_BHSD_BHSD_BHSD, } AttnBiasType = { @@ -134,6 +136,7 @@ def fused_attn_fwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -203,6 +206,8 @@ def fused_attn_fwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -251,7 +256,7 @@ def fused_attn_fwd( M: torch.Tensor max(Q*K.T) shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor + ZInv: torch.Tensor, only allocated for T3HD path 1/sum(e^(x - max(x))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen @@ -302,17 +307,6 @@ def fused_attn_fwd( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - if s_quantizer is None: - raise ValueError( - "s_quantizer is required for FP8 fused attention forward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if o_quantizer is None: - raise ValueError( - "o_quantizer is required for FP8 fused attention forward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -326,6 +320,7 @@ def fused_attn_fwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -415,7 +410,6 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, fake_dtype: torch.dtype, - dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, @@ -427,6 +421,9 @@ def fused_attn_bwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", + do_format: str = "sbhd", + dqkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -465,8 +462,6 @@ def fused_attn_bwd( fake_dtype : tex.DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype - dqkv_dtype : tex.DType - data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -493,6 +488,15 @@ def fused_attn_bwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} + do_format : str, default = "sbhd" + format of dO; {"sbhd", "bshd", "thd"} + dqkv_layout : str, default = "sbh3d" + layout of dQ, dK and dV; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -553,29 +557,6 @@ def fused_attn_bwd( f" for backend={fused_attention_backend}." ) - if fused_attention_backend == FusedAttnBackend["FP8"]: - if s_quantizer is None: - raise ValueError( - "s_quantizer is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if dp_quantizer is None: - raise ValueError( - "dp_quantizer is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if dqkv_dtype is None: - raise ValueError( - "dqkv_dtype is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if len(aux_ctx_tensors) != 3: - raise ValueError( - "aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," - f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" - f" (backend={fused_attention_backend})." - ) - output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -583,6 +564,9 @@ def fused_attn_bwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[do_format], + QKVLayout[dqkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -597,7 +581,6 @@ def fused_attn_bwd( o, d_o, fake_dtype, - dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 9d2513835c..fee2c9771f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -301,6 +301,14 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data = std::nullopt); + std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4bc744e7e..2ecef7d79d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -84,7 +84,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -98,11 +98,12 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -111,6 +112,12 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); +std::tuple permute_to_grouped_tensor_fwd( + at::Tensor query, at::Tensor key, at::Tensor value, NVTE_QKV_Layout input_layout); +std::tuple permute_to_grouped_tensor_bwd( + at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, + NVTE_QKV_Layout input_layout); + at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index ff60bb87bb..e3c25b396a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -91,6 +91,21 @@ std::pair quantizer_helper(py::handle quantizer, !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // MXFP8 + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK(!data.has_value(), + "MXFP8Quantizer::create_tensor() does not take data tensor as input!"); + } } return {std::move(te_T), std::move(py_T)}; } @@ -98,7 +113,7 @@ std::pair quantizer_helper(py::handle quantizer, // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -134,8 +149,13 @@ std::vector fused_attn_fwd( std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; + o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; + size_t h = 0, d = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, nullptr, &h, nullptr, &d, + nullptr); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -146,9 +166,7 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { @@ -156,7 +174,7 @@ std::vector fused_attn_fwd( } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (o_format == NVTE_QKV_Format::NVTE_THD) { te_O.zero_(at::cuda::getCurrentCUDAStream()); } } else { @@ -235,9 +253,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -260,7 +278,7 @@ std::vector fused_attn_fwd( // f16_arbitrary: // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] // return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + // fp8 : M [b, h, sq, 1], optional ZInv [b, h, sq, 1] (T3HD path), rng_state [2] size_t i = 0; at::Tensor output_tensor; // intermediate softmax tensor, S or M (for fp8) @@ -268,8 +286,10 @@ std::vector fused_attn_fwd( allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); - // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor - if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // fp8 T3HD has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor + if (((qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) || + return_max_logit) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); @@ -295,9 +315,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory @@ -310,11 +330,12 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -343,25 +364,35 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d_qk = q_shape[q_shape.size() - 1]; - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - + const DType dqkv_fake_dtype = GetTransformerEngineDType(fake_dtype); + size_t h_q = 0, h_kv = 0, d_qk = 0, d_v = 0; + size_t ndim_q = q_shape.size(); + size_t ndim_kv = k_shape.size(); + std::vector dQ_shape(ndim_q), dK_shape(ndim_kv), dV_shape(ndim_kv); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); + NVTE_QKV_Format dkv_format = nvte_get_kv_format(dqkv_layout); + nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, nullptr, &h_q, nullptr, &d_qk, + nullptr); + nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, nullptr, &h_kv, nullptr, + nullptr, nullptr); + nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, nullptr, nullptr, nullptr, &d_v, + nullptr); at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - std::vector tmp_shape; - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + // FP16/BF16: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.float16/torch.bfloat16 + // FP8DS: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.uint8 + // FP8CS/MXFP8: dqkv_fake_dtype = kFloat16/kBFloat16, dQ/dK/dV.dtype = torch.float16/torch.bfloat16 + auto options = torch::TensorOptions().dtype(fake_dtype).device(torch::kCUDA); + if (detail::IsFloat8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { - options = options.dtype(fake_dtype); - } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); + std::vector tmp_shape; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -378,7 +409,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -392,9 +423,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -407,9 +438,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -420,27 +451,29 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); dK = torch::empty(tmp_shape, options); - tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + tmp_shape = std::vector(dV_shape.begin(), dV_shape.end()); dV = torch::empty(tmp_shape, options); break; default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); - std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); - std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, dQ_shape, dqkv_fake_dtype, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, dK_shape, dqkv_fake_dtype, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, dV_shape, dqkv_fake_dtype, true, dV); // construct NVTE tensors - if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + if (detail::IsFloat8Quantizers(dqkv_quantizer.ptr())) { // FP8 - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD)) { if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { + ((h_kv * d_v) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && + dV.is_contiguous()) { mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -450,8 +483,10 @@ std::vector fused_attn_bwd( dV.fill_(0); } } - } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + } else if (dqkv_quantizer.is_none() || + detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { + if (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); @@ -538,9 +573,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -555,9 +590,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers @@ -614,6 +649,119 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } +std::tuple permute_to_grouped_tensor_fwd( + at::Tensor query, at::Tensor key, at::Tensor value, NVTE_QKV_Layout original_layout) { + NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, + "permute_to_grouped_tensor_fwd: original_layout must be NVTE_SBHD_SBHD_SBHD or " + "NVTE_BSHD_BSHD_BSHD."); + NVTE_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda()); + NVTE_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous()); + NVTE_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4); + NVTE_CHECK(query.scalar_type() == at::ScalarType::Half || + query.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(key.scalar_type() == query.scalar_type() && + value.scalar_type() == query.scalar_type()); + + int64_t B = 0; + int64_t S_q = 0, H_q = 0, D_qk = 0; + int64_t S_kv = 0, H_kv = 0, D_v = 0; + if (original_layout == NVTE_SBHD_SBHD_SBHD) { + S_q = query.size(0); + B = query.size(1); + H_q = query.size(2); + D_qk = query.size(3); + S_kv = key.size(0); + H_kv = key.size(2); + D_v = value.size(3); + } else { + B = query.size(0); + S_q = query.size(1); + H_q = query.size(2); + D_qk = query.size(3); + S_kv = key.size(1); + H_kv = key.size(2); + D_v = value.size(3); + } + NVTE_CHECK(key.size(original_layout == NVTE_SBHD_SBHD_SBHD ? 1 : 0) == B && + value.size(original_layout == NVTE_SBHD_SBHD_SBHD ? 1 : 0) == B, + "permute_to_grouped_tensor_fwd: Q/K/V batch dimension must match."); + + const int64_t numel_q = B * H_q * S_q * D_qk; + const int64_t numel_k = B * H_kv * S_kv * D_qk; + const int64_t numel_v = B * H_kv * S_kv * D_v; + at::Tensor qkv_out_flat = at::empty({numel_q + numel_k + numel_v}, query.options()); + at::Tensor q_out = qkv_out_flat.narrow(0, 0, numel_q).view({B, H_q, S_q, D_qk}); + at::Tensor k_out = qkv_out_flat.narrow(0, numel_q, numel_k).view({B, H_kv, S_kv, D_qk}); + at::Tensor v_out = qkv_out_flat.narrow(0, numel_q + numel_k, numel_v).view({B, H_kv, S_kv, D_v}); + + auto te_q = makeTransformerEngineTensor(query); + auto te_k = makeTransformerEngineTensor(key); + auto te_v = makeTransformerEngineTensor(value); + auto te_qo = makeTransformerEngineTensor(q_out); + auto te_ko = makeTransformerEngineTensor(k_out); + auto te_vo = makeTransformerEngineTensor(v_out); + + nvte_permute_to_grouped_tensor_fwd(te_q.data(), te_k.data(), te_v.data(), te_qo.data(), + te_ko.data(), te_vo.data(), original_layout, + at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(q_out, k_out, v_out); +} + +std::tuple permute_to_grouped_tensor_bwd( + at::Tensor query_grad, at::Tensor key_grad, at::Tensor value_grad, + NVTE_QKV_Layout original_layout) { + NVTE_CHECK(original_layout == NVTE_SBHD_SBHD_SBHD || original_layout == NVTE_BSHD_BSHD_BSHD, + "permute_to_grouped_tensor_bwd: original_layout must be NVTE_SBHD_SBHD_SBHD or " + "NVTE_BSHD_BSHD_BSHD."); + NVTE_CHECK(query_grad.is_cuda() && key_grad.is_cuda() && value_grad.is_cuda()); + NVTE_CHECK(query_grad.is_contiguous() && key_grad.is_contiguous() && value_grad.is_contiguous()); + NVTE_CHECK(query_grad.dim() == 4 && key_grad.dim() == 4 && value_grad.dim() == 4); + NVTE_CHECK(query_grad.scalar_type() == at::ScalarType::Half || + query_grad.scalar_type() == at::ScalarType::BFloat16); + NVTE_CHECK(key_grad.scalar_type() == query_grad.scalar_type() && + value_grad.scalar_type() == query_grad.scalar_type()); + + const int64_t B = query_grad.size(0); + const int64_t H_q = query_grad.size(1); + const int64_t S_q = query_grad.size(2); + const int64_t D_qk = query_grad.size(3); + const int64_t H_kv = key_grad.size(1); + const int64_t S_kv = key_grad.size(2); + const int64_t D_v = value_grad.size(3); + + const int64_t numel_q = S_q * B * H_q * D_qk; + const int64_t numel_k = S_kv * B * H_kv * D_qk; + const int64_t numel_v = S_kv * B * H_kv * D_v; + at::Tensor qkv_grad_flat = at::empty({numel_q + numel_k + numel_v}, query_grad.options()); + + at::Tensor query; + at::Tensor key; + at::Tensor value; + if (original_layout == NVTE_SBHD_SBHD_SBHD) { + query = qkv_grad_flat.narrow(0, 0, numel_q).view({S_q, B, H_q, D_qk}); + key = qkv_grad_flat.narrow(0, numel_q, numel_k).view({S_kv, B, H_kv, D_qk}); + value = qkv_grad_flat.narrow(0, numel_q + numel_k, numel_v).view({S_kv, B, H_kv, D_v}); + } else { + query = qkv_grad_flat.narrow(0, 0, numel_q).view({B, S_q, H_q, D_qk}); + key = qkv_grad_flat.narrow(0, numel_q, numel_k).view({B, S_kv, H_kv, D_qk}); + value = qkv_grad_flat.narrow(0, numel_q + numel_k, numel_v).view({B, S_kv, H_kv, D_v}); + } + + auto te_gq = makeTransformerEngineTensor(query_grad); + auto te_gk = makeTransformerEngineTensor(key_grad); + auto te_gv = makeTransformerEngineTensor(value_grad); + auto te_q = makeTransformerEngineTensor(query); + auto te_k = makeTransformerEngineTensor(key); + auto te_v = makeTransformerEngineTensor(value); + + nvte_permute_to_grouped_tensor_bwd(te_gq.data(), te_gk.data(), te_gv.data(), te_q.data(), + te_k.data(), te_v.data(), original_layout, + at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(query, key, value); +} + /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 18da5d0e9f..7f32892015 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -399,6 +399,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", py::call_guard()); + m.def("permute_to_grouped_tensor_fwd", + &transformer_engine::pytorch::permute_to_grouped_tensor_fwd, + "Permute Q, K, V to grouped tensors.", py::arg("query"), py::arg("key"), py::arg("value"), + py::arg("original_layout"), py::call_guard()); + m.def( + "permute_to_grouped_tensor_bwd", &transformer_engine::pytorch::permute_to_grouped_tensor_bwd, + "Permute Q, K, V back to original layout.", py::arg("query_grad"), py::arg("key_grad"), + py::arg("value_grad"), py::arg("original_layout"), py::call_guard()); m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd, "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd, diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..9610880093 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1445,6 +1445,22 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data) { + // static std::once_flag once; + // static at::Tensor amax_tensor; + // std::call_once(once, []() { + // amax_tensor = at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + // }); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); + // out_cpp.set_amax(amax_tensor.data_ptr(), GetTransformerEngineDType(amax_tensor.scalar_type()), + // getTensorShape(amax_tensor)); + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index ff1c78f695..34f6cb0a59 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -634,7 +634,7 @@ def make_grouped_tensor( total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = quantizer.get_scale_shape(s, True) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -871,15 +871,25 @@ def split_into_quantized_tensors( # populate scale_inv_offsets from the tensor offsets if self.scale_inv is not None and self.scale_inv_offsets is None: - if recipe.nvfp4(): - self.scale_inv_offsets = self.tensor_offsets // 16 - if recipe.mxfp8(): - self.scale_inv_offsets = self.tensor_offsets // 32 + if recipe.nvfp4() or recipe.mxfp8() or recipe.float8_block_scaling(): + cum = 0 + scale_inv_offsets: List[int] = [0] + for i in range(self.num_tensors): + tensor_shape = self.tensor_shapes[i] + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + cum += math.prod(scale_shape) + scale_inv_offsets.append(cum) + self.scale_inv_offsets = scale_inv_offsets if self.columnwise_scale_inv is not None and self.columnwise_scale_inv_offsets is None: - if recipe.nvfp4(): - self.columnwise_scale_inv_offsets = self.tensor_offsets // 16 - if recipe.mxfp8(): - self.columnwise_scale_inv_offsets = self.tensor_offsets // 32 + if recipe.nvfp4() or recipe.mxfp8() or recipe.float8_block_scaling(): + cum = 0 + columnwise_scale_inv_offsets: List[int] = [0] + for i in range(self.num_tensors): + tensor_shape = self.tensor_shapes[i] + scale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + cum += math.prod(scale_shape) + columnwise_scale_inv_offsets.append(cum) + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets for i in range(self.num_tensors): quantizer = self.quantizer