diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..8c7646c00d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,8 @@ url = https://github.com/google/googletest.git [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend - url = https://github.com/NVIDIA/cudnn-frontend.git + url = https://gitlab-master.nvidia.com/cudnn/cudnn_frontend.git + branch = develop [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index d33027a41a..69432369f3 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 +Subproject commit 69432369f3060467c72d01bd08cfeb9271178c22 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0f36a8816d..a5804c6888 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -19,8 +19,14 @@ DotProductAttention, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, +) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -180,19 +186,23 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + deterministic="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) # When is_training is False, gradient outputs are None. is_training = is_training == "True" - + if deterministic == "True": + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + else: + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" # set up environment variables and config fp8_bwd = fp8_bwd == "True" and dtype == "fp8" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" fp8_dpa = fp8_dpa == "True" and dtype == "fp8" fp8_mha = fp8_mha == "True" and dtype == "fp8" - f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + f16_O = dtype == "fp8" and scaling_mode in ["current", "mxfp8"] and f16_O == "True" os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -247,6 +257,10 @@ def run_dpa_with_cp( fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.HYBRID, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + ) # instantiate attention module core_attn = DotProductAttention( @@ -302,10 +316,25 @@ def run_dpa_with_cp( fp8_dtype=tex.DType.kFloat8E5M2, device="cuda", ) + if scaling_mode == "mxfp8": + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False qkv_layout = "_".join([qkv_format] * 3) q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: - q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + if fp8_mha and scaling_mode != "mxfp8": + q, k, v, qkv_layout = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) for x in [q, k, v]: x.requires_grad = True @@ -351,12 +380,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out, max_logit = out if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: @@ -412,8 +441,8 @@ def run_dpa_with_cp( qkv_quantizer.amax.fill_(0.0) dout_quantizer.scale.fill_(1.0) dout_quantizer.amax.fill_(0.0) - if fp8_mha: - q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + if fp8_mha and scaling_mode != "mxfp8": + q_, k_, v_, qkv_layout = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) if is_training: q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: @@ -468,12 +497,12 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # fp8_output=fp8_mha, ) if config.return_max_logit: out_, max_logit_ = out_ if is_training: - if fp8_bwd and fp8_mha: + if fp8_bwd and fp8_mha and scaling_mode != "mxfp8": dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: @@ -502,7 +531,7 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + for i, tensor in enumerate(tensors): # dbias/dbias_ could be None, so skip check for it if tensor is not None: assert torch.all(~torch.isnan(tensor)) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 60ade522e3..e73e16eb3d 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1803,23 +1803,38 @@ def get_model(dtype, config): return outputs +attn_mask_type = "causal" +# attn_mask_type = "no_mask" +# attn_mask_type = "causal_bottom_right" model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128), - "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), - "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), - "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "fp8_9": ModelConfig( + 2, + 4096, + 128, + 192, + head_dim_v=128, + attn_mask_type=attn_mask_type, + ), + "fp8_10": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type=attn_mask_type), + "fp8_11": ModelConfig(2, 8192, 32, 128, attn_mask_type=attn_mask_type, window_size=(128, 0)), + "fp8_12": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type=attn_mask_type), + "fp8_13": ModelConfig(2, 8192, 64, 64, attn_mask_type=attn_mask_type, window_size=(128, 0)), + # "fp8_14": ModelConfig( + # 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + # ), + # "fp8_15": ModelConfig( + # 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + # ), + # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + # "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + # "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), + # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } -param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] +param_types_fp8_vs_f16 = [torch.bfloat16] # [torch.float16, torch.bfloat16] qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] @@ -1833,7 +1848,7 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_mha_fp8_vs_f16( dtype, model, @@ -1864,6 +1879,12 @@ def test_mha_fp8_vs_f16( fp8_dpa=True, fp8_mha=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=True, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2047,7 +2068,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: hidden_states.requires_grad = True tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, @@ -2083,7 +2103,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", ["delayed", "current", "mxfp8"]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -2115,6 +2135,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8_format=recipe.Format.HYBRID, fp8_dpa=True, ) + elif scaling_mode == "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling( + fp8_format=recipe.Format.E4M3, + fp8_dpa=True, + fp8_mha=False, + ) fp8_meta = {} fp8_meta["recipe"] = fp8_recipe available_backends, _, _ = get_available_attention_backends( @@ -2275,7 +2301,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim_qk, + (config.head_dim_qk, config.head_dim_v), num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -2285,6 +2311,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type="self", qkv_format=qkv_format, + softmax_type=config.softmax_type, ).to(dtype=dtype, device="cuda") if not is_training: dpa = dpa.eval() @@ -2320,7 +2347,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim_qk, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -2336,6 +2364,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] if config.dropout_p == 0.0: tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") @@ -2360,6 +2392,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") @@ -2370,6 +2403,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: inp[1], inp[2], qkv_format=qkv_format, + window_size=config.window_size, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=config.max_seqlen_q, @@ -2377,7 +2411,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index ecd0090a3b..94e009e377 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -17,6 +17,8 @@ from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, + MXFP8BlockScaling, + Format, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils @@ -26,6 +28,12 @@ pytest_logging_level = logging.getLevelName(logging.root.level) +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) @@ -39,13 +47,11 @@ "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_0": ModelConfig(2, 4096, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA - "cp_2_2": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) - ), # GQA + "cp_2_2": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA - "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_0": ModelConfig(2, 4096, 128, 192, attn_mask_type="causal", head_dim_v=128), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 @@ -73,10 +79,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"] + configs = ["cp_2_0", "cp_3_0", "cp_2_2"] # , "cp_1_2", "cp_2_1"]#, "cp_1_1", "cp_3_3"] model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} dtypes = ["bf16"] - qkv_formats = ["sbhd", "thd"] + qkv_formats = ["bshd", "sbhd", "thd"] @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @@ -94,25 +100,34 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config.context_parallel = True config.cp_comm_type = cp_comm_type - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No support for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if ( + config.window_size != (-1, 0) + and config.window_size != (-1, -1) + and cp_comm_type + in [ + "p2p", + "a2a+p2p", + ] + ): + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 + ): pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + # FlashAttention / CP implementation specific: MLA only with KV P2P if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} @@ -150,9 +165,18 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_4": ModelConfig( 2, 4096, 12, 128, attn_bias_type="post_scale_bias", bias_shape="bhss" ), # MHA - "cp_1_5": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA - "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA - "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_1_5": ModelConfig(2, 4096, 32, 128, attn_mask_type="causal", window_size=(128, 0)), # MHA + "cp_2_0": ModelConfig( + 2, + 4096, + 32, + 128, + num_gqa_groups=4, + attn_mask_type="causal", + ), # GQA + "cp_2_1": ModelConfig( + 2, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal" + ), # num_gqa_groups=4, attn_mask_type="causal"), # GQA "cp_2_2": ModelConfig( 2, 4096, @@ -190,7 +214,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA - "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA "cp_3_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA @@ -215,21 +239,23 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: configs = [ - "cp_1_0", - "cp_1_1", - "cp_1_4", + # "cp_1_0", + # "cp_1_1", + # "cp_1_4", "cp_1_5", "cp_2_0", - "cp_2_2", - "cp_2_3", - "cp_2_4", - "cp_3_2", - "cp_3_4", - "cp_4_2", + "cp_2_1", + # "cp_2_2", + # "cp_2_3", + # "cp_2_4", + # "cp_3_1", + # "cp_3_2", + # "cp_3_4", + # "cp_4_2", ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] - qkv_formats = ["sbhd", "thd"] + qkv_formats = ["bshd", "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @@ -241,96 +267,89 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("fp8_bwd", [True, False]) @pytest.mark.parametrize("fp8_mha", [True, False]) @pytest.mark.parametrize("fp8_dpa", [True, False]) -@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O ): + config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") - if qkv_format == "thd" and get_device_compute_capability() < (9, 0): - pytest.skip("THD format is only supported on sm90+!") - if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): - pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") - if dtype == "fp8" and get_device_compute_capability() < (9, 0): - pytest.skip("FP8 attention is only supported on sm90+!") + if get_device_compute_capability() < (9, 0) and qkv_format == "thd": + pytest.skip("Only sm90+ architectures support THD format!") + if get_device_compute_capability() < (9, 0) and dtype == "fp8": + pytest.skip("Only sm90+ architectures support FP8 attention!") + + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("dtype=fp8 requires fp8_dpa=True or fp8_mha=True!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: - pytest.skip("Only fp8 works with fp8_bwd=True!") - - config = model_configs_fused_attn[model] - config.context_parallel = True - config.cp_comm_type = cp_comm_type + pytest.skip("fp8_bwd=True requires dtype=fp8!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("dtype!=fp8 requires fp8_dpa=False and fp8_mha=False!") - if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": - pytest.skip("THD format does not support post_scale_bias yet!") - if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if cp_comm_type == "a2a+p2p": - pytest.skip( - "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" - " yet!" - ) - if dtype == "fp8" and cp_comm_type == "all_gather": - pytest.skip( - "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" - ) if dtype == "fp8" and qkv_format == "thd": - pytest.skip("FP8 attention cannot work with THD format yet!") + pytest.skip("No support for FP8 attention with THD format!") if dtype == "fp8" and config.attn_bias_type != "no_bias": - pytest.skip("FP8 attention cannot work with bias yet!") - if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("FP8 attention cannot work with sliding window yet!") - if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): - pytest.skip("CP implementation with KV P2P does not support sliding window yet!") - if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with KV all-gather does not support bias yet!") - if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": - pytest.skip("CP implementation with QKVO A2A does not support bias yet!") - if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0): - pytest.skip( - f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" - f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" - ) - if dtype != "fp8" and (fp8_mha or fp8_dpa): - pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") - if dtype == "fp8" and not (fp8_mha or fp8_dpa): - pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") - if dtype != "fp8" and scaling_mode is not None: - pytest.skip("Only fp8 works with scaling_mode != None!") - if dtype == "fp8" and scaling_mode is None: - pytest.skip("fp8 only works with scaling_mode != None!") + pytest.skip("No support for FP8 attention with bias!") + + if config.attn_bias_type != "no_bias" and qkv_format == "thd": + pytest.skip("No supprt for bias with THD format!") + if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: + pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") + + if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + + if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [ + "p2p", + "a2a+p2p", + ]: + pytest.skip("No support for SWA with cp_comm_type={p2p, a2a+p2p}!") + + # TODO: Remove this once the issue is fixed! if ( dtype == "fp8" - and scaling_mode == "current" - and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + and (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) + and cp_comm_type == "all_gather" + ): + pytest.skip("No support for SWA with FP8 attention and cp_comm_type=all_gather!") + + if cp_comm_type in ["a2a", "a2a+p2p"] and ( + config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0 ): - pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") - if f16_O and (dtype != "fp8" or scaling_mode != "current"): - pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") - if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently only support KV P2P!") - if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: - pytest.skip("MLA CP currently does not support FP8 attention!") - if dtype == "fp8" and config.softmax_type != "vanilla": - pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") - if config.softmax_type != "vanilla" and cp_comm_type != "a2a": pytest.skip( - "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + f"cp_comm_type=a2a requires num_heads ({config.num_heads}) and" + f" num_gqa_groups ({config.num_gqa_groups}) divisible by 2!" ) + + if config.softmax_type != "vanilla" and dtype == "fp8": + pytest.skip("No support for non-vanilla softmax with FP8 attention!") + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip(f"No support for non-vanilla softmax with cp_comm_type={cp_comm_type}!") if ( - get_cudnn_version() < (9, 18, 0) - and config.softmax_type != "vanilla" + config.softmax_type != "vanilla" and qkv_format == "thd" + and get_cudnn_version() < (9, 18, 0) ): - pytest.skip( - "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" - " non-vanilla softmax types!" - ) + pytest.skip("No support for non-vanilla softmax with THD format and cuDNN < 9.18.0!") + + if dtype == "fp8" and scaling_mode is None: + pytest.skip("dtype=fp8 requires scaling_mode != None!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("dtype!=fp8 requires scaling_mode = None!") + if dtype != "fp8" and not f16_O: + pytest.skip("dtype!=fp8 requires f16_O=True!") + if scaling_mode == "delayed" and f16_O: + pytest.skip("scaling_mode=delayed requires f16_O=False!") + if scaling_mode == "mxfp8" and not f16_O: + pytest.skip("scaling_mode=mxfp8 requires f16_O=True!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -354,6 +373,12 @@ def test_cp_with_fused_attention( Float8CurrentScaling(fp8_dpa=True), DelayedScaling(fp8_dpa=True), ] + if fp8 and scaling_mode == "mxfp8": + fp8_meta["recipe"] = MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True) + fp8_meta["local_recipes"] = [ + MXFP8BlockScaling(fp8_format=Format.HYBRID, fp8_dpa=True), + ] + # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -363,6 +388,7 @@ def test_cp_with_fused_attention( fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -382,6 +408,7 @@ def test_cp_with_fused_attention( scaling_mode=scaling_mode, f16_O=f16_O, is_training=is_training, + deterministic=_deterministic, log_level=pytest_logging_level, ), check=True, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 317240fb78..1747d75676 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -177,6 +177,7 @@ def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): rmse = torch.sqrt((a - b).square().mean()).item() logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + # rmse_tol = rmse_tol * 1.1 assert rmse < rmse_tol * rmse_range, ( name_a + " vs " diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6a136c67e4..b5ecde6750 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -131,6 +131,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Layout_Group::NVTE_SD_SD_SD; default: NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_layout_group."); @@ -172,6 +174,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_format."); @@ -192,6 +196,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2BSHD: case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_q_format."); @@ -212,12 +218,93 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Format::NVTE_BHSD: + return NVTE_QKV_Format::NVTE_BHSD; default: NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_kv_format."); } } +// map one NVTE_QKV_Format to another +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t) { + size_t _b = 0, _h = 0, _s = 0, _d = 0, _t = 0; + switch (src_format) { + case NVTE_QKV_Format::NVTE_BSHD: + _b = src_shape[0]; + _s = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_SBHD: + _s = src_shape[0]; + _b = src_shape[1]; + _h = src_shape[2]; + _d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_BHSD: + _b = src_shape[0]; + _h = src_shape[1]; + _s = src_shape[2]; + _d = src_shape[3]; + break; + case NVTE_QKV_Format::NVTE_THD: + _t = src_shape[0]; + _h = src_shape[1]; + _d = src_shape[2]; + break; + default: + NVTE_ERROR("src_format not supported!"); + break; + } + switch (dst_format) { + case NVTE_QKV_Format::NVTE_BSHD: + dst_shape[0] = _b; + dst_shape[1] = _s; + dst_shape[2] = _h; + dst_shape[3] = _d; + break; + case NVTE_QKV_Format::NVTE_SBHD: + dst_shape[0] = _s; + dst_shape[1] = _b; + dst_shape[2] = _h; + dst_shape[3] = _d; + break; + case NVTE_QKV_Format::NVTE_BHSD: + dst_shape[0] = _b; + dst_shape[1] = _h; + dst_shape[2] = _s; + dst_shape[3] = _d; + break; + case NVTE_QKV_Format::NVTE_THD: + dst_shape[0] = _t; + dst_shape[1] = _h; + dst_shape[2] = _d; + break; + default: + NVTE_ERROR("dst_format not supported!"); + break; + } + + if (b != nullptr) { + *b = _b; + } + if (h != nullptr) { + *h = _h; + } + if (s != nullptr) { + *s = _s; + } + if (d != nullptr) { + *d = _d; + } + if (t != nullptr) { + *t = _t; + } +} + // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, @@ -269,9 +356,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: mxfp8, d_qk=128, d_v=192 + (cudnn_runtime_version >= 92100 && head_dim_qk <= 192 && head_dim_v <= 128)) && + !requires_64bit_ragged_offset && + // pre-9.21: softmax_type=vanilla, 9.21+: softmax_type={vanilla, off-by-one, learnable} + ((cudnn_runtime_version < 92100 && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || + cudnn_runtime_version >= 92100) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { @@ -410,12 +501,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || + qkv_format == NVTE_QKV_Format::NVTE_BHSD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || cudnn_runtime_version >= 90600)) || ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + q_format == NVTE_QKV_Format::NVTE_BHSD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + kv_format == NVTE_QKV_Format::NVTE_BHSD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && cudnn_runtime_version >= 90700)) && // sliding window @@ -426,7 +520,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && window_size_right == -1 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + ((window_size_left == -1 || window_size_left >= 0) && + (window_size_right == -1 || window_size_right >= 0) && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && @@ -535,19 +630,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -566,23 +659,22 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector tmp_shape(4); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, + &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; + b = input_cu_seqlens_q->data.shape[0] - 1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] - 1; } + int64_t num_pages_k = 0; int64_t num_pages_v = 0; int64_t page_size_k = 0; @@ -646,10 +738,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, + attn_scale, dropout, qkv_layout, o_format, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + input_Q, input_K, input_V, input_SoftmaxOffset, input_output_S, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -666,11 +760,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -691,22 +786,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); - auto ndim = input_Q->data.shape.size(); - auto ndim_kv = input_K->data.shape.size(); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim_kv - 2]; - size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim_kv - 1]; - size_t t_q = 0; - size_t t_kv = 0; + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector tmp_shape(4); + nvte_convert_qkv_format(q_format, input_Q->data.shape, q_format, tmp_shape, &b, &h_q, &s_q, &d_qk, + &t_q); + nvte_convert_qkv_format(kv_format, input_K->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, input_V->data.shape, kv_format, tmp_shape, &b, &h_kv, &s_kv, + &d_v, &t_kv); if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_K->data.shape[0]; + b = input_cu_seqlens_q->data.shape[0] - 1; + } else if (kv_format == NVTE_QKV_Format::NVTE_THD) { + b = input_cu_seqlens_kv->data.shape[0] - 1; } auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); @@ -755,14 +848,25 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, deterministic, input_Q, input_K, - input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, - output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + size_t i = 0; + const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + const Tensor *input_dO_f16 = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + const Tensor *input_SoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, + qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, + input_M, input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, + output_dQ, output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 80e64370f9..6c3d9a8161 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1652,16 +1652,20 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, - void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, - void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, + void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, + void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, + NVTEScalingMode scaling_mode, void* workspace, size_t* workspace_size, cudaStream_t stream, + cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -1669,19 +1673,25 @@ void fused_attn_fp8_fwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || - o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); try { FADescriptor_v1 descriptor{b, @@ -1689,8 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -1704,13 +1714,13 @@ void fused_attn_fp8_fwd_impl_v1( scaling_factor, is_training, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + softmax_type, + window_size_left, + window_size_right, + bottom_right_diagonal, true, qkv_tensor_type, o_tensor_type, @@ -1736,6 +1746,7 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // amax_o std::shared_ptr, // Stats std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -1762,31 +1773,30 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr descale_q, descale_k, descale_v; std::shared_ptr descale_s, scale_s, scale_o; - std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr bias, softmax_offset, seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + // Q, K, V, attn_scale + std::vector q_strides(4); + std::vector k_strides(4); + std::vector v_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_strides) + .set_data_type(qkv_tensor_type)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_strides) + .set_data_type(qkv_tensor_type)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_strides) + .set_data_type(qkv_tensor_type)); attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -1794,21 +1804,58 @@ void fused_attn_fp8_fwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - - if (is_delayed_scaling) { - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Scale_o + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_o"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } } - if (is_current_scaling) { - scale_o = mha_graph->tensor(1.0f); + if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + std::vector q_scale_strides(4); + std::vector k_scale_strides(4); + std::vector v_scale_strides(4); + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_v_padded, + v_scale_strides.data(), kv_format); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_v_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_attributes sdpa_options; @@ -1818,6 +1865,18 @@ void fused_attn_fp8_fwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } + // sdpa_options.set_alibi_mask(is_alibi); // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -1855,19 +1914,41 @@ void fused_attn_fp8_fwd_impl_v1( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( - Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + // sdpa_options.set_sink_token(softmax_offset); + } - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); - amax_o->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); + std::shared_ptr O, Stats, amax_s, amax_o; + if (is_delayed_scaling || is_current_scaling) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_s = outputs[2]; + amax_o = outputs[3]; + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } else if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, sdpa_options); + O = outputs[0]; + Stats = outputs[1]; + amax_o = outputs[2]; + } - amax_s->set_output(true) + std::vector o_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); + O->set_output(true) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_strides) + .set_data_type(o_tensor_type); + amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); @@ -1890,10 +1971,15 @@ void fused_attn_fp8_fwd_impl_v1( std::shared_ptr, // O std::shared_ptr, // amax_s std::shared_ptr> // amax_o - key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, - scale_s, scale_o, attn_scale, O, amax_s, amax_o); + key_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, nullptr, nullptr, + nullptr, attn_scale, O, nullptr, amax_o) + : std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -1904,17 +1990,17 @@ void fused_attn_fp8_fwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, - attn_scale, O, amax_s, amax_o, Stats, bias, seq_q, seq_kv, dropout_seed, dropout_offset] = - get_graph(sdpa_fp8_fprop_cache, descriptor); + attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, dropout_seed, + dropout_offset] = get_graph(sdpa_fp8_fprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -1937,17 +2023,19 @@ void fused_attn_fp8_fwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_s, devPtrDescaleS}, - {scale_s, devPtrScaleS}, {attn_scale, &scaling_factor}, {O, devPtrO}, - {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; if (is_delayed_scaling) { variant_pack[scale_o] = devPtrScaleO; } + if (!is_mxfp8) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[amax_s] = devPtrAmaxS; + } /* if (is_bias) { variant_pack[bias] = devPtrBias; @@ -1972,6 +2060,10 @@ void fused_attn_fp8_fwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -1980,20 +2072,26 @@ void fused_attn_fp8_fwd_impl_v1( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl_v1( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, - void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, - void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, - void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, - void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, + void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, + void* devPtrdK, void* devPtrdV, void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, + void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, + void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, + void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, + void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, + void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, + void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -2001,20 +2099,26 @@ void fused_attn_fp8_bwd_impl_v1( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); auto bias_b = b; auto bias_h = h; - const auto cudnn_runtime_version = cudnnGetVersion(); auto bias_sq = s_q; auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || - dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); - bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + bool is_delayed_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); - NVTE_CHECK(is_current_scaling || is_delayed_scaling, - "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " - "kFloat8E5M2!"); + bool is_current_scaling = (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_mxfp8 = (scaling_mode == NVTE_MXFP8_1D_SCALING) && + (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + NVTE_CHECK( + is_delayed_scaling || is_current_scaling || is_mxfp8, + "FP8 fused attention only supports FP8DelayedScaling or FP8CurrentScaling or MXFP8 recipes!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -2024,8 +2128,8 @@ void fused_attn_fp8_bwd_impl_v1( hg, s_q, s_kv, - d, - d, + d_qk, + d_v, 0, 0, 0, @@ -2039,13 +2143,13 @@ void fused_attn_fp8_bwd_impl_v1( scaling_factor, true, dropout_probability, - layout, + qkv_layout, bias_type, mask_type, - NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, - 0, - 0, - true, + softmax_type, + window_size_left, + window_size_right, + bottom_right_diagonal, deterministic, qkv_tensor_type, o_tensor_type, @@ -2056,18 +2160,25 @@ void fused_attn_fp8_bwd_impl_v1( namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, - std::shared_ptr, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::shared_ptr, // Q + std::shared_ptr, // Q_t + std::shared_ptr, // K + std::shared_ptr, // K_t + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO + std::shared_ptr, // dO_t + std::shared_ptr, // dO_f16 std::shared_ptr, // attn_scale std::shared_ptr, // descale_q + std::shared_ptr, // descale_q_t std::shared_ptr, // descale_k + std::shared_ptr, // descale_k_t std::shared_ptr, // descale_v std::shared_ptr, // descale_o std::shared_ptr, // descale_dO + std::shared_ptr, // descale_dO_t std::shared_ptr, // descale_s std::shared_ptr, // descale_dP std::shared_ptr, // scale_dQ @@ -2084,6 +2195,8 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dP std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // dropout_seed @@ -2108,54 +2221,58 @@ void fused_attn_fp8_bwd_impl_v1( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr Q, Q_t, K, K_t, V, O, dO, dO_t, dO_f16, Stats, + attn_scale; + std::shared_ptr descale_q, descale_q_t, descale_k, descale_k_t, + descale_v; std::shared_ptr descale_s, descale_o; - std::shared_ptr descale_dP, descale_dO; + std::shared_ptr descale_dP, descale_dO, descale_dO_t; std::shared_ptr scale_s, scale_dP; std::shared_ptr scale_dQ, scale_dK, scale_dV; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset; + std::shared_ptr seq_q, seq_kv; std::shared_ptr dropout_seed, dropout_offset; - std::vector q_stride(4); - std::vector k_stride(4); - std::vector v_stride(4); - std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() + // Q, K, V, O, dO, stats, attn_scale + std::vector q_strides(4); + std::vector k_strides(4); + std::vector v_strides(4); + std::vector o_strides(4); + std::vector dO_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, q_strides.data(), + k_strides.data(), v_strides.data(), qkv_layout); + generateMatrixStridesWithFormat(b, h, s_q, d_v, o_strides.data(), o_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_strides.data(), d_out_format); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_strides) + .set_data_type(qkv_tensor_type)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_strides) + .set_data_type(qkv_tensor_type)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_strides) + .set_data_type(qkv_tensor_type)); + O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride) + .set_dim({b, h, s_q, d_v}) + .set_stride(o_strides) .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_strides) + .set_data_type(do_tensor_type)); + Stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") .set_dim({b, h, s_q, 1}) .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); - attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -2163,33 +2280,138 @@ void fused_attn_fp8_bwd_impl_v1( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Descale_q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); - descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); - descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); - if (is_O_in_F16) { - descale_o = mha_graph->tensor(1.0f); - } else { - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + // Descale_q, Descale_k, Descale_v, Descale_s, Scale_s, Descale_dP, Scale_dP, Descale_o, Descale_dO, Scale_dQ, Scale_dK, Scale_dV + if (is_delayed_scaling || is_current_scaling) { + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_v"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_s"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_s"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + if (is_current_scaling && is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } } - descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); - scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - - if (is_delayed_scaling) { - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); - } - if (is_current_scaling) { - scale_dQ = mha_graph->tensor(1.0f); - scale_dK = mha_graph->tensor(1.0f); - scale_dV = mha_graph->tensor(1.0f); + if (is_mxfp8) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + // Q_t, K_t, dO_t, dO_f16 + std::vector q_t_strides(4); + std::vector k_t_strides(4); + std::vector dO_t_strides(4); + generateMatrixStridesWithFormat(b, h, s_q, d_qk, q_t_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, s_kv, d_qk, k_t_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, h, s_q, d_v, dO_t_strides.data(), d_out_format); + Q_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q_t") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_t_strides) + .set_data_type(qkv_tensor_type)); + K_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K_t") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_t_strides) + .set_data_type(qkv_tensor_type)); + dO_t = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_t") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_t_strides) + .set_data_type(do_tensor_type)); + dO_f16 = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO_f16") + .set_dim({b, h, s_q, d_v}) + .set_stride(dO_strides) + .set_data_type(o_tensor_type)); + // Descale_q, Descale_q_t, Descale_k, Descale_k_t, Descale_v, Descale_dO, Descale_dO_t + auto padded = pad_s_d_for_mxfp8(s_q, s_kv, d_qk, d_v); + std::vector q_scale_strides(4); + std::vector q_t_scale_strides(4); + std::vector k_scale_strides(4); + std::vector k_t_scale_strides(4); + std::vector v_scale_strides(4); + std::vector dO_scale_strides(4); + std::vector dO_t_scale_strides(4); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_qk_scale_padded, + q_scale_strides.data(), q_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_qk_padded, + q_t_scale_strides.data(), q_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_qk_scale_padded, + k_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_scale_padded, padded.d_qk_padded, + k_t_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, hg, padded.s_kv_padded, padded.d_v_scale_padded, + v_scale_strides.data(), kv_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_padded, padded.d_v_scale_padded, + dO_scale_strides.data(), d_out_format); + generateMatrixStridesWithFormat(b, h, padded.s_q_scale_padded, padded.d_v_padded, + dO_t_scale_strides.data(), d_out_format); + descale_q = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({b, h, padded.s_q_padded, padded.d_qk_scale_padded}) + .set_stride(q_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_q_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_qk_padded}) + .set_stride(q_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k") + .set_dim({b, hg, padded.s_kv_padded, padded.d_qk_scale_padded}) + .set_stride(k_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_k_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_k_t") + .set_dim({b, hg, padded.s_kv_scale_padded, padded.d_qk_padded}) + .set_stride(k_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_v") + .set_dim({b, hg, padded.s_kv_padded, padded.d_v_scale_padded}) + .set_stride(v_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO") + .set_dim({b, h, padded.s_q_padded, padded.d_v_scale_padded}) + .set_stride(dO_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); + descale_dO_t = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_dO_t") + .set_dim({b, h, padded.s_q_scale_padded, padded.d_v_padded}) + .set_stride(dO_t_scale_strides) + .set_data_type(fe::DataType_t::FP8_E8M0) + .set_reordering_type(fe::TensorReordering_t::F8_128x4)); } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; @@ -2198,6 +2420,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_causal_mask(is_causal) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const& diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); + } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } + // sdpa_backward_options.set_alibi_mask(is_alibi); // if (is_bias) { @@ -2251,13 +2485,63 @@ void fused_attn_fp8_bwd_impl_v1( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( - q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, - descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, sdpa_backward_options); + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + // sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + // sdpa_backward_options.set_dsink_token(d_softmax_offset); + } - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + std::shared_ptr dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP; + if (is_delayed_scaling || is_current_scaling) { + auto outputs = mha_graph->sdpa_fp8_backward(Q, K, V, O, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, + scale_dV, scale_dP, sdpa_backward_options); + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + amax_dP = outputs[6]; + } + if (is_mxfp8) { + auto outputs = mha_graph->sdpa_fp8_backward( + Q, Q_t, K, K_t, V, O, dO_f16, dO, dO_t, Stats, descale_q, descale_q_t, descale_k, + descale_k_t, descale_v, descale_dO, descale_dO_t, sdpa_backward_options); + dQ = outputs[0]; + dK = outputs[1]; + dV = outputs[2]; + amax_dQ = outputs[3]; + amax_dK = outputs[4]; + amax_dV = outputs[5]; + } + std::vector dq_strides(4); + std::vector dk_strides(4); + std::vector dv_strides(4); + generateMatrixStridesWithLayout(b, h, hg, s_q, s_kv, d_qk, d_v, dq_strides.data(), + dk_strides.data(), dv_strides.data(), dqkv_layout); + dQ->set_output(true) + .set_dim({b, h, s_q, d_qk}) + .set_stride(dq_strides) + .set_data_type(dqkv_tensor_type); + dK->set_output(true) + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(dk_strides) + .set_data_type(dqkv_tensor_type); + dV->set_output(true) + .set_dim({b, hg, s_kv, d_v}) + .set_stride(dv_strides) + .set_data_type(dqkv_tensor_type); amax_dQ->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -2270,21 +2554,18 @@ void fused_attn_fp8_bwd_impl_v1( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true) - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT); - - dO->set_data_type(do_tensor_type); - dQ->set_data_type(dqkv_tensor_type); - dK->set_data_type(dqkv_tensor_type); - dV->set_data_type(dqkv_tensor_type); + if (is_delayed_scaling || is_current_scaling) { + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + } - std::tuple, // q - std::shared_ptr, // k - std::shared_ptr, // v - std::shared_ptr, // o - std::shared_ptr, // stats + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // O + std::shared_ptr, // Stats std::shared_ptr, // dO std::shared_ptr, // attn_scale std::shared_ptr, // descale_q @@ -2307,10 +2588,16 @@ void fused_attn_fp8_bwd_impl_v1( std::shared_ptr, // amax_dV std::shared_ptr> // amax_dP key_tensors_tuple = std::make_tuple( - q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP); + auto mxfp8_tensors_tuple = + is_mxfp8 ? std::make_tuple(Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -2322,17 +2609,18 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, + bias_tuple, softmax_offset_tuple, padding_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - - auto [mha_graph, q, k, v, o, stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, + auto [mha_graph, Q, K, V, O, Stats, dO, attn_scale, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, - dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, bias, dBias, seq_q, seq_kv, dropout_seed, - dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); + dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, + descale_k_t, descale_dO_t, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, + dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -2349,37 +2637,47 @@ void fused_attn_fp8_bwd_impl_v1( // build variant pack std::unordered_map, void*> variant_pack = { - {q, devPtrQ}, - {k, devPtrK}, - {v, devPtrV}, - {o, devPtrO}, - {stats, devPtrM}, + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {Stats, devPtrM}, {dO, devPtrdO}, {attn_scale, &scaling_factor}, {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, {descale_dO, devPtrDescaledO}, - {descale_s, devPtrDescaleS}, - {descale_dP, devPtrDescaledP}, - {scale_s, devPtrScaleS}, - {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, {amax_dQ, devPtrAmaxdQ}, {amax_dK, devPtrAmaxdK}, {amax_dV, devPtrAmaxdV}, - {amax_dP, devPtrAmaxdP}, }; - + if (is_delayed_scaling || is_current_scaling) { + variant_pack[descale_s] = devPtrDescaleS; + variant_pack[descale_dP] = devPtrDescaledP; + variant_pack[scale_s] = devPtrScaleS; + variant_pack[scale_dP] = devPtrScaledP; + variant_pack[amax_dP] = devPtrAmaxdP; + } + if (is_delayed_scaling || (is_current_scaling && !is_O_in_F16)) { + variant_pack[descale_o] = devPtrDescaleO; + } if (is_delayed_scaling) { variant_pack[scale_dQ] = devPtrScaledQ; variant_pack[scale_dK] = devPtrScaledK; variant_pack[scale_dV] = devPtrScaledV; } - if (!is_O_in_F16) { - variant_pack[descale_o] = devPtrDescaleO; + if (is_mxfp8) { + variant_pack[Q_t] = devPtrQ_t; + variant_pack[K_t] = devPtrK_t; + variant_pack[dO_f16] = devPtrdO_f16; + variant_pack[dO_t] = devPtrdO_t; + variant_pack[descale_q_t] = devPtrDescaleQ_t; + variant_pack[descale_k_t] = devPtrDescaleK_t; + variant_pack[descale_dO_t] = devPtrDescaledO_t; } /* if (is_bias) { @@ -2410,6 +2708,11 @@ void fused_attn_fp8_bwd_impl_v1( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -2423,33 +2726,50 @@ void fused_attn_fp8_bwd_impl_v1( #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, Tensor* input_output_S, Tensor* output_O, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, + const Tensor* input_K, const Tensor* input_V, + const Tensor* input_SoftmaxOffset, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - void* devPtrQ = input_Q->data.dptr; - void* devPtrK = input_K->data.dptr; - void* devPtrV = input_V->data.dptr; - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - + void *devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; + void *devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; + void *devPtrO = nullptr, *devPtrAmaxO = nullptr, *devPtrScaleO = nullptr; + void *devPtrAmaxS = nullptr, *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr; + devPtrQ = input_Q->data.dptr; + devPtrDescaleQ = input_Q->scale_inv.dptr; + devPtrK = input_K->data.dptr; + devPtrDescaleK = input_K->scale_inv.dptr; + devPtrO = output_O->data.dptr; + devPtrAmaxO = output_O->amax.dptr; + if (input_Q->scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + devPtrV = input_V->data.dptr; + devPtrDescaleV = input_V->scale_inv.dptr; + devPtrScaleO = output_O->scale.dptr; + devPtrAmaxS = input_output_S->amax.dptr; + devPtrScaleS = input_output_S->scale.dptr; + devPtrDescaleS = input_output_S->scale_inv.dptr; + } else if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrV = input_V->columnwise_data.dptr; + devPtrDescaleV = input_V->columnwise_scale_inv.dptr; + } + void* devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + int i = 0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; @@ -2459,21 +2779,29 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 3) { + int i = 0; + Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; devPtrZInv = output_ZInv->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor* output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); void* devPtrcuSeqlensKV = @@ -2488,17 +2816,20 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, + devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, + devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, @@ -2521,24 +2852,34 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } } // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, - const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, - const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, + const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, + const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, + const Tensor* input_S, const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, + const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, + Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, + const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; void* devPtrV = input_V->data.dptr; void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_Q->scale_inv.dptr; - void* devPtrDescaleV = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_K->scale_inv.dptr; + void* devPtrDescaleV = input_V->scale_inv.dptr; + void *devPtrQ_t = nullptr, *devPtrK_t = nullptr, *devPtrDescaleQ_t = nullptr, + *devPtrDescaleK_t = nullptr; + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrQ_t = input_Q->columnwise_data.dptr; + devPtrDescaleQ_t = input_Q->columnwise_scale_inv.dptr; + devPtrK_t = input_K->columnwise_data.dptr; + devPtrDescaleK_t = input_K->columnwise_scale_inv.dptr; + } void* devPtrO = input_O->data.dptr; const DType O_type = input_O->data.dtype; @@ -2548,6 +2889,12 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; + void *devPtrdO_t = nullptr, *devPtrdO_f16 = nullptr, *devPtrDescaledO_t = nullptr; + if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { + devPtrdO_t = input_dO->columnwise_data.dptr; + devPtrdO_f16 = input_dO_f16->data.dptr; + devPtrDescaledO_t = input_dO->columnwise_scale_inv.dptr; + } void* devPtrM = input_M->data.dptr; void* devPtrZInv = input_ZInv->data.dptr; @@ -2558,6 +2905,13 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrScaledP = input_output_dP->scale.dptr; void* devPtrDescaledP = input_output_dP->scale_inv.dptr; + void* devPtrSoftmaxOffset = nullptr; + void* devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } + void* devPtrdQ = output_dQ->data.dptr; void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; @@ -2582,21 +2936,25 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); + if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || + (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, deterministic, devPtrQ, devPtrK, devPtrV, - devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, + devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, + devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + input_dO->scaling_mode, workspace->data.dptr, &workspace_size, stream, handle); + } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 225e700eff..617efa8f42 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,26 +15,31 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, + size_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, + const Tensor *input_SoftmaxOffset, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with separate Q, K, V -void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, bool deterministic, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); +void fused_attn_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, + bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, + const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, + const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // end of CUDNN>=8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index a897b09330..f37eeb0c68 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -293,6 +293,27 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * s_q * d; + strideA[head_dim_idx] = s_q * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_dim_idx] = d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * s_kv * d; + strideA[head_dim_idx] = s_kv * d; + strideA[seqlen_transpose_dim_idx] = d; + strideA[hidden_transpose_dim_idx] = 1; + } + break; } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 08a56cda6b..fdda4dfe9c 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -27,11 +27,269 @@ enum NVTE_QKV_Matrix { NVTE_K_Matrix = 1, // keys NVTE_K_Matrix_Transpose = 2, // keys transposed NVTE_V_Matrix = 3, // values - NVTE_V_Matrix_Transpose = 4, // value matrix transposed + NVTE_V_Matrix_Transpose = 4, // values transposed NVTE_S_Matrix = 5, // output of GEMM1 NVTE_O_Matrix = 6, // final output }; +// Padded sizes for MXFP8 layout (s_q/s_kv/d_qk/d_v and their scaled dimensions) +struct MXFP8PaddedSizes { + int64_t s_q_padded; + int64_t s_kv_padded; + int64_t s_q_scale; + int64_t s_kv_scale; + int64_t s_q_scale_padded; + int64_t s_kv_scale_padded; + int64_t d_qk_padded; + int64_t d_v_padded; + int64_t d_qk_scale; + int64_t d_v_scale; + int64_t d_qk_scale_padded; + int64_t d_v_scale_padded; +}; + +// Pad s and d for MXFP8 quantization +inline MXFP8PaddedSizes pad_s_d_for_mxfp8(int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v) { + constexpr int64_t block_size = 32; + MXFP8PaddedSizes p; + p.s_q_padded = ((s_q + 127) / 128) * 128; + p.s_kv_padded = ((s_kv + 127) / 128) * 128; + p.s_q_scale = (s_q + block_size - 1) / block_size; + p.s_kv_scale = (s_kv + block_size - 1) / block_size; + p.s_q_scale_padded = ((p.s_q_scale + 3) / 4) * 4; + p.s_kv_scale_padded = ((p.s_kv_scale + 3) / 4) * 4; + p.d_qk_padded = ((d_qk + 127) / 128) * 128; + p.d_v_padded = ((d_v + 127) / 128) * 128; + p.d_qk_scale = (d_qk + block_size - 1) / block_size; + p.d_v_scale = (d_v + block_size - 1) / block_size; + p.d_qk_scale_padded = ((p.d_qk_scale + 3) / 4) * 4; + p.d_v_scale_padded = ((p.d_v_scale + 3) / 4) * 4; + return p; +} + +// Get matrix strides for a 4D tensor [batch_size, num_heads, sequence_len, head_dim] given a QKV format. +// strides must point to at least 4 int64_t elements. +inline void generateMatrixStridesWithFormat(int64_t b, int64_t h, int64_t s, int64_t d, + int64_t *strides, NVTE_QKV_Format format) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; + + switch (format) { + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_THD: + strides[b_dim] = s * h * d; + strides[h_dim] = d; + strides[s_dim] = h * d; + strides[d_dim] = 1; + break; + case NVTE_QKV_Format::NVTE_SBHD: + strides[b_dim] = h * d; + strides[h_dim] = d; + strides[s_dim] = b * h * d; + strides[d_dim] = 1; + break; + case NVTE_QKV_Format::NVTE_BHSD: + strides[b_dim] = h * s * d; + strides[h_dim] = s * d; + strides[s_dim] = d; + strides[d_dim] = 1; + break; + default: + NVTE_CHECK(false, "Invalid format."); + break; + } +} + +// get matrix strides based on layout and matrix type +inline void generateMatrixStridesWithLayout(int64_t b, int64_t h, int64_t hg, int64_t s_q, + int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t *q_strides, int64_t *k_strides, + int64_t *v_strides, NVTE_QKV_Layout layout) { + constexpr int b_dim = 0; + constexpr int h_dim = 1; + constexpr int s_dim = 2; + constexpr int d_dim = 3; + + switch (layout) { + case NVTE_QKV_Layout::NVTE_SB3HD: + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBH3D: + q_strides[b_dim] = 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = b * 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = b * 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = b * hg * d_v; + v_strides[d_dim] = 1; + break; + case NVTE_QKV_Layout::NVTE_BS3HD: + case NVTE_QKV_Layout::NVTE_T3HD: + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSH3D: + case NVTE_QKV_Layout::NVTE_TH3D: + q_strides[b_dim] = s_q * 3 * h * d_qk; + q_strides[h_dim] = 3 * d_qk; + q_strides[s_dim] = 3 * h * d_qk; + q_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + k_strides[i] = v_strides[i] = q_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: + case NVTE_QKV_Layout::NVTE_THD_T2HD: + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: + case NVTE_QKV_Layout::NVTE_THD_TH2D: + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * 2 * hg * d_qk; + k_strides[h_dim] = 2 * d_qk; + k_strides[s_dim] = 2 * hg * d_qk; + k_strides[d_dim] = 1; + for (int i = 0; i < 4; i++) { + v_strides[i] = k_strides[i]; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = s_kv * hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = hg * d_v; + v_strides[d_dim] = 1; + break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + q_strides[b_dim] = h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = b * h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = s_kv * hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = s_kv * hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = hg * d_v; + v_strides[d_dim] = 1; + break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + q_strides[b_dim] = s_q * h * d_qk; + q_strides[h_dim] = d_qk; + q_strides[s_dim] = h * d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = hg * d_qk; + k_strides[h_dim] = d_qk; + k_strides[s_dim] = b * hg * d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = hg * d_v; + v_strides[h_dim] = d_v; + v_strides[s_dim] = b * hg * d_v; + v_strides[d_dim] = 1; + break; + case NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD: + q_strides[b_dim] = h * s_q * d_qk; + q_strides[h_dim] = s_q * d_qk; + q_strides[s_dim] = d_qk; + q_strides[d_dim] = 1; + k_strides[b_dim] = hg * s_kv * d_qk; + k_strides[h_dim] = s_kv * d_qk; + k_strides[s_dim] = d_qk; + k_strides[d_dim] = 1; + v_strides[b_dim] = hg * s_kv * d_v; + v_strides[h_dim] = s_kv * d_v; + v_strides[s_dim] = d_v; + v_strides[d_dim] = 1; + break; + default: + NVTE_CHECK(false, "Invalid layout."); + break; + } +} + void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8169bf22e2..04f7ec4a6c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -52,6 +52,7 @@ enum NVTE_QKV_Layout { NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ + NVTE_BHSD_BHSD_BHSD = 25, /*!< BHSD_BHSD_BHSD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -70,6 +71,8 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_HD_HD = 4, /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ NVTE_Paged_KV_HD_HD_HD = 5, + /*! SD_SD_SD QKV layouts, e.g. BHSD_BHSD_BHSD */ + NVTE_SD_SD_SD = 6, }; /*! \enum NVTE_QKV_Format @@ -90,6 +93,8 @@ enum NVTE_QKV_Format { NVTE_THD_2BSHD = 5, /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ NVTE_THD_2SBHD = 6, + /*! BHSD QKV format, e.g. BHSD_BHSD_BHSD */ + NVTE_BHSD = 7, }; /*! \enum NVTE_Bias_Type @@ -188,6 +193,22 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \brief Convert one NVTE_QKV_Format to another. + * + * \param[in] src_format The source format. + * \param[in] src_shape The source shape. + * \param[in] dst_format The destination format. + * \param[out] dst_shape The destination shape. + * \param[out] b The batch size. + * \param[out] h The number of heads. + * \param[out] s The sequence length. + * \param[out] d The head dimension. + * \param[out] t The number of tokens. + */ +void nvte_convert_qkv_format(NVTE_QKV_Format src_format, std::vector src_shape, + NVTE_QKV_Format dst_format, std::vector &dst_shape, size_t *b, + size_t *h, size_t *s, size_t *d, size_t *t); + /*! \brief Get fused attention backend based on input parameters. * * \param[in] is_training Whether the model is in training mode. @@ -274,6 +295,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -283,19 +305,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -347,6 +367,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. + * \param[in] o_format Output format. + * \param[in] d_out_format Output gradient's format. + * \param[in] dqkv_layout QKV gradient tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. * \param[in] softmax_type Attention softmax type. @@ -366,11 +389,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format d_out_format, NVTE_QKV_Layout dqkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6adba23a8f..96e6803ec5 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -48,7 +48,8 @@ .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ - .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD) \ + .value("NVTE_BHSD", NVTE_QKV_Format::NVTE_BHSD); \ pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -74,7 +75,8 @@ .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ - .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ + .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a6a8b0b26a..5ab3ff2507 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,12 +29,14 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8TensorStorage from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, @@ -152,6 +154,16 @@ # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +_run_shadow_f16_fwd = os.getenv("NVTE_RUN_SHADOW_F16_FWD", "0") == "1" +_replace_out_return_with_shadow_f16 = ( + os.getenv("NVTE_REPLACE_OUT_RETURN_WITH_SHADOW_F16", "0") == "1" +) +_replace_out_save_with_shadow_f16 = os.getenv("NVTE_REPLACE_OUT_SAVE_WITH_SHADOW_F16", "0") == "1" +_replace_aux_with_shadow_f16 = os.getenv("NVTE_REPLACE_AUX_WITH_SHADOW_F16", "0") == "1" +_run_shadow_f16_bwd = os.getenv("NVTE_RUN_SHADOW_F16_BWD", "0") == "1" +_replace_dq_with_shadow_f16 = os.getenv("NVTE_REPLACE_DQ_WITH_SHADOW_F16", "0") == "1" +_replace_dk_with_shadow_f16 = os.getenv("NVTE_REPLACE_DK_WITH_SHADOW_F16", "0") == "1" +_replace_dv_with_shadow_f16 = os.getenv("NVTE_REPLACE_DV_WITH_SHADOW_F16", "0") == "1" class FP8EmulationFunc(torch.autograd.Function): @@ -173,15 +185,23 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] ] - q_fp8, k_fp8, v_fp8 = combine_and_quantize( + # sbhd_sbhd_sbhd should always be the shape at this point + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( qkv_layout, query_layer, key_layer, value_layer, quantizer ) tensors = combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype ) + if isinstance(quantizer, MXFP8Quantizer): + # bhsd_bhsd_bhsd should always be the shape at this point + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] elif quantizer_name in ["S_quantizer", "O_quantizer"]: - t_fp8 = quantizer(tensor1) - tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + if quantizer is not None: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) else: tensors = (tensor1, tensor2, tensor3) ctx.quantizer = quantizer @@ -193,16 +213,23 @@ def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layou def backward(ctx, grad1, grad2, grad3): # pylint: disable=missing-function-docstring if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: - dt_fp8 = ctx.quantizer(grad1) - tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + if ctx.quantizer is not None: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + else: + tensors = grad1, grad2, grad3 elif ctx.quantizer_name == "dQKV_quantizer": query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] - dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + dq_fp8, dk_fp8, dv_fp8, ctx.qkv_layout = combine_and_quantize( ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer ) tensors = combine_and_dequantize( ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype ) + if isinstance(ctx.quantizer, MXFP8Quantizer): + # bhsd_bhsd_bhsd should always be the shape at this point + # permute back to sbhd_sbhd_sbhd + tensors = [x.permute(2, 0, 1, 3).contiguous() for x in tensors] else: tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None @@ -379,6 +406,7 @@ def forward( query_layer.shape[0], key_layer.shape[0], ) + apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 if "padding" in attn_mask_type and attention_mask is None: attention_mask = dpa_utils.get_padding_mask( @@ -405,10 +433,7 @@ def forward( ) ) - batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] - apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 - - # [b, np, sq, sk] + # [b, h, sq, sk] output_size = ( query_layer.size(1), query_layer.size(2), @@ -427,12 +452,7 @@ def forward( int(query_layer.shape[2] / value_layer.shape[2]), dim=2 ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting result tensor: [b * np, sq, sk] + # preallocting result tensor: [b * h, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], @@ -446,14 +466,15 @@ def forward( scale /= self.layer_number if fp8: + # get fp8 recipe for DPA + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=S_quantizer.dtype, device="cuda" @@ -461,25 +482,50 @@ def forward( dP_quantizer = Float8CurrentScalingQuantizer( fp8_dtype=dP_quantizer.dtype, device="cuda" ) + # disable swizzle for MXFP8Quantizer + for q in [ + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ]: + if isinstance(q, MXFP8Quantizer): + q.optimize_for_gemm = False + q.internal = False - if "2" in qkv_layout or "3" in qkv_layout: - qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) - qkv_layout = "_".join([qkv_format] * 3) + # q, k, v are in sbhd after previous reshaping # quantize and dequantize QKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + QKV_quantizer, + "QKV_quantizer", + "sbhd_sbhd_sbhd", ) # quantize and dequantize dQKV to emulate FP8 query_layer, key_layer, value_layer = FP8EmulationFunc.apply( - query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + query_layer, + key_layer, + value_layer, + dQKV_quantizer, + "dQKV_quantizer", + "sbhd_sbhd_sbhd", ) - # Raw attention scores. [b * np, sq, sk] + # [sq, b, h, d] -> [sq, b * h, d] + query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, h, d] -> [sk, b * h, d] + key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) + + # Raw attention scores. [b * h, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ).view(*output_size) @@ -487,8 +533,8 @@ def forward( elif core_attention_bias_type == "pre_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" matmul_result = torch.bmm( - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] ) matmul_result = matmul_result.view(*output_size) + core_attention_bias matmul_result *= scale @@ -513,8 +559,8 @@ def forward( ) matmul_result = torch.baddbmm( matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * h, sq, d] + key_layer.transpose(0, 1).transpose(1, 2), # [b * h, d, sk] beta=0.0, alpha=scale, ) @@ -531,13 +577,13 @@ def forward( # max attention score max_logit = None if self.return_max_logit: - # matmul_result [b, np, sq, dk], max_logit [np] + # matmul_result [b, h, sq, dk], max_logit [h] max_logit = matmul_result if attn_mask_type != "no_mask": max_logit = self.mask_func(matmul_result, attention_mask) max_logit = torch.amax(max_logit, dim=(0, 2, 3)) - # add attention sink to the last column: [b, np, sq, sk+1] + # add attention sink to the last column: [b, h, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( [ @@ -562,7 +608,7 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) - # remove attention sink: [b, np, sq, sk] + # remove attention sink: [b, h, sq, sk] if self.softmax_type != "vanilla": attention_probs = attention_probs[..., :-1] @@ -572,7 +618,7 @@ def forward( attention_probs = self.attention_dropout(attention_probs) # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] + # [sk, b, h, d] --> [b, h, sq, d] output_size = ( value_layer.size(1), value_layer.size(2), @@ -580,10 +626,10 @@ def forward( value_layer.size(3), ) - # change view [sk, b * np, hn] + # change view [sk, b * h, d] value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] + # change view [b * h, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) if fp8: @@ -592,37 +638,37 @@ def forward( attention_probs, None, None, S_quantizer, "S_quantizer", None ) - # matmul: [b * np, sq, hn] + # matmul: [b * h, sq, d] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] + # change view [b, h, sq, d] context_layer = context_layer.view(*output_size) if q_format == "sbhd": - # [b, np, sq, hn] --> [sq, b, np, hn] + # [b, h, sq, d] --> [sq, b, h, d] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(seqlen, batch_size, -1) + # [sq, b, h, d] --> [sq, b, hd] + context_layer = context_layer.view(max_seqlen_q, batch_size, -1) if q_format == "bshd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [b, sq, hp] - context_layer = context_layer.view(batch_size, seqlen, -1) + # [b, sq, h, d] --> [b, sq, hd] + context_layer = context_layer.view(batch_size, max_seqlen_q, -1) if q_format == "thd": - # [b, np, sq, hn] --> [b, sq, np, hn] + # [b, h, sq, d] --> [b, sq, h, d] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # [b, sq, np, hn] --> [tq, np, hn] + # [b, sq, h, d] --> [tq, h, d] context_layer = ConvertBSHDtoTHD.apply( context_layer, cu_seqlens_q, ) - # [tq, np, hn] --> [tq, hp] + # [tq, h, d] --> [tq, hd] context_layer = context_layer.view(context_layer.shape[0], -1) if fp8: @@ -1198,21 +1244,26 @@ def forward( if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + # qkv_layout may change due to MXFP8 quantization + # o_format should stay the same as original qkv_format + original_qkv_layout = qkv_layout + _, o_format, _ = dpa_utils.get_qkv_format(qkv_layout) + # input types are inferred from the real data while output types are controlled by fp8_output # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output - # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) - # whether bwd kernel in FP8: + # whether fwd kernel will be run in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel will be run in FP8: is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) # get nominal data type for out @@ -1221,16 +1272,21 @@ def forward( out_nominal_dtype = q.dtype max_logit = None + orig_q, orig_k, orig_v = q, k, v + orig_qkv_layout = qkv_layout if fp8: fused_attention_backend = FusedAttnBackend["FP8"] # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E4M3 + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; + # dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) # print quantizers print_quantizers( @@ -1248,6 +1304,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, @@ -1270,6 +1327,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1280,20 +1338,84 @@ def forward( cuda_graph=is_graph_capturing(), ) - # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + if _run_shadow_f16_fwd: + # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + assert all( + x.dtype in [torch.float16, torch.bfloat16] for x in [q, k, v] + ), "q, k, v must be torch.float16 or torch.bfloat16" + out_f16_, aux_ctx_tensors_f16, *_ = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + orig_q, + orig_k, + orig_v, + out_nominal_dtype, + FusedAttnBackend["F16_arbitrary_seqlen"], + attn_bias, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + None, # s_quantizer + None, # o_quantizer + attn_scale, + dropout_p, + fast_zero_fill, + orig_qkv_layout, + o_format, + attn_bias_type, + attn_mask_type, + softmax_type, + window_size, + bottom_right_diagonal, + rng_gen, + softmax_offset, + return_max_logit, + is_graph_capturing(), + ) + if torch.cuda.current_device() == 0: + print( + f"L{layer_number}: real/shadow out min:" + f" {out_.min():.4f}/{out_f16_.min():.4f}, max:" + f" {out_.max():.4f}/{out_f16_.max():.4f}" + ) + print( + f"L{layer_number}: real/shadow stats min:" + f" {aux_ctx_tensors[0].min():.4f}/{aux_ctx_tensors_f16[0].min():.4f}, max:" + f" {aux_ctx_tensors[0].max():.4f}/{aux_ctx_tensors_f16[0].max():.4f}" + ) + + # out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_fp8 = out_ out = out_ - - if isinstance(out_, Float8Tensor): - if not is_output_fp8 or not is_bwd_fp8: + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if isinstance(out_, QuantizedTensorStorage): + if not is_output_fp8 or bwd_requires_o_f16: out = out_.dequantize().view(out_.shape) else: - if is_output_fp8 or ( - is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): + if is_output_fp8 or bwd_requires_o_fp8: out_fp8 = O_quantizer(out_) # print quantizers @@ -1310,19 +1432,50 @@ def forward( # return appropriate tensors out_ret = out_fp8 if is_output_fp8 else out + if _run_shadow_f16_fwd and _replace_out_return_with_shadow_f16: + out_ret = out_f16_ + if _run_shadow_f16_fwd and _replace_aux_with_shadow_f16: + aux_ctx_tensors[0] = aux_ctx_tensors_f16[0] # save appropriate tensors fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) if is_bwd_fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): fp8_tensors = (q_fp8, k_fp8, v_fp8, None) qkvo_tensors = (None, None, None, out) - else: + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + if _run_shadow_f16_bwd: + qkvo_tensors = (q, k, v, out) else: if is_input_fp8: q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + if _run_shadow_f16_fwd and not _replace_aux_with_shadow_f16: + tmp_quantizer = QKV_quantizer.copy() + if isinstance(tmp_quantizer, MXFP8Quantizer): + tmp_quantizer.optimize_for_gemm = False + q_fp8_, k_fp8_, _, _ = combine_and_quantize( + original_qkv_layout, q, k, v, tmp_quantizer + ) + q_ = q_fp8_.dequantize(dtype=out_nominal_dtype) + k_ = k_fp8_.dequantize(dtype=out_nominal_dtype) + if isinstance(tmp_quantizer, MXFP8Quantizer): + qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) + if qkv_format == "bshd": + q = q_.permute(0, 2, 1, 3).contiguous() + k = k_.permute(0, 2, 1, 3).contiguous() + elif qkv_format == "sbhd": + q = q_.permute(2, 0, 1, 3).contiguous() + k = k_.permute(2, 0, 1, 3).contiguous() + else: + q, k = q_, k_ + if _run_shadow_f16_fwd and _replace_out_save_with_shadow_f16: + out = out_f16_ qkvo_tensors = (q, k, v, out) else: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1348,6 +1501,7 @@ def forward( dropout_p, fast_zero_fill, qkv_layout, + o_format, attn_bias_type, attn_mask_type, softmax_type, @@ -1433,9 +1587,16 @@ def forward( ctx.qkv_layout = reload_layout[:-1] else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout else: ctx.qkv_layout = qkv_layout + if fp8 and not ctx.fp8: + ctx.qkv_layout = original_qkv_layout + ctx.o_format = o_format + # dqkv should have the same layout as the original qkv + ctx.dqkv_layout = original_qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type @@ -1454,14 +1615,21 @@ def forward( @staticmethod def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - - # d_out is expected to be in FP8 if is_output_fp8=True, - # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): - d_out = ctx.dO_quantizer(d_out) - if not ctx.use_FAv2_bwd: - d_out._data = d_out._data.contiguous() - elif not ctx.use_FAv2_bwd: + d_out_shadow_f16 = d_out + + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + d_out_fp8 = None + d_out_format = ctx.o_format + if ctx.fp8: + if ctx.fp8_recipe.mxfp8(): + d_out, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, d_out) + if isinstance(d_out, QuantizedTensorStorage): + d_out_fp8 = d_out + else: + d_out_fp8 = ctx.dO_quantizer(d_out) + if not isinstance(d_out, QuantizedTensorStorage) and not ctx.use_FAv2_bwd: d_out = d_out.contiguous() ( q_fp8, @@ -1480,6 +1648,10 @@ def backward(ctx, d_out, *_args): ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) aux_ctx_tensors = other_tensors + aux_ctx_tensors_shadow_f16 = aux_ctx_tensors + out_shadow_f16 = out + original_qkv_layout = ctx.dqkv_layout + original_qkv_format, *_ = dpa_utils.get_qkv_format(original_qkv_layout) if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() @@ -1523,14 +1695,6 @@ def backward(ctx, d_out, *_args): dqkv_nominal_dtype = ctx.nominal_dtype if ctx.fp8: - # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - if ctx.is_output_fp8: - d_out_fp8 = d_out - else: - d_out_fp8 = ctx.dO_quantizer(d_out) - # print quantizers print_quantizers( "FusedAttnFunc.backward >> before: ", @@ -1543,27 +1707,26 @@ def backward(ctx, d_out, *_args): ctx.dP_quantizer, ) - # get tex.DType for dq, dk, dv data - dqkv_te_dtype = d_out_fp8._fp8_dtype - - # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # q_fp8, k_fp8, v_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16, # fp8_dtype = tex.DType.kFloat8E4M3 - # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor/MXFP8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E5M2 - # out_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # DelayedScaling: + # out_, dq_, dk_, dv_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - # - # dq_, dk_, dv_: - # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 - # fp8_dtype = tex.DType.kFloat8E5M2 - # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_ = ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8 - ) + # Float8CurrentScaling + NVTE_DPA_FP8CS_O_in_F16=1: + # out_, dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # Float8CurrentScaling + NVTE_DPA_FP8CS_O_in_F16=0: + # out_: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # dq_, dk_, dv_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # MXFP8BlockScaling: + # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_ = out + if ctx.fp8_recipe.mxfp8(): + out_ = out + aux_ctx_tensors.append(d_out) dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1575,7 +1738,6 @@ def backward(ctx, d_out, *_args): out_, d_out_fp8, dqkv_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1587,6 +1749,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1595,11 +1760,88 @@ def backward(ctx, d_out, *_args): ctx.deterministic, is_graph_capturing(), ) + if _run_shadow_f16_bwd: + original_qkv_layout = ctx.dqkv_layout + tmp_quantizer = ctx.QKV_quantizer.copy() + if isinstance(tmp_quantizer, MXFP8Quantizer): + tmp_quantizer.optimize_for_gemm = False + q_fp8_, k_fp8_, v_fp8_, _ = combine_and_quantize( + original_qkv_layout, q, k, v, tmp_quantizer + ) + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.dequantize(dtype=dqkv_nominal_dtype) for x in (q_fp8_, k_fp8_, v_fp8_) + ] + if isinstance(tmp_quantizer, MXFP8Quantizer): + if original_qkv_format == "bshd": + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.permute(0, 2, 1, 3).contiguous() + for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16) + ] + elif original_qkv_format == "sbhd": + q_shadow_f16, k_shadow_f16, v_shadow_f16 = [ + x.permute(2, 0, 1, 3).contiguous() + for x in (q_shadow_f16, k_shadow_f16, v_shadow_f16) + ] + dq_shadow_f16, dk_shadow_f16, dv_shadow_f16, *rest = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_shadow_f16, + k_shadow_f16, + v_shadow_f16, + out_shadow_f16, + d_out_shadow_f16, + dqkv_nominal_dtype, + aux_ctx_tensors_shadow_f16, + FusedAttnBackend["F16_arbitrary_seqlen"], + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + None, + None, + None, + ctx.attn_scale, + ctx.dropout_p, + ctx.fast_zero_fill, + original_qkv_layout, + original_qkv_format, + original_qkv_format, + original_qkv_layout, + ctx.attn_bias_type, + ctx.attn_mask_type, + ctx.softmax_type, + ctx.window_size, + ctx.bottom_right_diagonal, + ctx.deterministic, + is_graph_capturing(), + ) + if _replace_dq_with_shadow_f16: + dq_ = dq_shadow_f16 + if _replace_dk_with_shadow_f16: + dk_ = dk_shadow_f16 + if _replace_dv_with_shadow_f16: + dv_ = dv_shadow_f16 + if torch.cuda.current_device() == 0: + print( + f"L{ctx.layer_number}: real/shadow dq min:" + f" {dq_.min():.4f}/{dq_shadow_f16.min():.4f}, max:" + f" {dq_.max():.4f}/{dq_shadow_f16.max():.4f}" + ) + print( + f"L{ctx.layer_number}: real/shadow dk min:" + f" {dk_.min():.4f}/{dk_shadow_f16.min():.4f}, max:" + f" {dk_.max():.4f}/{dk_shadow_f16.max():.4f}" + ) + print( + f"L{ctx.layer_number}: real/shadow dv min:" + f" {dv_.min():.4f}/{dv_shadow_f16.min():.4f}, max:" + f" {dv_.max():.4f}/{dv_shadow_f16.max():.4f}" + ) # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 dq, dk, dv = dq_, dk_, dv_ - is_float8tensor = isinstance(dq_, Float8Tensor) - if is_float8tensor and not ctx.is_input_fp8: + is_quantized_tensor = isinstance(dq_, QuantizedTensorStorage) + if is_quantized_tensor and not ctx.is_input_fp8: # return in F16 dq, dk, dv = combine_and_dequantize( ctx.qkv_layout, @@ -1608,9 +1850,9 @@ def backward(ctx, d_out, *_args): dv_, src_nominal_dtype=dq_.dtype, ) - if not is_float8tensor and ctx.is_input_fp8: + if not is_quantized_tensor and ctx.is_input_fp8: # return in FP8 - dq, dk, dv = combine_and_quantize( + dq, dk, dv, _ = combine_and_quantize( ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer ) @@ -1628,7 +1870,6 @@ def backward(ctx, d_out, *_args): else: if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) - dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, @@ -1641,7 +1882,6 @@ def backward(ctx, d_out, *_args): out, d_out, dqkv_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1653,6 +1893,9 @@ def backward(ctx, d_out, *_args): ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, + ctx.o_format, + d_out_format, + ctx.dqkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, ctx.softmax_type, @@ -1817,9 +2060,9 @@ def forward( fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), "No fused attention backend supports this input combination!" assert all( - x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor) + x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, QuantizedTensorStorage) for x in [query_layer, key_layer, value_layer] - ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors." + ), "FusedAttention only supports FP16 and BF16 data types, or QuantizedTensors." assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "FusedAttention only supports CUDA tensors." @@ -1925,7 +2168,7 @@ def forward( " with FP8!" ) if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + all_quantizers = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) for q in all_quantizers: if isinstance(q, Float8CurrentScalingQuantizer): q.with_amax_reduction = True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 10ba99595b..90df998e09 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -59,6 +59,18 @@ _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def get_bsh_dims(tensor_format): + """Get batch dimension and sequence dimension from tensor format""" + if tensor_format in ["bshd", "sbhd", "bhsd"]: + batch_dim = tensor_format.index("b") + seq_dim = tensor_format.index("s") + head_dim = tensor_format.index("h") + else: # tensor_format == "thd" + batch_dim = seq_dim = tensor_format.index("t") + head_dim = tensor_format.index("h") + return batch_dim, seq_dim, head_dim + + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -237,10 +249,10 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] - # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -251,12 +263,12 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -410,15 +422,32 @@ def flash_attn_a2a_communicate( cp_stream: torch.cuda.Stream, before_attn: bool, qkv_format: str = "bshd", - cu_seqlens_padded: torch.Tensor = None, + cu_seqlens_q_padded: torch.Tensor = None, + cu_seqlens_kv_padded: torch.Tensor = None, + a2a_input_names: List[str] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" - - assert ( - qkv_format != "thd" or cu_seqlens_padded is not None - ), "cu_seqlens_padded is required for THD format!" + assert a2a_input_names in [ + ["q", "k", "v"], + ["out"], + ["dout"], + ["dq", "dk", "dv"], + ], "a2a_input_names must be one of ['q', 'k', 'v'], ['out'], ['dout'], ['dq', 'dk', 'dv']!" + if a2a_input_names in [["out"], ["dout"]]: + assert qkv_format != "thd" or cu_seqlens_q_padded is not None, ( + f"flash_attn_a2a_communicate requires cu_seqlens_q_padded for {a2a_input_names} with" + " THD format!" + ) + if a2a_input_names in [["q", "k", "v"], ["dq", "dk", "dv"]]: + assert qkv_format != "thd" or ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), ( + "flash_attn_a2a_communicate requires cu_seqlens_q_padded and cu_seqlens_kv_padded for" + f" {a2a_input_names} with THD format!" + ) a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + _, _, head_dim = get_bsh_dims(qkv_format) if before_attn: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -430,18 +459,24 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd", "bhsd"]: # reorder the sequence chunks x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] + # [b, h//cp, cp*2, s//2, d] -> [b, h//cp, cp*s, d] a2a_outputs[i - 2] = x.view( *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] ) else: # qkv_format == "thd" - # [cp, t, np//cp, hn] -> [cp*t, np//cp, hn] + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i - 2] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) + # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks a2a_outputs[i - 2] = reorder_seq_chunks_after_a2a_before_attn_thd( @@ -450,14 +485,21 @@ def flash_attn_a2a_communicate( if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] - # or [t, np, hn] -> [t, cp, np//cp, hn] - x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] - # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] - a2a_inputs[i] = x.movedim(-3, 0).contiguous() + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # [s, b, h, d] -> [s, b, cp, h//cp, d] + # [b, h, s, d] -> [b, cp, h//cp, s, d] + # [t, h, d] -> [t, cp, h//cp, d] + x = x.view( + *x.shape[:head_dim], + cp_size, + x.shape[head_dim] // cp_size, + *x.shape[head_dim + 1 :], + ) + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] + # [b, cp, h//cp, s, d] -> [cp, b, h//cp, s, d] + # [t, cp, h//cp, d] -> [cp, t, h//cp, d] + a2a_inputs[i] = x.movedim(head_dim, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): if 0 < i < len(a2a_inputs) + 1: @@ -467,30 +509,57 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - if qkv_format in ["bshd", "sbhd"]: - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + if qkv_format in ["bshd", "sbhd", "bhsd"]: + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] + # [b, h//cp, cp*s, d] -> [b, h//cp, cp*2, s//2, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) else: # qkv_format == "thd" + cu_seqlens_padded = ( + cu_seqlens_q_padded + if a2a_input_names[i] in ["q", "out", "dout", "dq"] + else cu_seqlens_kv_padded + ) # reorder the sequence chunks x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) - # [cp*t, np//cp, hn] -> [cp, t, np//cp, hn] + # [cp*t, h//cp, d] -> [cp, t, h//cp, d] a2a_inputs[i] = x.view(cp_size, -1, *x.shape[-2:]) if i > 1: with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] - # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] - x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] - # or [t, cp, np//cp, hn] -> [t, np, hn] + # [cp, 2, b, s//2, h//cp, d] -> [2, b, s//2, cp, h//cp, d] + # [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [cp, 2, b, h//cp, s//2, d] -> [2, b, cp, h//cp, s//2, d] + # [cp, t, h//cp, d] -> [t, cp, h//cp, d] + tmp_list = [x for x in qkv_format] + if "t" not in qkv_format: + tmp_list.insert(0, "2") + tmp_list.insert(0, "c") + tmp_format = "".join(tmp_list) + head_dim_ = tmp_format.index("h") - 1 + tmp_list.insert(head_dim_, tmp_list.pop(0)) + x = x.movedim(0, head_dim_) + # [2, b, s//2, cp, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # [2, s//2, b, cp, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [2, b, cp, h//cp, s//2, d] -> [b, cp, h//cp, 2, s//2, d] + # [t, cp, h//cp, d] -> [t, cp, h//cp, d] + if "t" not in qkv_format: + tmp_format = "".join(tmp_list) + seq_dim_ = tmp_format.index("s") - 1 + tmp_list.insert(seq_dim_, tmp_list.pop(0)) + x = x.movedim(0, seq_dim_) + else: + seq_dim_ = 0 + x = x.contiguous() + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] + # [b, cp, h//cp, 2, s//2, d] -> [b*h, s, d] + # [t, cp, h//cp, d] -> [t, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -775,13 +844,16 @@ def cp_p2p_fwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step, O_quantizer_per_step, rank, @@ -867,11 +939,17 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_kv_padded_ = cu_seqlens_kv_padded fp8_meta_kwargs = {} + new_qkv_layout = qkv_layout if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step @@ -888,7 +966,8 @@ def cp_p2p_fwd_fused_attn( fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs, @@ -1065,15 +1144,19 @@ def cp_p2p_bwd_fused_attn( softmax_scale, dropout_p, qkv_layout, + o_format, + d_out_format, + dqkv_layout, attn_mask_type, attn_bias_type, deterministic, fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, S_quantizer, dP_quantizer_per_step, dQKV_quantizer_per_step, + QKV_quantizer_per_step, + dO_quantizer_per_step, q_part, k_part, v_part, @@ -1123,16 +1206,26 @@ def cp_p2p_bwd_fused_attn( fp8_meta_kwargs = {} if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip( - [q_fp8, kv_fp8, kv_fp8], - [q_part, k_part, v_part], + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer_per_step ) - ] - if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): - out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) - dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + if not fp8_recipe.mxfp8(): + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + else: + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout_part) + aux_tensors.append(dout_part) + dout_part = dO_quantizer_per_step(dout_part) fp8_meta_kwargs["s_quantizer"] = S_quantizer fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step @@ -1148,7 +1241,6 @@ def cp_p2p_bwd_fused_attn( out_part, dout_part, bwd_nominal_dtype, - bwd_output_te_dtype, aux_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -1156,6 +1248,9 @@ def cp_p2p_bwd_fused_attn( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, + d_out_format=d_out_format, + dqkv_layout=dqkv_layout, attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, deterministic=deterministic, @@ -1313,16 +1408,15 @@ def forward( ) # set up attention args - enable_mla = k.shape[-1] != v.shape[-1] - causal = "causal" in attn_mask_type - if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - + causal = "causal" in attn_mask_type + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = q.shape[:-1] + v.shape[-1:] batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None @@ -1337,13 +1431,10 @@ def forward( else: cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size - max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] - - fused_attn_backend = None amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] @@ -1352,9 +1443,9 @@ def forward( assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." fwd_nominal_dtype = q.dtype - is_input_fp8 = isinstance(q, Float8Tensor) + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; @@ -1362,7 +1453,6 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - ( QKV_quantizer, O_quantizer, @@ -1370,43 +1460,58 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) - q_f16 = None + # q, k, v a2a: gather s and split h + # FP8DS/CS: Float8Tensor -> torch.uint8 -> Float8Tensor + # MXFP8/F16: fwd_nominal_dtype q_fp8, k_fp8, v_fp8 = (None, None, None) - # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if fp8 and is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = (q._data, k._data, v._data) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + [q, k, v], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + True, + qkv_format=qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) - if fp8 and is_input_fp8: + if fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): q_fp8, k_fp8, v_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) ] q, k, v = q_fp8, k_fp8, v_fp8 + post_a2a_o_shape = q.shape[:-1] + v.shape[-1:] # convert qkv to the right type + q_f16 = None + fused_attn_backend = None if fp8: assert use_fused_attention, "FP8 is only supported with Fused Attention!" fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - else: + elif not fp8_recipe.mxfp8(): # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=torch.uint8 q_f16 = q - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] # print quantizers @@ -1427,10 +1532,11 @@ def forward( # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + S_quantizer_per_step[i] = S_quantizer.copy() if S_quantizer is not None else None O_quantizer_per_step[i] = O_quantizer.copy() - O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not fp8_recipe.mxfp8(): + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: # q_f16: torch.Tensor, dtype=fwd_nominal_dtype # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype @@ -1482,7 +1588,6 @@ def forward( attn_bias_ = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) @@ -1553,17 +1658,22 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() + # q, k, v, o: + # causal: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # non-causal: [b, s, h, d] or [s, b, h, d] p2p_comm_buffers = [None for _ in range(cp_size)] k_shape = k.shape k_numel = k.numel() v_shape = v.shape + o_shape = q.shape[:-1] + v.shape[-1:] p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] # P2P communication and compute: each rank has cp_size steps - # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype - # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 + # MXFP8/F16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # FP8DS/CS attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + o_format = qkv_format for i in range(cp_size + 1): if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): @@ -1617,13 +1727,16 @@ def forward( softmax_scale, dropout_p, qkv_layout, + o_format, attn_mask_type, attn_bias_type, fp8, + fp8_recipe, q_fp8, k_fp8, v_fp8, fwd_nominal_dtype, + QKV_quantizer, S_quantizer_per_step[i], O_quantizer_per_step[i], rank, @@ -1771,8 +1884,8 @@ def forward( with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, h, sq, 1] -> [b, h, sq] or - # [t, h, 1] -> [t, np] + # [b, h, sq, 1] -> [b, h, sq] + # [t, h, 1] -> [t, h] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( @@ -1784,21 +1897,16 @@ def forward( out_per_step[i - 1] = out_per_step[i - 1].dequantize( dtype=torch.float32 ) - if fp8_recipe.float8_current_scaling(): + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": - if enable_mla: - out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape - ) + if fp8: + out = torch.zeros_like(out_per_step[0]).view(o_shape) else: - # MHA or GQA - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( - q.shape - ) + out = torch.zeros(o_shape, dtype=q.dtype, device=q.device) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -1838,7 +1946,7 @@ def forward( # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: if i == 0: out = flash_attn_fwd_out_correction_init( out_per_step[0], @@ -1846,10 +1954,7 @@ def forward( softmax_lse_per_step[0], seq_dim, ) - if enable_mla: - out = out.view(v_shape) - else: - out = out.view(q.shape) + out = out.view(o_shape) else: flash_attn_fwd_out_correction( out.view(*out_per_step[i].shape), @@ -1858,7 +1963,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1869,7 +1974,7 @@ def forward( softmax_lse_in_packed_format, ) else: - if qkv_format in ["bshd", "sbhd"]: + if o_format in ["bshd", "sbhd"]: flash_attn_fwd_second_half_out_correction( out, out_per_step[i], @@ -1877,7 +1982,7 @@ def forward( softmax_lse_per_step[i], seq_dim, ) - elif qkv_format == "thd": + elif o_format == "thd": tex.thd_out_correction( out, out_per_step[i], @@ -1887,35 +1992,31 @@ def forward( True, softmax_lse_in_packed_format, ) - - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - ctx.batch_size = out.shape[0] - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - ctx.batch_size = out.shape[1] + out = out.view(post_a2a_o_shape) + out_part = out.to(fwd_nominal_dtype) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + out, + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + cp_group_a2a, + cp_stream, + False, + qkv_format=o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + out = out.view(orig_o_shape) if return_max_logit: max_logit = flash_attn_a2a_communicate_softmax_offset( max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False ) - elif not use_fused_attention: - out = out.view(-1, *out.shape[-2:]) # update FP8 quantizers: amax across cp_size steps - if fp8 and use_fused_attention: + if fp8 and use_fused_attention and not fp8_recipe.mxfp8(): amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) O_quantizer.amax.copy_(amax_cp_fwd[1]) @@ -1938,7 +2039,11 @@ def forward( out_f16 = out.to(fwd_nominal_dtype) if fp8 and ( is_output_fp8 - or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + and not fp8_recipe.mxfp8() + ) ): out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 @@ -1949,7 +2054,7 @@ def forward( kv_fp8 = None kv = p2p_comm_buffers[-1] - if fp8: + if fp8 and not fp8_recipe.mxfp8(): q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8], [q, kv]) @@ -1957,17 +2062,28 @@ def forward( # q, kv, out fp8_tensors = (None, None, None) f16_tensors = (None, None, None) + out_f16 = out_part if ctx.fp8: # fwd: fp8, bwd: fp8, save all fp8 fp8_tensors = (q_fp8, kv_fp8, out_fp8) if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: f16_tensors = (None, None, out_f16) - elif fp8 and is_input_fp8: + elif fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) + elif fp8 and is_input_fp8 and not fp8_recipe.mxfp8(): # fwd: fp8, bwd: f16, save all f16 # dequantize fp8 inputs q_f16 = q_fp8.dequantize() kv_f16 = kv_fp8.dequantize() f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8 and is_input_fp8 and fp8_recipe.mxfp8(): + # fwd: fp8, bwd: f16, save all f16 + # there is already an F16 version of the inputs + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q, k, v) + kv = torch.cat((k_f16.view(-1), v_f16.view(-1)), dim=-1) + f16_tensors = (q_f16, kv, out_f16) + elif fp8 and not is_input_fp8 and fp8_recipe.mxfp8(): + f16_tensors = (q, kv, out_f16) elif fp8: # fwd: fp8, bwd: f16, save all f16 # inputs are already in f16 @@ -2005,7 +2121,6 @@ def forward( ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape @@ -2018,12 +2133,19 @@ def forward( ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 - ctx.enable_mla = enable_mla ctx.k_numel = k_numel ctx.k_shape = k_shape ctx.v_shape = v_shape - + ctx.o_shape = o_shape + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape + ctx.orig_o_shape = orig_o_shape + ctx.post_a2a_o_shape = post_a2a_o_shape + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -2032,14 +2154,14 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop(f"{nvtx_label}") - if return_max_logit: return out_ret, max_logit return out_ret @@ -2054,7 +2176,12 @@ def backward(ctx, dout, *_args): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): + if ( + ctx.fp8 + and ctx.is_output_fp8 + and not isinstance(dout, QuantizedTensorStorage) + and not ctx.fp8_recipe.mxfp8() + ): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -2094,7 +2221,6 @@ def backward(ctx, dout, *_args): # set up attention args causal = "causal" in ctx.attn_mask_type seq_dim = None - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") @@ -2133,13 +2259,13 @@ def backward(ctx, dout, *_args): if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() # [b, h, sq//2] -> [b, h, sq//2, 1] or - # [t//2, np] -> [t//2, h, 1] + # [t//2, h] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() # [b, h, sq] -> [b, h, sq, 1] or - # [t, np] -> [t, h, 1] + # [t, h] -> [t, h, 1] softmax_lse.unsqueeze_(-1) # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 @@ -2154,28 +2280,29 @@ def backward(ctx, dout, *_args): buffer_dtype = torch.uint8 dq_buffer = None dout_fp8 = None - bwd_output_te_dtype = None dkv_buffer = None if ctx.fp8: - assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" fused_attn_backend = FusedAttnBackend["FP8"] - q, kv, out = ( - q_fp8._data, - kv_fp8._data, - ( - out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - else out_fp8._data - ), - ) + if not ctx.fp8_recipe.mxfp8(): + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype # dout: torch.Tensor, dtype=torch.uint8 - if ctx.is_output_fp8: + if isinstance(dout, QuantizedTensorStorage): dout_fp8 = dout - else: + elif not ctx.fp8_recipe.mxfp8(): dout_fp8 = ctx.dO_quantizer(dout) - dout = dout_fp8._data + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data # print quantizers print_quantizers( @@ -2189,9 +2316,6 @@ def backward(ctx, dout, *_args): ctx.dP_quantizer, ) - # dout_fp8._fp8_dtype - bwd_output_te_dtype = ctx.dO_quantizer.dtype - # create buffers for reduction in float32 if ctx.fp8_recipe.delayed(): dq_buffer = torch.empty( @@ -2199,7 +2323,7 @@ def backward(ctx, dout, *_args): dtype=buffer_dtype, device=q.device, ) - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_buffer = torch.empty( q.shape, dtype=torch.float32, @@ -2213,7 +2337,7 @@ def backward(ctx, dout, *_args): ) dkv_recv_buffer = torch.empty_like(dkv_send_buffer) p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dkv_buffer = torch.zeros( kv.shape, dtype=torch.float32, @@ -2226,10 +2350,13 @@ def backward(ctx, dout, *_args): # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; # only used to hold temporary scale/amax values (output only, no quantization op) for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dP_quantizer_per_step[i] = ( + ctx.dP_quantizer.copy() if ctx.dP_quantizer is not None else None + ) dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() - dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if not ctx.fp8_recipe.mxfp8(): + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) @@ -2240,34 +2367,28 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # communicate for the 'a2a' part of 'a2a+p2p' + dout = dout.view(*ctx.orig_o_shape) if cp_size_a2a > 1: - if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( cp_size_a2a, out.device ) - out, dout = flash_attn_a2a_communicate( - [out, dout], + dout = flash_attn_a2a_communicate( + dout, chunk_ids_for_a2a, seq_dim, cp_size_a2a, ctx.cp_group_a2a, ctx.cp_stream, True, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) - - if ctx.enable_mla: - out = out.view(*ctx.v_shape) - dout = dout.view(*ctx.v_shape) - else: - # MHA or GQA - out = out.view(*q.shape) - dout = dout.view(*q.shape) + out = out.view(*ctx.o_shape) + dout = dout.view(*ctx.o_shape) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -2364,10 +2485,11 @@ def backward(ctx, dout, *_args): kv_fp8, ( out - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + if (ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or ctx.fp8_recipe.mxfp8() else out_fp8 ), - dout_fp8, + dout_fp8 if not ctx.fp8_recipe.mxfp8() else dout, softmax_lse, softmax_lse_, rng_states, @@ -2384,16 +2506,20 @@ def backward(ctx, dout, *_args): fused_attn_backend, ctx.softmax_scale, ctx.dropout_p, - qkv_layout, + ctx.qkv_layout, + ctx.qkv_format, + ctx.qkv_format, + ctx.qkv_layout, ctx.attn_mask_type, ctx.attn_bias_type, ctx.deterministic, ctx.fwd_nominal_dtype, bwd_nominal_dtype, - bwd_output_te_dtype, ctx.S_quantizer, dP_quantizer_per_step[i], dQKV_quantizer_per_step[i], + ctx.QKV_quantizer, + ctx.dO_quantizer, ] else: flash_attn_inputs = [ @@ -2467,7 +2593,7 @@ def backward(ctx, dout, *_args): if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8_recipe.delayed(): dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] # copy dq_ into the right buffer position @@ -2551,7 +2677,7 @@ def backward(ctx, dout, *_args): # dkv correction if ctx.fp8 and ctx.fp8_recipe.delayed(): dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] - elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + elif ctx.fp8 and (ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8()): dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] @@ -2641,9 +2767,10 @@ def backward(ctx, dout, *_args): # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: - amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + if not ctx.fp8_recipe.mxfp8(): + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) dq = dq_buffer if ctx.fp8_recipe.delayed(): @@ -2657,7 +2784,7 @@ def backward(ctx, dout, *_args): for x in [dq, dk, dv] ] dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.qkv_layout, dq, dk, dv, @@ -2666,7 +2793,7 @@ def backward(ctx, dout, *_args): ) dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if ctx.fp8_recipe.float8_current_scaling(): + if ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8(): dk = dkv[: ctx.k_numel].view(ctx.k_shape) dv = dkv[ctx.k_numel :].view(ctx.v_shape) @@ -2682,7 +2809,7 @@ def backward(ctx, dout, *_args): dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + dq, dk, dv, _ = combine_and_quantize(ctx.qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) if ctx.fp8: # print quantizers @@ -2700,7 +2827,8 @@ def backward(ctx, dout, *_args): if cp_size_a2a > 1: if ctx.fp8 and ctx.is_input_fp8: dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv - dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) + if not ctx.fp8_recipe.mxfp8(): + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2710,16 +2838,22 @@ def backward(ctx, dout, *_args): ctx.cp_group_a2a, ctx.cp_stream, False, + qkv_format=ctx.qkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) - if ctx.fp8 and ctx.is_input_fp8: + if ctx.fp8 and ctx.is_input_fp8 and not ctx.fp8_recipe.mxfp8(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) ] - if ctx.qkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif ctx.qkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + dq, dk, dv = [ + x.view(y) + for x, y in zip( + [dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape] + ) + ] if attn_dbias is not None: # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] @@ -2817,27 +2951,42 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + o_format = qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) - qkv_dtype = q.dtype - - causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" - if use_fused_attention and causal and "bottom_right" not in attn_mask_type: - attn_mask_type = attn_mask_type + "_bottom_right" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert qkv_format != "thd", f"No support for cp_comm_type='all_gather' and {qkv_format=}." + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." assert ( - use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='all_gather' and {attn_bias_type=}." + assert ( + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or fa_utils.v2_3_plus + ), ( + "cp_comm_type='all_gather' only supports SWA through FusedAttention or FlashAttention" + f" >= 2.3. Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='all_gather' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -2870,14 +3019,6 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - seq_dim = qkv_format.index("s") - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) if use_fused_attention or qkv_format == "thd": @@ -2886,30 +3027,90 @@ def forward( cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) else: cu_seqlens_q_padded = None + if use_fused_attention and attn_mask_type == "causal": + attn_mask_type = attn_mask_type + "_bottom_right" + causal = "causal" in attn_mask_type - # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + # FP8 setup + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + ( + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + fwd_nominal_dtype = q.dtype + q_fp8, k_fp8, v_fp8 = (q, k, v) if is_input_fp8 else (None, None, None) + q_f16, k_f16, v_f16 = (None, None, None) if is_input_fp8 else (q, k, v) + fused_attn_backend = None + fp8_meta_kwargs = {} + if fp8: + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + if not is_input_fp8 and not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer + elif use_fused_attention: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] + + # q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # reshape: split s + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) + # s dim first for all-gather + # [b, s, h, d]/[s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] - # [s, b, h, d] -> [cp, s, b, h, d] + # gather along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # pick out specific chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # reshape/flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # k_ag: [cp*s, b, h, d] + # v_ag: [cp*s, b, h, d] + # out: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + q_shape, k_shape, v_shape = q.shape, k.shape, v.shape + o_shape = q.shape[:-1] + v.shape[-1:] + out = torch.empty(o_shape, dtype=fwd_nominal_dtype, device=q.device) + # create two streams to resolve wave quantization issue of Flash Attn in each step flash_attn_streams = [torch.cuda.current_stream(), cp_stream] - + # prepare per-step tensors local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] kv_seq_range_per_step = [None, None] window_size_per_step = [None, None] @@ -2917,16 +3118,15 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) max_logit_per_step = [None, None] max_logit = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( local_seq_chunk_ids[i], @@ -2946,13 +3146,32 @@ def forward( cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( k.shape[1], max_seqlen_kv_, k.device ) - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] if use_fused_attention: + new_qkv_layout = qkv_layout + if fp8: + if not fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like( + q_fp8, data=q_part, dtype=fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_part, dtype=fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_part, dtype=fwd_nominal_dtype + ) + else: + q_part, k_part, v_part, new_qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) ( out_per_step[i], - [softmax_lse_per_step[i], rng_states[i]], + aux_ctx_tensors, *max_logit_, ) = fused_attn_fwd( is_training, @@ -2960,14 +3179,15 @@ def forward( max_seqlen_kv_, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - qkv_dtype, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + q_part, + k_part, + v_part, + fwd_nominal_dtype, + fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout=qkv_layout, + qkv_layout=new_qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -2976,9 +3196,16 @@ def forward( window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors if return_max_logit: max_logit_per_step[i] = max_logit_[0] + if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): + out_per_step[i] = out_per_step[i].dequantize(dtype=fwd_nominal_dtype) else: fa_forward_args_thd = get_fa_args( True, @@ -2995,9 +3222,9 @@ def forward( fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( - q_, - k_, - v_, + q_part, + k_part, + v_part, *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, @@ -3013,61 +3240,154 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + # out_per_step[i]: fwd_nominal_dtype, [b, s//2, h, d] or [s//2, b, h, d] + # out: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # max_logit_per_step[i]: torch.float32, [h] + # max_logit: torch.float32, [h] if return_max_logit and i == 0: max_logit = torch.clone(max_logit_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if qkv_format == "bshd": + if o_format == "bshd": out[:, i - 1].copy_(out_per_step[i - 1]) - elif qkv_format == "sbhd": + elif o_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) + + # all reduce max_logit across ranks if return_max_logit: torch.distributed.all_reduce( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - if use_fused_attention: - if qkv_format == "bshd": - out = out.view(out.shape[0], -1, *out.shape[-2:]) - elif qkv_format == "sbhd": - out = out.view(-1, *out.shape[-3:]) - else: - out = out.view(-1, *out.shape[-2:]) + # out: fwd_nominal_dtype + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + out = out.view(orig_o_shape) - ctx.save_for_backward( - q, - k, - v, + # prepare for forward output and backward saves of out + out_fp8 = None + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) + if fp8 and (is_output_fp8 or bwd_requires_o_fp8): + out_fp8 = O_quantizer(out) + out_ret = out_fp8 if is_output_fp8 else out + + # save tensors for backward + ctx.fp8 = fp8 and is_bwd_fp8 + ctx.fp8_recipe = fp8_recipe + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + # True: q split along s; k/v with s first, i.e. [s, b, h, d] + # False: original [b, s, h, d] or [s, b, h, d] + ctx.qkv_reshaped = True + # no load-balance related token shuffling; original token order in q/k/v/out + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out/out_fp8: [b, s, h, d] or [s, b, h, d] + if ctx.fp8: + # q_fp8_save: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k_fp8_save: [s, b, h, d] + # v_fp8_save: [s, b, h, d] + q_fp8_save, k_fp8_save, v_fp8_save = None, None, None + if fp8_recipe.delayed() or fp8_recipe.float8_current_scaling(): + q_fp8_save = Float8Tensor.make_like(q_fp8, data=q, dtype=fwd_nominal_dtype) + k_fp8_save = Float8Tensor.make_like(k_fp8, data=k, dtype=fwd_nominal_dtype) + v_fp8_save = Float8Tensor.make_like(v_fp8, data=v, dtype=fwd_nominal_dtype) + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v in FP8, o in f16 + # MXFP8: q/k/v/o all in f16 + if fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, out_fp8) + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8_save, k_fp8_save, v_fp8_save, None) + f16_tensors = (None, None, None, out) + elif fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out) + elif fp8: + # convert q/k/v to F16 if necessary, and save q/k/v/o all in F16 and original format + if is_input_fp8: + q_f16, k_f16, v_f16 = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + f16_tensors = (q_f16, k_f16, v_f16, out) + ctx.qkv_reshaped = False + else: + # save all in F16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + # out: [b, s, h, d] or [s, b, h, d] + f16_tensors = (q, k, v, out) + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_q_padded, *cu_seqlens_kv_per_step, - *out_per_step, *softmax_lse_per_step, *rng_states, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects - ctx.qkv_dtype = qkv_dtype + ctx.qkv_format = qkv_format + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_format = qkv_format + ctx.dqkv_layout = qkv_layout + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.orig_o_shape = orig_o_shape + ctx.o_shape = o_shape + ctx.q_shape = q_shape + ctx.k_shape = k_shape + ctx.v_shape = v_shape ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step + ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 + + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.O_quantizer = O_quantizer.copy() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: - return out, max_logit - return out + return out_ret, max_logit + return out_ret @staticmethod def backward(ctx, dout, *_args): @@ -3076,22 +3396,94 @@ def backward(ctx, dout, *_args): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (*saved_tensors,) = ctx.saved_tensors - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] - cu_seqlens_kv_per_step = saved_tensors[5:7] - out_per_step = saved_tensors[7:9] - softmax_lse_per_step = saved_tensors[9:11] - rng_states = saved_tensors[11:13] + cu_seqlens_kv_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] + ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_q_padded, + cu_seqlens_kv_per_step[0], + cu_seqlens_kv_per_step[1], + softmax_lse_per_step[0], + softmax_lse_per_step[1], + rng_states[0], + rng_states[1], + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step - seq_dim = ctx.qkv_format.index("s") - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + _, seq_dim_qkv, _ = get_bsh_dims(ctx.qkv_format) + _, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_o, _ = get_bsh_dims(ctx.o_format) + causal = "causal" in ctx.attn_mask_type - dout = dout.view(q.shape) - dq = torch.empty_like(q) - dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) - dv = torch.zeros_like(dk) + # set up dout: + # FP8DS/CS: torch.uint8, [b, s, h, d] or [s, b, h, d] + # MXFP8/F16: torch.float16 or torch.bfloat16, [b, s, h, d] or [s, b, h, d] + dout_fp8 = None + if ctx.fp8: + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout + if not ctx.fp8_recipe.mxfp8(): + dout = dout_fp8._data + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + dout = dout.view(ctx.o_shape) + + # set up q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # k: [s, b, h, d] + # v: [s, b, h, d] + if ctx.fp8 and not ctx.fp8_recipe.mxfp8(): + q, k, v = [x._data for x in [q_fp8, k_fp8, v_fp8]] + if not ctx.qkv_reshaped: + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] + + # set up out: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): torch.uint8 + # FP8CS+_dpa_fp8_cs_o_in_f16: torch.float16 or torch.bfloat16 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + if ctx.fp8 and ( + ctx.fp8_recipe.delayed() + or (ctx.fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ): + out = out_fp8._data + out = out.view(ctx.o_shape) + + # set up dq, dk, dv: + # dq: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dk: fwd_nominal_dtype, [cp*s, b, h, d] + # dv: fwd_nominal_dtype, [cp*s, b, h, d] + dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + dk = torch.zeros( + (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=k.device, + ) + dv = torch.zeros( + (ctx.v_shape[0] * cp_size, *ctx.v_shape[1:]), + dtype=ctx.fwd_nominal_dtype, + device=v.device, + ) dq_per_step = [None, None] dk_per_step = [None, None] dv_per_step = [None, None] @@ -3101,23 +3493,22 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, h, d] -> [cp, s, b, h, d] + # gather k and v along s: [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s: [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # select appropriate chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) - local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] - + # set up flash_attn_bwd flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} @@ -3149,57 +3540,119 @@ def backward(ctx, dout, *_args): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], kv_seq_range_per_step[i][1], ) max_seqlen_kv = seq_end_idx - seq_start_idx - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] - out_ = out_per_step[i] - dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + out_part = out.select(seq_dim_o, i).contiguous() + dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + aux_ctx_tensors = [ + softmax_lse_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + ] + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + fp8_meta_kwargs = {} + qkv_layout = ctx.qkv_layout + d_out_format = ctx.o_format + if ctx.fp8: + fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_FP8 + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o/do all in FP8 + # FP8CS+_dpa_fp8_cs_o_in_f16: q/k/v/do in FP8, o in f16 + # MXFP8: q/k/v/do all in MXFP8, o/do_f16 in F16 + if not ctx.fp8_recipe.mxfp8(): + q_part = Float8Tensor.make_like( + q_fp8, data=q_part, dtype=ctx.fwd_nominal_dtype + ) + k_part = Float8Tensor.make_like( + k_fp8, data=k_part, dtype=ctx.fwd_nominal_dtype + ) + v_part = Float8Tensor.make_like( + v_fp8, data=v_part, dtype=ctx.fwd_nominal_dtype + ) + if ctx.fp8_recipe.delayed() or ( + ctx.fp8_recipe.float8_current_scaling() + and not _dpa_fp8_cs_o_in_f16 + ): + out_part = Float8Tensor.make_like( + out_fp8, data=out_part, dtype=ctx.fwd_nominal_dtype + ) + dout_part = Float8Tensor.make_like( + dout_fp8, data=dout_part, dtype=ctx.fwd_nominal_dtype + ) + else: + q_part, k_part, v_part, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, ctx.QKV_quantizer + ) + dout_part, d_out_format = dpa_utils.permute_to_grouped_tensor( + d_out_format, dout_part + ) + aux_ctx_tensors.append(dout_part) + dout_part = ctx.dO_quantizer(dout_part) dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv_per_step[i], - q_, - k_, - v_, - out_, - dout_, - ctx.qkv_dtype, - TE_DType[dout.dtype], + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.fwd_nominal_dtype, aux_ctx_tensors, - tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, + o_format=ctx.o_format, + d_out_format=d_out_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, ) + if ctx.fp8 and all( + isinstance(x, QuantizedTensorStorage) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ): + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + x.dequantize(dtype=ctx.fwd_nominal_dtype) + for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] + ] else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_, k_, v_] + torch.empty_like(x) for x in [q_part, k_part, v_part] ] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - ctx.qkv_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=ctx.max_seqlen_q, @@ -3216,29 +3669,34 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] if ctx.use_flash_attn_3: - fa_backward_kwargs["is_causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["is_causal"] = causal else: - fa_backward_kwargs["causal"] = "causal" in ctx.attn_mask_type + fa_backward_kwargs["causal"] = causal flash_attn_bwd( - dout_, - q_, - k_, - v_, - out_, + dout_part, + q_part, + k_part, + v_part, + out_part, softmax_lse_per_step[i], *fa_backward_args_thd, **fa_backward_kwargs, ) if i > 0: + # dq/dk/dv, dq_per_step/dk_per_step/dv_per_step: ctx.fwd_nominal_dtype with torch.cuda.stream(flash_attn_streams[i - 1]): - if ctx.qkv_format == "bshd": + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dq_per_step[i]: [b, s//2, h, d] or [s//2, b, h, d] + if ctx.dqkv_format == "bshd": dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.qkv_format == "sbhd": + elif ctx.dqkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] + # dk/dv: [cp*s, b, h, d] + # dk_per_step[i - 1]/dv_per_step[i - 1]: [s_range, b, h, d] or [b, s_range, h, d] + # move s to first dim: [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim, 0).contiguous() + x.movedim(seq_dim_dqkv, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] ] # wait until dkv update of last step is done @@ -3248,6 +3706,7 @@ def backward(ctx, dout, *_args): kv_seq_range_per_step[i - 1][0], kv_seq_range_per_step[i - 1][1], ) + # add to dk/dv: [cp*s, b, h, d] dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) if i < len(local_seq_chunk_ids): @@ -3255,23 +3714,33 @@ def backward(ctx, dout, *_args): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + # split s:[cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + # put back together the right chunks for each rank chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) + # reduce scatter: [cp*s, b, h, d] -> [s, b, h, d] dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) - dk = dk.movedim(0, seq_dim).contiguous() - dv = dv.movedim(0, seq_dim).contiguous() - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") + # reshape to original format: + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dk: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dv: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + dq = dq.view(*dq.shape[:seq_dim_dqkv], -1, *dq.shape[(seq_dim_dqkv + 2) :]) + dk = dk.movedim(0, seq_dim_dqkv).contiguous() + dv = dv.movedim(0, seq_dim_dqkv).contiguous() + # quantize if necessary + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( None, dq, @@ -3294,6 +3763,10 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, + None, + None, ) @@ -3338,24 +3811,42 @@ def forward( ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) - + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + original_qkv_layout = qkv_layout + orig_q_shape, orig_k_shape, orig_v_shape = q.shape, k.shape, v.shape + o_format = qkv_format + batch_dim_qkv, seq_dim_qkv, _ = get_bsh_dims(qkv_format) + _, seq_dim_o, _ = get_bsh_dims(o_format) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type + + if qkv_format in ["bshd", "sbhd"]: + assert ( + "padding" not in attn_mask_type + ), f"No support for cp_comm_type='a2a', {attn_mask_type=} and {qkv_format=}." assert ( - not padding or qkv_format == "thd" - ), f"{attn_mask_type} mask type is not supported for BSHD and SBHD!" - assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" - assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + attn_bias_type == "no_bias" + ), f"No support for cp_comm_type='a2a' and {attn_bias_type=}." assert ( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention or fa_utils.v2_3_plus - ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + ), ( + "cp_comm_type='a2a' only supports SWA through FusedAttention or FlashAttention >= 2.3." + f" Found {use_fused_attention=} and {fa_utils.v2_3_plus=}." + ) + assert q.shape[seq_dim_qkv] % 2 == 0 and k.shape[seq_dim_qkv] % 2 == 0, ( + "cp_comm_type='a2a' requires seq_len % 2 == 0 for Q, K, V. Found seq_len_q =" + f" {q.shape[seq_dim_qkv]}, seq_len_kv = {k.shape[seq_dim_qkv]}, cp_size = {cp_size}." + ) + assert q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0, ( + "cp_comm_type='a2a' requires num_heads % cp_size == 0 for Q, K, V. Found num_heads_q =" + f" {q.shape[-2]}, num_heads_kv = {k.shape[-2]}, cp_size = {cp_size}." + ) flash_attn_fwd = None if not use_fused_attention: @@ -3395,26 +3886,10 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert ( - q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 - ), "The number of attention heads needs to be divisible by CP size!" - - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - - if qkv_format in ["bshd", "sbhd"]: - batch_dim = qkv_format.index("b") - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - batch_dim = seq_dim = qkv_format.index("t") - - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" - assert isinstance(k, q.__class__) and isinstance( v, q.__class__ - ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." - is_input_fp8 = isinstance(q, Float8Tensor) + ), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage." + is_input_fp8 = isinstance(q, QuantizedTensorStorage) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; @@ -3422,62 +3897,99 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers) + dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) ) q_fp8, k_fp8, v_fp8 = (None, None, None) + fp8_meta_kwargs = {} if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if is_input_fp8: - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - else: - q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) - q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer - else: - assert False, "FP8 is only supported with Fused Attention!" + assert use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if is_input_fp8: + q_fp8, k_fp8, v_fp8 = q, k, v + elif not fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q, k, v, QKV_quantizer + ) + if not fp8_recipe.mxfp8(): + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: if use_fused_attention: - fp8_meta_kwargs = {} fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + # q, k, v: + # FP8DS/FP8CS: torch.uint8 + # MXFP8: torch.float16 or torch.bfloat16 + # F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, - seq_dim, + seq_dim_qkv, cp_size, cp_group, cp_stream, before_attn=True, qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["q", "k", "v"], ) + + # softmax_offset: split h + # [1, h, 1, 1] -> [1, h//cp, 1, 1] if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( softmax_offset, 1, cp_size, cp_group, cp_stream, True ) - out_fp8 = None - out_f16 = None - batch_size = q.shape[batch_dim] + # _part: inputs to attention kernel and saved for backward + # note: they have post a2a shapes + batch_size = q.shape[batch_dim_qkv] q_part, k_part, v_part = q, k, v - out_part = None + out_part, out_fp8, out_f16 = None, None, None + bwd_requires_o_f16 = is_training and ( + not is_bwd_fp8 + or ( + is_bwd_fp8 + and ( + (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + or fp8_recipe.mxfp8() + ) + ) + ) + bwd_requires_o_fp8 = ( + is_training + and is_bwd_fp8 + and ( + fp8_recipe.delayed() + or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) + ) + ) if use_fused_attention: if fp8: - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] + if fp8_recipe.mxfp8(): + q_fp8, k_fp8, v_fp8, qkv_layout = combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) + q_part, k_part, v_part = [q_fp8, k_fp8, v_fp8] + else: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, @@ -3492,6 +4004,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, + o_format=o_format, attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, @@ -3504,24 +4017,17 @@ def forward( return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), ) - if isinstance(out_, Float8Tensor): - out_fp8 = out_ - out_ = out_._data - if is_bwd_fp8 and not ( - fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - ): - out_part = out_fp8 - else: - out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) - else: - out_f16 = out_ - out_part = out_ - if ( - fp8 - and is_bwd_fp8 - and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - ): - out_part = O_quantizer(out_) + # construct out_part for backward + out_fp8 = out_ + out_f16 = out_ + if bwd_requires_o_fp8: + if not isinstance(out_, QuantizedTensorStorage): + out_fp8 = O_quantizer(out_) + out_part = out_fp8 + if bwd_requires_o_f16: + if isinstance(out_, QuantizedTensorStorage): + out_f16 = out_.dequantize(dtype=fwd_nominal_dtype) + out_part = out_f16 else: fa_forward_args_thd = get_fa_args( True, @@ -3549,60 +4055,98 @@ def forward( aux_ctx_tensors = [softmax_lse, rng_state] out_part = out_ + # a2a: split s and gather h + # [b, s, h//cp, d] -> [b*s//cp, h, d] + # [s, b, h//cp, d] -> [s//cp*b, h, d] + # [t, h//cp, d] -> [t//cp, h, d] + if isinstance(out_, QuantizedTensorStorage): + out_ = out_._data chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, - seq_dim, + seq_dim_o, cp_size, cp_group, cp_stream, before_attn=False, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["out"], ) - if return_max_logit: - max_logit = flash_attn_a2a_communicate_softmax_offset( - *max_logit, 0, cp_size, cp_group, cp_stream, False - ) - - if use_fused_attention: - if qkv_format == "bshd": - # [b*s, h, d] -> [b, s, h, d] - out_ = out_.view(batch_size, -1, *out_.shape[-2:]) - elif qkv_format == "sbhd": - # [s*b, h, d] -> [s, b, h, d] - out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - - if fp8 and use_fused_attention: - if fp8_recipe.float8_current_scaling(): - out_f16 = out_ - if is_output_fp8: - out_fp8 = O_quantizer(out_) + # [b*s//cp, h, d] -> [b, s//cp, h, d] + # [s//cp*b, h, d] -> [s//cp, b, h, d] + # [t//cp, h, d] -> [t//cp, h, d] + if o_format == "bshd": + out_ = out_.view(batch_size, -1, *out_.shape[-2:]) + elif o_format == "sbhd": + out_ = out_.view(-1, batch_size, *out_.shape[-2:]) + + # out_ret: output tensor for forward pass + if fp8: if fp8_recipe.delayed(): out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) - if not is_output_fp8: + if is_output_fp8: + if fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8(): + out_fp8 = O_quantizer(out_) + out_f16 = out_ + else: + if fp8_recipe.delayed(): out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ else: out_f16 = out_ - out_ret = out_fp8 if is_output_fp8 else out_f16 + # all gather max logit + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + *max_logit, 0, cp_size, cp_group, cp_stream, False + ) + + ctx.qkv_layout = qkv_layout + ctx.o_format = o_format + ctx.dqkv_layout = original_qkv_layout + ctx.dqkv_format = qkv_format + ctx.orig_q_shape = orig_q_shape + ctx.orig_k_shape = orig_k_shape + ctx.orig_v_shape = orig_v_shape + ctx.out_part_shape = out_part.shape + ctx.out_ret_shape = out_ret.shape + + # save tensors for backward ctx.fp8 = fp8 and is_bwd_fp8 fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) - if ctx.fp8: - if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: - fp8_tensors = (q_part, k_part, v_part, None) - f16_tensors = (None, None, None, out_part) + if is_training: + if ctx.fp8: + # FP8DS or (FP8CS+not _dpa_fp8_cs_o_in_f16): q/k/v/o all in FP8 + # (FP8CS+_dpa_fp8_cs_o_in_f16) or MXFP8: q/k/v in FP8, o in F16 + if ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or fp8_recipe.mxfp8(): + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) + elif fp8_recipe.delayed() or ( + fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 + ): + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + # FP8DS/CS: convert post-a2a FP8 q/k/v to F16 + # MXFP8: save post-a2a pre-quantization F16 q/k/v + # out_part is already converted to the right precision + if fp8_recipe.mxfp8(): + f16_tensors = (q, k, v, out_part) + ctx.qkv_layout = original_qkv_layout + else: + q_part, k_part, v_part = combine_and_dequantize( + qkv_layout, q_part, k_part, v_part + ) + f16_tensors = (q_part, k_part, v_part, out_part) else: - fp8_tensors = (q_part, k_part, v_part, out_part) - elif fp8: - q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) - f16_tensors = (q_part, k_part, v_part, out_part) - else: - f16_tensors = (q_part, k_part, v_part, out_part) - + # all tensors are already in F16 + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *f16_tensors, @@ -3614,16 +4158,13 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects - ctx.out_shape = out_ret.shape - ctx.batch_size = batch_size ctx.cp_group = cp_group ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale - ctx.qkv_format = qkv_format ctx.attn_mask_type = attn_mask_type ctx.attn_bias_type = attn_bias_type ctx.deterministic = deterministic @@ -3645,11 +4186,13 @@ def forward( ctx.S_quantizer = S_quantizer if ctx.fp8: ctx.QKV_quantizer = QKV_quantizer.copy() - ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() ctx.O_quantizer = O_quantizer.copy() - ctx.O_quantizer.scale = O_quantizer.scale.clone() - ctx.S_quantizer = S_quantizer.copy() - ctx.S_quantizer.scale = S_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() if S_quantizer is not None else None + if not ctx.fp8_recipe.mxfp8(): + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if return_max_logit: return out_ret, max_logit @@ -3677,60 +4220,53 @@ def backward(ctx, dout, *_args): *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_format = ctx.qkv_format - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - causal = "causal" in ctx.attn_mask_type - - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") - else: # qkv_format == "thd" - seq_dim = qkv_format.index("t") - + batch_dim_dqkv, seq_dim_dqkv, _ = get_bsh_dims(ctx.dqkv_format) + _, seq_dim_do, _ = get_bsh_dims(ctx.o_format) bwd_nominal_dtype = ctx.fwd_nominal_dtype - dqkv_te_dtype = None fused_attn_backend = None - dout_fp8 = dout + causal = "causal" in ctx.attn_mask_type + + dout_fp8 = None + fp8_meta_kwargs = {} if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorStorage): - dout = ctx.dO_quantizer(dout) - dout_fp8 = dout - dqkv_te_dtype = dout._fp8_dtype + assert ctx.use_fused_attention, "FP8 is only supported with FusedAttention backend!" + fused_attn_backend = FusedAttnBackend["FP8"] + if isinstance(dout, QuantizedTensorStorage): + dout_fp8 = dout + elif not ctx.fp8_recipe.mxfp8(): + dout = ctx.dO_quantizer(dout) + dout_fp8 = dout + if not ctx.fp8_recipe.mxfp8(): dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer - fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer - - else: - assert False, "FP8 is only supported with Fused Attention!" + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer else: if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - - if not ctx.use_fused_attention: - if qkv_format in ["bshd", "sbhd"]: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) - else: - dout = dout.view(*ctx.out_shape) - + dout = dout.view(*ctx.out_ret_shape) + + # dout: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( dout, chunk_ids_for_a2a, - seq_dim, + seq_dim_do, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=True, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=ctx.o_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + a2a_input_names=["dout"], ) flash_attn_bwd = None @@ -3748,7 +4284,7 @@ def backward(ctx, dout, *_args): fa_backward_kwargs["window_size_right"] = ctx.window_size[1] fa_backward_kwargs["deterministic"] = ctx.deterministic else: - if qkv_format == "thd": + if ctx.o_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( _flash_attn_varlen_bwd, ) @@ -3775,12 +4311,21 @@ def backward(ctx, dout, *_args): dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: + d_out_format = ctx.o_format q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ( + ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ) or ctx.fp8_recipe.mxfp8(): out_part = out - dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + if not ctx.fp8_recipe.mxfp8(): + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + else: + # d_out_format = bhsd for both dout (F16) and dout_part (MXFP8) + dout, d_out_format = dpa_utils.permute_to_grouped_tensor(d_out_format, dout) + aux_ctx_tensors.append(dout) + dout_part = ctx.dO_quantizer(dout) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3792,14 +4337,16 @@ def backward(ctx, dout, *_args): out_part, dout_part, bwd_nominal_dtype, - dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout=qkv_layout, + qkv_layout=ctx.qkv_layout, + o_format=ctx.o_format, + d_out_format=d_out_format, + dqkv_layout=ctx.dqkv_layout, attn_mask_type=ctx.attn_mask_type, attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, @@ -3808,7 +4355,7 @@ def backward(ctx, dout, *_args): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if isinstance(dq, Float8Tensor): + if all(isinstance(x, QuantizedTensorStorage) for x in [dq, dk, dv]): dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv dq, dk, dv = [x._data for x in [dq, dk, dv]] else: @@ -3817,7 +4364,7 @@ def backward(ctx, dout, *_args): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - qkv_format, + ctx.dqkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -3843,24 +4390,33 @@ def backward(ctx, dout, *_args): **fa_backward_kwargs, ) + # dq, dk, dv: + # FP8DS: torch.uint8 + # FP8CS/MXFP8/F16: torch.float16 or torch.bfloat16 + # a2a: gather s and split h + # [b, s//cp, h, d] -> [b, s, h//cp, d] + # [s//cp, b, h, d] -> [s, b, h//cp, d] + # [t//cp, h, d] -> [t, h//cp, d] chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, - seq_dim, + seq_dim_dqkv, cp_size, ctx.cp_group, ctx.cp_stream, before_attn=False, - qkv_format=qkv_format, - cu_seqlens_padded=cu_seqlens_q_padded, + qkv_format=ctx.dqkv_format, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + a2a_input_names=["dq", "dk", "dv"], ) + dq, dk, dv = [ + x.view(y) + for x, y in zip([dq, dk, dv], [ctx.orig_q_shape, ctx.orig_k_shape, ctx.orig_v_shape]) + ] - if qkv_format == "bshd": - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif qkv_format == "sbhd": - dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - + # d_bias, d_softmax_offset d_bias = None d_softmax_offset = None if ctx.use_fused_attention: @@ -3872,9 +4428,14 @@ def backward(ctx, dout, *_args): d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False ) + # convert dq, dk, dv to appropriate types if ctx.fp8: - if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: - dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ( + ctx.fp8_recipe.float8_current_scaling() or ctx.fp8_recipe.mxfp8() + ) and ctx.is_input_fp8: + dq, dk, dv, _ = combine_and_quantize( + ctx.dqkv_layout, dq, dk, dv, ctx.dQKV_quantizer + ) if ctx.fp8_recipe.delayed(): dq, dk, dv = [ Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) @@ -3882,7 +4443,7 @@ def backward(ctx, dout, *_args): ] if not ctx.is_input_fp8: dq, dk, dv = combine_and_dequantize( - qkv_layout, + ctx.dqkv_layout, dq, dk, dv, @@ -3890,7 +4451,6 @@ def backward(ctx, dout, *_args): ) nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") - return ( None, dq, @@ -4019,7 +4579,6 @@ def attn_forward_func_with_cp( in Megatron-LM. """ - if cp_comm_type == "a2a+p2p": assert ( isinstance(cp_group, list) and len(cp_group) == 2 @@ -4066,10 +4625,11 @@ def attn_forward_func_with_cp( ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] - assert not enable_mla or cp_comm_type in [ - "p2p", - "a2a+p2p", - ], f"Context parallelism does not support MLA with {cp_comm_type=}!" + # assert not enable_mla or cp_comm_type in [ + # "p2p", + # "a2a+p2p", + # "a2a", + # ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: @@ -4127,7 +4687,16 @@ def attn_forward_func_with_cp( elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [ + window_size, + cp_group, + cp_stream, + use_flash_attn_3, + fp8, + fp8_meta, + quantizers, + fp8_output, + ] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..d1fd0b0ed0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -30,7 +30,7 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( @@ -98,19 +98,26 @@ +-------------------+-----------+-----------------------------------------------------------------------------------+ | Linear | Attention | Configuration | +===================+===========+===================================================================================+ -| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | -| | | export NVTE_DPA_FP8_RECIPE="F16" | +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS, NVFP4 or MXFP8 to autocast(); | +| /MXFP8 | | export NVTE_DPA_FP8_RECIPE="F16" | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8DS | Pass FP8DS to autocast(); | +| FP8DS | FP8DS | Pass FP8DS to autocast(); | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8DS | Pass FP8CS to autocast(); | +| FP8CS | FP8DS | Pass FP8CS to autocast(); | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | +| MXFP8 | FP8DS | Pass MXFP8 to autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear MXFP8; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | @@ -118,19 +125,27 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8CS | Pass FP8DS to autocast(); | +| FP8DS | FP8CS | Pass FP8DS to autocast(); | | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8CS | Pass FP8CS to autocast(); | +| FP8CS | FP8CS | Pass FP8CS to autocast(); | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | +| MXFP8 | FP8CS | Pass MXFP8 to autocast(); | +| | | Attention creates a new FP8CS recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear MXFP8, and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | @@ -139,6 +154,18 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS/FP8CS | MXFP8 | Pass FP8DS/FP8CS to autocast(); | +| | | Attention creates a new MXFP8 recipe based on fp8_format, fp8_dpa, fp8_mha from | +| | | linear FP8DS/FP8CS | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| MXFP8 | MXFP8 | Pass MXFP8 to autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | MXFP8 | Pass NVFP4 to autocast(); | +| | | Attention MXFP8 reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling" # switch to MXFP8BS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ """ _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} @@ -600,7 +627,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False fp8_recipe.fp8_dpa = False fp8_recipe.fp8_mha = False - elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + elif ( + fp8_recipe.float8_current_scaling() or fp8_recipe.mxfp8() + ) and _dpa_fp8_recipe == "DelayedScaling": # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe fake_recipe = DelayedScaling( fp8_format=fp8_recipe.fp8_format, @@ -653,6 +682,25 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) fp8_recipe_dpa = fake_recipe fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.mxfp8() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ), + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP @@ -673,11 +721,26 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False - if not fp8_recipe_dpa.float8_per_tensor_scaling(): - assert not ( - fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha - ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + elif ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ) and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "MXFP8BlockScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a MXFP8 recipe + fake_recipe = MXFP8BlockScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa # reduce over TP+CP groups; expect fp8_group to be set up so # assume attention uses the same fp8_group as GEMMs @@ -1203,7 +1266,9 @@ def forward( cu_seqlens_kv_padded = None # get qkv's memory layout - if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + if all( + isinstance(x, Float8TensorStorage) for x in [query_layer, key_layer, value_layer] + ): ( qkv_layout, query_layer._data, @@ -1365,6 +1430,7 @@ def forward( attention_dropout=self.attention_dropout, context_parallel=context_parallel, cp_comm_type=self.cp_comm_type, + cp_size=cp_size, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..a003226651 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,11 +35,16 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8TensorStorage +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -220,6 +225,8 @@ class AttentionParams: Whether context parallelism is used or not. cp_comm_type : str, default = "p2p" The communication type of context parallelism. + cp_size : int, default = 1 + The group size of context parallelism. deterministic : bool, default = False Whether to run `DotProductAttention` with determinism or not. is_training : bool, default = True @@ -261,6 +268,7 @@ class AttentionParams: attention_dropout: float = 0.0 context_parallel: bool = False cp_comm_type: str = "p2p" + cp_size: int = 1 deterministic: bool = False is_training: bool = True fp8: bool = False @@ -338,6 +346,7 @@ def get_attention_backend( attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel cp_comm_type = attention_params.cp_comm_type + cp_size = attention_params.cp_size deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 @@ -359,6 +368,7 @@ def get_attention_backend( "transformer_engine_version": te.__version__, "compute_capability": "sm" + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "cuda_version": torch.version.cuda, "flash_attn_version": ( str(FlashAttentionUtils.version) if FlashAttentionUtils.is_installed @@ -446,31 +456,41 @@ def get_attention_backend( qkv_dtype, ) use_flash_attention_2 = False - if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [ + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in ( torch.Tensor, Float8Tensor, - ]: + Float8TensorStorage, + MXFP8Tensor, + MXFP8TensorStorage, + ): if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug( - "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s." + " Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}," + " qkv_type = {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor," + " MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) use_flash_attention_3 = False if use_fused_attention: logger.debug( - "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " - "qkv_type = {torch.Tensor, Float8Tensor}. ", + "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. Supported:" + " qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, qkv_type =" + " {torch.Tensor, Float8Tensor, Float8TensorStorage, MXFP8Tensor," + " MXFP8TensorStorage}. ", qkv_dtype, qkv_type, ) use_fused_attention = False # Filter: Execution type - if fp8 and fp8_meta["recipe"].fp8_dpa: + fp8_recipe = None + if fp8: + fp8_recipe = fp8_meta["recipe"] if fp8_meta is not None else None + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if fp8 and fp8_recipe.fp8_dpa: if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for FP8 attention") use_flash_attention_2 = False @@ -478,6 +498,12 @@ def get_attention_backend( if FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False + if use_flash_attention_3 and not ( + fp8_recipe.delayed() or fp8_recipe.float8_current_scaling() + ): + if FlashAttentionUtils.v3_is_installed: + logger.debug(f"Disabling FlashAttention 3 for {fp8_recipe.__class__.__name__}") + use_flash_attention_3 = False if use_unfused_attention: allow_emulation = ( os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() @@ -485,15 +511,21 @@ def get_attention_backend( if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False - fp8_recipe = fp8_meta["recipe"] - if fp8_meta.get("local_recipes", None) is not None: - fp8_recipe = fp8_meta["local_recipes"][0] + if use_fused_attention and fp8_recipe.delayed(): + if ( + device_compute_capability >= (10, 0) + and deterministic + and cudnn_version < (9, 18, 0) + ): + logger.debug( + "Disabling FusedAttention for FP8 delayed scaling on arch >= sm100 with" + " determinism for cuDNN < 9.18.0" + ) + use_fused_attention = False if use_fused_attention and fp8_recipe.float8_current_scaling(): if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False - # TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling - # determinism for Blackwell else: if cudnn_version < (9, 14, 0): logger.debug( @@ -503,10 +535,27 @@ def get_attention_backend( else: if deterministic and cudnn_version < (9, 18, 0): logger.debug( - "Disabling FusedAttention for FP8 current scaling requiring determinism" - " with cuDNN < 9.18.0" + "Disabling FusedAttention for FP8 current scaling with determinism" + " for cuDNN < 9.18.0" ) use_fused_attention = False + if use_fused_attention and fp8_recipe.mxfp8(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for MXFP8 on arch < sm100") + use_fused_attention = False + elif fp8_recipe.fp8_mha: + logger.debug("Disabling FusedAttention for MXFP8 with fp8_mha=True") + use_fused_attention = False + else: + if cudnn_version < (9, 21, 0): + logger.debug("Disabling FusedAttention for MXFP8 with cuDNN < 9.21.0") + use_fused_attention = False + elif qkv_format == "thd": + logger.debug("Disabling FusedAttention for MXFP8 with qkv_format = thd") + use_fused_attention = False + if use_fused_attention and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()): + logger.debug(f"Disabling FusedAttention for {fp8_recipe.__class__.__name__}") + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: @@ -539,7 +588,7 @@ def get_attention_backend( if use_flash_attention: use_flash_attention = False logger.debug("Disabling FlashAttention for max_logit") - if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8 and fp8_recipe.fp8_dpa: use_flash_attention = False use_fused_attention = False use_unfused_attention = False @@ -564,8 +613,8 @@ def get_attention_backend( use_flash_attention = False use_fused_attention = False use_unfused_attention = False - if fp8 and fp8_meta["recipe"].fp8_dpa: - if fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_recipe.fp8_dpa: + if fp8_recipe.fp8_mha: logger.debug("Disabling all backends for KV caching with FP8 MHA") use_flash_attention = False use_fused_attention = False @@ -715,30 +764,26 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if softmax_type != "vanilla": logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) use_flash_attention = False - if fp8 and fp8_meta["recipe"].fp8_dpa: - logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) - use_fused_attention = False + if fp8 and fp8_recipe.fp8_dpa: logger.debug( "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type ) use_unfused_attention = False - if qkv_format == "thd": - if cudnn_version < (9, 18, 0): - logger.debug( - "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" - " version < 9.18", - softmax_type, - ) - use_fused_attention = False - if context_parallel: - if cp_comm_type != "a2a": - logger.debug( - "Disabling FusedAttention for context parallelism with softmax_type = %s and" - " cp_comm_type = %s", - softmax_type, - cp_comm_type, - ) - use_fused_attention = False + if qkv_format == "thd" and cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False + if context_parallel and cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends @@ -759,7 +804,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_unfused_attention = False if context_parallel and (use_flash_attention_2 or use_flash_attention_3): if FlashAttentionUtils.is_installed or FlashAttentionUtils.v3_is_installed: - if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8 and fp8_recipe.fp8_dpa: logger.debug( "Disabling FlashAttention as it does not support context parallelism with FP8" ) @@ -816,10 +861,50 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt " bias for THD format" ) use_fused_attention = False - elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: + elif fp8 and fp8_recipe.fp8_dpa and qkv_format == "thd": logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" - " MLA attention" + " attention and THD format" + ) + use_fused_attention = False + elif fp8 and fp8_recipe.fp8_dpa and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with FP8" + " attention and bias" + ) + use_fused_attention = False + elif core_attention_bias_type != "no_bias" and cp_comm_type != "p2p": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias" + " and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD" + " format and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif ( + window_size is not None + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and cp_comm_type in ["p2p", "a2a+p2p"] + ): + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with sliding" + " window attention and cp_comm_type = %s", + cp_comm_type, + ) + use_fused_attention = False + elif cp_comm_type in ["a2a", "a2a+p2p"] and (num_heads % 2 != 0 or num_gqa_groups % 2 != 0): + logger.debug( + "Disabling FusedAttention as cp_comm_type = %s requires num_heads and" + " num_gqa_groups divisible by 2 (got num_heads = %s, num_gqa_groups = %s)", + cp_comm_type, + num_heads, + num_gqa_groups, ) use_fused_attention = False @@ -872,12 +957,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention for FP8" - ) - use_fused_attention = False - elif attention_dropout != 0.0: + if attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " "without dropout" @@ -981,8 +1061,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if use_fused_attention: q_type = TE_DType[qkv_dtype] kv_type = q_type - if fp8 and fp8_meta["recipe"].fp8_dpa: - q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8 and fp8_recipe.fp8_dpa: + q_type = get_fp8_te_dtype(fp8_recipe, fprop_tensor=True) kv_type = q_type fused_attention_backend = tex.get_fused_attn_backend( is_training, @@ -1012,8 +1092,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and window_size is not None - and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + and (window_size[0] != -1 or window_size[1] not in [-1, 0]) + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] ): logger.debug( "Disabling FusedAttention as only sub-backend %s does not support " @@ -1058,15 +1138,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic: - if softmax_type != "vanilla": - logger.debug( - "Disabling FusedAttention for determinism reasons with softmax_type = %s. " - "Sink attention (off-by-one and learnable softmax) requires " - "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", - softmax_type, - ) - use_fused_attention = False - fused_attention_backend = None + # if softmax_type != "vanilla": + # logger.debug( + # "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + # "Sink attention (off-by-one and learnable softmax) requires " + # "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + # softmax_type, + # ) + # use_fused_attention = False + # fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["FP8"] and is_training @@ -2095,28 +2175,45 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers): +def get_attention_quantizers(fp8, fp8_recipe, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] - QKV_quantizer.internal = True + QKV_quantizer.internal = False QKV_quantizer.set_usage(rowwise=True, columnwise=False) - O_quantizer = quantizers["scaling_fwd"][META_O] - O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] - dQKV_quantizer.interal = True - dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.internal = False + O_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.internal = False dO_quantizer.set_usage(rowwise=True, columnwise=False) - dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] - dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True + dP_quantizer.set_usage(rowwise=True, columnwise=False) + + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = False + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8_recipe.mxfp8(): + QKV_quantizer.columnwise_usage = True + QKV_quantizer.optimize_for_gemm = True + S_quantizer = None + O_quantizer.columnwise_usage = True + + dO_quantizer.columnwise_usage = True + dO_quantizer.optimize_for_gemm = True + dP_quantizer = None + dQKV_quantizer.columnwise_usage = True return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer @@ -2170,18 +2267,87 @@ def print_quantizers( type_str = "DS" elif isinstance(q, Float8CurrentScalingQuantizer): type_str = "CS" - print( - f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" - f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" - ) + elif isinstance(q, MXFP8Quantizer): + type_str = "MXFP8" + if type_str in ["DS", "CS"]: + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) + else: + print(f"{label} >> {names[i]:14s}: {type_str}") + + +def permute_to_grouped_tensor(src_format, tensor): + """Permute tensor from src_format = {bshd, sbhd, thd} to des_format = {bhsd, htd} for MXFP8 quantization.""" + if src_format in ["bhsd", "htd"]: + return tensor, src_format + des_format = "bhsd" if src_format != "thd" else "htd" + # make tensor contiguous bshd/sbhd/thd + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + # permute bshd/sbhd to bhsd, and thd to htd + dim_s_or_t = src_format.find("s") if "s" in src_format else src_format.find("t") + dim_others = [i for i in range(len(tensor.shape)) if i != dim_s_or_t] + new_dims = [*dim_others[:-1], dim_s_or_t, dim_others[-1]] + tensor = tensor.permute(*new_dims).contiguous() + return tensor, des_format def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): """Combine q,k,v based on qkv_layout and quantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate + if isinstance(qkv_quantizer, MXFP8Quantizer): + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) + # permute q, k, v to bhsd/htd format + if q_format not in ["bhsd", "htd"]: + q, _ = permute_to_grouped_tensor(q_format, q) + if kv_format not in ["bhsd", "htd"]: + k, _ = permute_to_grouped_tensor(kv_format, k) + v, _ = permute_to_grouped_tensor(kv_format, v) + qkv_layout = "bhsd_bhsd_bhsd" if qkv_format != "thd" else "htd_htd_htd" + # check shapes + original_shapes = [x.shape for x in [q, k, v]] + s_q, d_qk = q.shape[-2:] + s_kv, d_v = v.shape[-2:] + assert s_q % 128 == 0 and s_kv % 128 == 0 and d_qk % 32 == 0 and d_v % 32 == 0, ( + "MXFP8 quantization requires s_q % 128 == 0, s_kv % 128 == 0, d_qk % 32 == 0, d_v % 32" + f" == 0. Found {s_q=}, {s_kv=}, {d_qk=}, {d_v=}." + ) + q, k, v = [x.view(-1, x.shape[-1]) for x in [q, k, v]] + # quantize q, k, v + if d_qk == d_v: + input_tensors = [q, k, v] + num_tensors = len(input_tensors) + shapes = [x.shape for x in input_tensors] + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shapes, + quantizer=qkv_quantizer, + device="cuda", + dtype=q.dtype, + ) + q_fp8, k_fp8, v_fp8 = grouped_tensor.quantize(input_tensors) + else: + input_tensors = [q, k] + num_tensors = len(input_tensors) + shapes = [x.shape for x in input_tensors] + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shapes=shapes, + quantizer=qkv_quantizer, + device="cuda", + dtype=q.dtype, + ) + q_fp8, k_fp8 = grouped_tensor.quantize(input_tensors) + v_fp8 = qkv_quantizer(v) + # view rowwise/columnwise data back to original shapes, not rowwise_scale_inv/columnwise_scale_inv + q_fp8, k_fp8, v_fp8 = [x.view(s) for x, s in zip([q_fp8, k_fp8, v_fp8], original_shapes)] + + return q_fp8, k_fp8, v_fp8, qkv_layout + qkv_layout = qkv_layout.replace("paged_kv_", "") qkv_group = len(qkv_layout.split("_")) src_nominal_dtype = q.dtype + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") @@ -2221,24 +2387,29 @@ def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): for x in [q_data, k_data, v_data] ] - return q_fp8, k_fp8, v_fp8 + return q_fp8, k_fp8, v_fp8, qkv_layout def combine_and_dequantize( qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None ): """Combine q,k,v based on qkv_layout and dequantize them together""" - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_layout = qkv_layout.replace("paged_kv_", "") - qkv_group = len(qkv_layout.split("_")) - if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + if all(isinstance(x, QuantizedTensorStorage) for x in [q_fp8, k_fp8, v_fp8]): src_nominal_dtype = q_fp8.dtype else: assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" if des_nominal_dtype is None: des_nominal_dtype = src_nominal_dtype + if all(isinstance(x, (MXFP8Tensor, MXFP8TensorStorage)) for x in [q_fp8, k_fp8, v_fp8]): + q, k, v = [x.dequantize(dtype=des_nominal_dtype) for x in [q_fp8, k_fp8, v_fp8]] + return q, k, v + + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_format, q_format, kv_format = get_qkv_format(qkv_layout) + qkv_group = len(qkv_layout.split("_")) q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + # 1: qkv packed, 2: kv packed, 3: qkv separate match qkv_group: case 1: dim = qkv_layout.find("3") diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d95d327c78..9972ddd994 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -795,14 +795,22 @@ def forward( fp8_dpa = fp8_recipe.fp8_dpa fp8_mha = fp8_recipe.fp8_mha float8_current_scaling = fp8_recipe.float8_current_scaling() + mxfp8_scaling = fp8_recipe.mxfp8() else: fp8_dpa = _dpa_fp8_recipe_dpa fp8_mha = _dpa_fp8_recipe_mha float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" - # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe - qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling - # DPA: always produce FP8 output when fp8=True to take advantage of the O amax - dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + mxfp8_scaling = _dpa_fp8_recipe == "MXFP8BlockScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling or MXFP8BlockScaling recipe + qkv_fp8_output = ( + fp8 + and fp8_mha + and rotary_pos_emb is None + and not float8_current_scaling + and not mxfp8_scaling + ) + # DPA: produce FP8 output when fp8=True to take advantage of the O amax except for MXFP8BlockScaling + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) and not mxfp8_scaling # Proj Gemm: match DPA output except for Float8CurrentScaling proj_fp8_grad = dpa_fp8_output and not float8_current_scaling diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 2de4576e05..0ca738cdb8 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -42,6 +42,7 @@ "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, + "bhsd": NVTE_QKV_Format.NVTE_BHSD, } QKVLayout = { @@ -70,6 +71,7 @@ "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, + "bhsd_bhsd_bhsd": NVTE_QKV_Layout.NVTE_BHSD_BHSD_BHSD, } AttnBiasType = { @@ -134,6 +136,7 @@ def fused_attn_fwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -203,6 +206,8 @@ def fused_attn_fwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -302,17 +307,6 @@ def fused_attn_fwd( rng_elts_per_thread = ( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - if s_quantizer is None: - raise ValueError( - "s_quantizer is required for FP8 fused attention forward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if o_quantizer is None: - raise ValueError( - "o_quantizer is required for FP8 fused attention forward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -326,6 +320,7 @@ def fused_attn_fwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -379,7 +374,6 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, fake_dtype: torch.dtype, - dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, @@ -391,6 +385,9 @@ def fused_attn_bwd( dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", + o_format: str = "sbhd", + d_out_format: str = "sbhd", + dqkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", softmax_type: str = "vanilla", @@ -429,8 +426,6 @@ def fused_attn_bwd( fake_dtype : tex.DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype - dqkv_dtype : tex.DType - data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -457,6 +452,15 @@ def fused_attn_bwd( {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} + o_format : str, default = "sbhd" + format of O; {"sbhd", "bshd", "thd"} + d_out_format : str, default = "sbhd" + format of dO; {"sbhd", "bshd", "thd"} + dqkv_layout : str, default = "sbh3d" + layout of dQ, dK and dV; + {"sb3hd", "sbh3d", "sbhd_sb2hd", "sbhd_sbh2d", "sbhd_sbhd_sbhd", + "bs3hd", "bsh3d", "bshd_bs2hd", "bshd_bsh2d", "bshd_bshd_bshd", + "t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"} attn_bias_type : str, default = "no_bias" type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type : str, default = "padding" @@ -517,29 +521,6 @@ def fused_attn_bwd( f" for backend={fused_attention_backend}." ) - if fused_attention_backend == FusedAttnBackend["FP8"]: - if s_quantizer is None: - raise ValueError( - "s_quantizer is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if dp_quantizer is None: - raise ValueError( - "dp_quantizer is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if dqkv_dtype is None: - raise ValueError( - "dqkv_dtype is required for FP8 fused attention backward" - f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." - ) - if len(aux_ctx_tensors) != 3: - raise ValueError( - "aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," - f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" - f" (backend={fused_attention_backend})." - ) - output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -547,6 +528,9 @@ def fused_attn_bwd( dropout, fast_zero_fill, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[d_out_format], + QKVLayout[dqkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], @@ -561,7 +545,6 @@ def fused_attn_bwd( o, d_o, fake_dtype, - dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 63a2e86e67..350238af03 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -300,6 +300,14 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data = std::nullopt); + std::pair create_grouped_tensor( size_t num_tensors, const std::vector& logical_shape, DType dtype, py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..602a09f54b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -84,7 +84,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -98,11 +98,12 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c33..f203d02c4b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -91,6 +91,21 @@ std::pair quantizer_helper(py::handle quantizer, !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // MXFP8 + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK(!data.has_value(), + "MXFP8Quantizer::create_tensor() does not take data tensor as input!"); + } } return {std::move(te_T), std::move(py_T)}; } @@ -98,7 +113,7 @@ std::pair quantizer_helper(py::handle quantizer, // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, @@ -134,8 +149,15 @@ std::vector fused_attn_fwd( std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape_tmp = std::vector{q_shape.begin(), q_shape.end()}; + o_shape_tmp[o_shape_tmp.size() - 1] = v_shape[v_shape.size() - 1]; + auto o_shape = std::vector{o_shape_tmp.begin(), o_shape_tmp.end()}; + size_t b = 0, h = 0, s = 0, d = 0, t = 0; + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + nvte_convert_qkv_format(q_format, o_shape_tmp, o_format, o_shape, &b, &h, &s, &d, &t); + if (q_format == NVTE_QKV_Format::NVTE_THD) { + b = cu_seqlens_q.size(0) - 1; + } const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); @@ -146,9 +168,7 @@ std::vector fused_attn_fwd( TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (o_format == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) { mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { @@ -156,7 +176,7 @@ std::vector fused_attn_fwd( } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (o_format == NVTE_QKV_Format::NVTE_THD) { te_O.zero_(at::cuda::getCurrentCUDAStream()); } } else { @@ -235,9 +255,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -295,9 +315,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), - at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, o_format, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory @@ -310,11 +330,12 @@ std::vector fused_attn_fwd( // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format d_out_format, + NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, const std::vector window_size, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -343,25 +364,40 @@ std::vector fused_attn_bwd( std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d_qk = q_shape[q_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - + size_t b = 0, h_q = 0, h_kv = 0, s_q = 0, s_kv = 0, d_qk = 0, d_v = 0, t_q = 0, t_kv = 0; + size_t ndim = q_shape.size(); + std::vector dQ_shape(ndim), dK_shape(ndim), dV_shape(ndim); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + NVTE_QKV_Format dq_format = nvte_get_q_format(dqkv_layout); + NVTE_QKV_Format dkv_format = nvte_get_kv_format(dqkv_layout); + nvte_convert_qkv_format(q_format, q_shape, dq_format, dQ_shape, &b, &h_q, &s_q, &d_qk, &t_q); + nvte_convert_qkv_format(kv_format, k_shape, dkv_format, dK_shape, &b, &h_kv, &s_kv, &d_qk, &t_kv); + nvte_convert_qkv_format(kv_format, v_shape, dkv_format, dV_shape, &b, &h_kv, &s_kv, &d_v, &t_kv); + if (dq_format == NVTE_QKV_Format::NVTE_THD) { + b = cu_seqlens_q.size(0) - 1; + } else if (dkv_format == NVTE_QKV_Format::NVTE_THD) { + b = cu_seqlens_kv.size(0) - 1; + } at::Tensor dQ, dK, dV, dQKV, dKV; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - std::vector tmp_shape; + DType dqkv_type = fake_dtype_te; + if (!dqkv_quantizer.is_none()) { + dqkv_type = dqkv_quantizer.attr("dtype").cast(); + } auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { options = options.dtype(torch::kUInt8); } - if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr()) || + detail::IsMXFP8Quantizers(dqkv_quantizer.ptr())) { options = options.dtype(fake_dtype); } - + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(dqkv_layout); + std::vector tmp_shape; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -378,7 +414,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; + tmp_shape = std::vector{dQ_shape.begin(), dQ_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -392,9 +428,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -407,9 +443,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; + tmp_shape = std::vector{dK_shape.begin(), dK_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -420,25 +456,26 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + case NVTE_QKV_Layout_Group::NVTE_SD_SD_SD: + tmp_shape = std::vector(dQ_shape.begin(), dQ_shape.end()); dQ = torch::empty(tmp_shape, options); - tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + tmp_shape = std::vector(dK_shape.begin(), dK_shape.end()); dK = torch::empty(tmp_shape, options); - tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + tmp_shape = std::vector(dV_shape.begin(), dV_shape.end()); dV = torch::empty(tmp_shape, options); break; default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); - std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); - std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, dQ_shape, fake_dtype_te, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, dK_shape, fake_dtype_te, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, dV_shape, fake_dtype_te, true, dV); // construct NVTE tensors if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (set_zero && (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD)) { if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -451,7 +488,7 @@ std::vector fused_attn_bwd( } } } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { + if (nvte_get_qkv_format(dqkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); @@ -538,9 +575,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -555,9 +592,9 @@ std::vector fused_attn_bwd( &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), - at::cuda::getCurrentCUDAStream()); + attn_scale, p_dropout, qkv_layout, o_format, d_out_format, dqkv_layout, bias_type, + attn_mask_type, softmax_type, window_size[0], window_size[1], bottom_right_diagonal, + deterministic, cuda_graph, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..7f026fe1b1 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1445,6 +1445,18 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data) { + at::Tensor amax_tensor = + at::zeros({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); + out_cpp.set_amax(amax_tensor.data_ptr(), DType::kFloat32, std::vector{1}); + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::create_grouped_tensor( const size_t num_tensors, const std::vector& logical_shape, const DType dtype, py::object quantizer, const std::optional& first_dims,