diff --git a/README.rst b/README.rst index 971919e1..a25d9d5b 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,7 @@ .. This file was modified to include portability information to AMDGPU. - Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. @@ -21,7 +21,7 @@ Feature Support Status * Activation, cast, fused softmax, layernorm, rmsnorm, transpose, fused rope, fp8 recipe, HipRTC: fully supported * GEMM: partially supported with following input/output types: (fp32/fp32), (fp16/fp16), (bf16/bf16), (fp8, bf8/fp16, bf16, fp32) -* Attention (Flash Attention, Fused Multihead Attention): partially supported: Fused Attention with AOTriton and CK backends +* Attention (Flash Attention, Fused Multihead Attention): partially supported: Fused Attention with AOTriton and CK backends, FlashAttention-2 with fixed length sequences * HipGraph, HipTX: partially supported * Tensor Parallelism, Sequence Parallelism, Context Parallelism: supported diff --git a/ci/_utils.sh b/ci/_utils.sh index e3d8ac7f..4ffd0fe3 100644 --- a/ci/_utils.sh +++ b/ci/_utils.sh @@ -55,17 +55,24 @@ configure_gemm_env() { configure_fused_attn_env() { case "$1" in "auto") - unset NVTE_FUSED_ATTN NVTE_FUSED_ATTN_AOTRITON NVTE_FUSED_ATTN_CK + unset NVTE_FLASH_ATTN NVTE_FUSED_ATTN NVTE_FUSED_ATTN_AOTRITON NVTE_FUSED_ATTN_CK ;; "aotriton") + export NVTE_FLASH_ATTN=0 export NVTE_FUSED_ATTN_CK=0 unset NVTE_FUSED_ATTN NVTE_FUSED_ATTN_AOTRITON ;; "ck") + export NVTE_FLASH_ATTN=0 export NVTE_FUSED_ATTN_AOTRITON=0 unset NVTE_FUSED_ATTN NVTE_FUSED_ATTN_CK ;; + "flash") + export NVTE_FUSED_ATTN=0 NVTE_FUSED_ATTN_CK=0 NVTE_FUSED_ATTN_AOTRITON=0 + unset NVTE_FLASH_ATTN + ;; "unfused") + export NVTE_FLASH_ATTN=0 export NVTE_FUSED_ATTN=0 unset NVTE_FUSED_ATTN_AOTRITON NVTE_FUSED_ATTN_CK ;; diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 82984537..3e6e32d5 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -91,12 +91,13 @@ pip list | egrep "flash|ml_dtypes|numpy|onnx|torch|transformer_e|typing_ext" for _gemm in hipblaslt rocblas; do configure_gemm_env $_gemm || continue - for _fus_attn in auto ck aotriton unfused; do + for _fus_attn in auto flash ck aotriton unfused; do configure_fused_attn_env $_fus_attn || continue - #Auto - default mode with all Fused attentions backends enabled - #CK/AOTriton - only corresponding Fused attention backend is enabled - #Unfused - Fused attention is disabled + #Auto - default mode with all Flash and Fused attention backends enabled + #Flash - Fused attention is disabled + #CK/AOTriton - no Flash attention and only corresponding Fused attention backend is enabled + #Unfused - Flash and Fused attentions are disabled #Level 1 - run hipBlasLt in auto and unfused modes, rocBlas in auto mode #Level 3 - run hipBlasLt in all but unfused modes, rocBlas in auto and unfused modes if [ $TEST_LEVEL -ge 3 ]; then diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 3a9d37b3..9b20836f 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -136,7 +136,7 @@ def _get_attention_backends( ) -> Tuple[List, List]: """Check if what attention backends support a model configuration""" - os.environ["NVTE_FLASH_ATTN"] = "1" if not IS_HIP_EXTENSION else "0" + os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index bf177e38..bed9a9f4 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,7 +12,7 @@ from test_fused_attn import ModelConfig from transformer_engine.pytorch.attention import ( _flash_attn_2_plus, - _flash_attn_2_3_plus, + _flash_attn_2_6_plus, ) from transformer_engine.pytorch.utils import ( get_device_compute_capability, @@ -74,6 +74,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) + if IS_HIP_EXTENSION and qkv_format == "thd": + pytest.skip("CP tests do not support THD format on ROCm yet!") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 50008fe0..d1cb4b3f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1646,6 +1646,8 @@ def test_gpt_cuda_graph(dtype, bs, model): if not use_hipblaslt(): pytest.skip("CUDA graph capture not supported with rocBLAS path") if dtype not in (torch.float32,): + if int(os.environ.get("FLASH_ATTN", "1")) != 0: + pytest.skip(f"ROCm flash attention does not support cuda graph with {dtype}") use_aotriton, use_ck = rocm_fused_attn_backend() if use_aotriton and not use_ck: pytest.skip(f"AOTriton attention backend does not support cuda graph with {dtype}") @@ -1808,7 +1810,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): device="cuda", attn_input_format="bshd", ) - + #TODO: release after rocm fused attn support var seq len features if not IS_HIP_EXTENSION: torch.manual_seed(0) @@ -1838,7 +1840,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): for (n1, p1), (n2, p2) in zip( block_bshd.named_parameters(), block_sbhd.named_parameters() ): - assert torch.all(torch.eq(p1, p2)), f"{n1} and {n2} not identical" + assert torch.all(torch.eq(p1, p2)), f"{n1} and {n2} not identical" x_sbhd = torch.randn( (config.seq_len, bs, config.hidden_size), @@ -1873,7 +1875,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): y_bshd, y_sbhd.transpose(0, 1).contiguous(), ) - + # TODO: wait for rocm fused attn support var seqlen if not IS_HIP_EXTENSION: # THD is not supported in float32 and on GPUs older than Ampere, skip the test here @@ -1903,6 +1905,10 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("backend", backends_inference) def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend): + if ((backend == "FlashAttention" and os.getenv("NVTE_FLASH_ATTN", "1") == "0") or + (backend == "FusedAttention" and os.getenv("NVTE_FUSED_ATTN", "1") == "0")): + pytest.skip(f"{backend} is disabled") + os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8ae7a80e..0b0cd6a9 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -90,13 +90,8 @@ from transformer_engine.pytorch.graph import is_graph_capturing -#TODO: add back once rocm TE support flash-attn -if not IS_HIP_EXTENSION: - _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) - _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) -else: - _NVTE_FLASH_ATTN = 0 - _flash_attn_version = PkgVersion("0.0.1") +_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) +_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) _flash_attn_version_required = PkgVersion("2.0.6") @@ -107,33 +102,35 @@ _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") +_flash_attn_2_6_plus = _flash_attn_version >= PkgVersion("2.6") _flash_attn_3_plus = False _use_flash_attn_3 = False -try: - _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) - _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") -except PackageNotFoundError: - if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: - warnings.warn( - "To use flash-attn v3, please use the following commands to install: \n" - """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" - """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" - """(3) mkdir -p $python_path/flashattn_hopper \n""" - """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" +if not IS_HIP_EXTENSION: + try: + _flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper")) + _flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1") + except PackageNotFoundError: + if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN: + warnings.warn( + "To use flash-attn v3, please use the following commands to install: \n" + """(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n""" + """(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n""" + """(3) mkdir -p $python_path/flashattn_hopper \n""" + """(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" + ) + else: + from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flashattn_hopper.flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import + _flash_attn_forward as _flash_attn_forward_v3, + ) + from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import + _flash_attn_backward as _flash_attn_backward_v3, ) -else: - from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flashattn_hopper.flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_v3, - ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_forward as _flash_attn_forward_v3, - ) - from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import - _flash_attn_backward as _flash_attn_backward_v3, - ) - _use_flash_attn_3 = True + _use_flash_attn_3 = True if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func @@ -334,11 +331,7 @@ def get_attention_backend( # Filter: Environment variables global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN - # TODO: enable flash attn package in rocm TE - if IS_HIP_EXTENSION: - _NVTE_FLASH_ATTN = 0 - else: - _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) + _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) _NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) use_flash_attention = _NVTE_FLASH_ATTN @@ -635,13 +628,6 @@ def get_attention_backend( "with causal mask, no dropout, and qkv_format = bshd/sbhd" ) use_fused_attention = False - # ROCm TE can support generic sliding window with dropout - elif IS_HIP_EXTENSION and qkv_format == "thd": - logger.debug( - "Disabling ROCm FusedAttention as it only supports sliding window attention " - "with qkv_format = bshd/sbhd" - ) - use_fused_attention = False elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ "no_mask", "padding", @@ -860,7 +846,7 @@ def get_attention_backend( # Select FusedAttention for performance if ( - use_flash_attention + use_flash_attention and (not IS_HIP_EXTENSION) and use_fused_attention and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] ): @@ -877,6 +863,7 @@ def get_attention_backend( if ( use_flash_attention and use_fused_attention + and not IS_HIP_EXTENSION and fused_attention_backend == FusedAttnBackend["FP8"] and _use_flash_attn_3 ): @@ -2083,6 +2070,14 @@ def forward( **fa_optional_forward_kwargs, ) + # Depending on flash_attn version softmax_lse shape may be + # either (batch_size, nheads, seqlen) or (nheads, total_q_seqlen) + # Here we use the former format + if not use_fused_attention and _flash_attn_2_6_plus: + softmax_lse_per_step[i] = softmax_lse_per_step[i].view( + softmax_lse_per_step[i].shape[0], cu_seqlens_q.numel() - 1, -1 + ).movedim(0, 1) + if i > 0: # wait until fwd restuls correction of last step is done if i > 1: @@ -2301,13 +2296,21 @@ def backward(ctx, dout): softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 ) - softmax_lse_ = softmax_lse_[..., 1, :].contiguous() + softmax_lse_ = softmax_lse_[..., 1, :] if ctx.use_fused_attention: # [b, np, sq//2] -> [b, np, sq//2, 1] softmax_lse_.unsqueeze_(-1) + elif _flash_attn_2_6_plus: + # [b, np, sq//2] -> [np, b*sq//2] + softmax_lse_ = softmax_lse_.movedim(1,0).reshape(softmax_lse_.shape[1], -1) + softmax_lse_ = softmax_lse_.contiguous() + if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) + elif _flash_attn_2_6_plus: + # [b, np, sq] -> [np, b*sq] + softmax_lse = softmax_lse.movedim(1,0).reshape(softmax_lse.shape[1], -1).contiguous() if ctx.fp8: if ctx.use_fused_attention: @@ -2370,6 +2373,8 @@ def backward(ctx, dout): fa_optional_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_plus: + fa_optional_backward_kwargs["softcap"] = 0.0 for i in range(cp_size): # wait until KV is received @@ -3305,6 +3310,8 @@ def backward(ctx, dout): fa_optional_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_plus: + fa_optional_backward_kwargs["softcap"] = 0.0 for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -3878,6 +3885,8 @@ def backward(ctx, dout): fa_optional_backward_kwargs["alibi_slopes"] = None if _flash_attn_2_4_1_plus: fa_optional_backward_kwargs["deterministic"] = ctx.deterministic + if _flash_attn_2_6_plus: + fa_optional_backward_kwargs["softcap"] = 0.0 if ctx.use_fused_attention: dq, dk, dv, _ = fused_attn_bwd( @@ -4868,10 +4877,7 @@ def __init__( deterministic: bool = False, ) -> None: super().__init__() - - # TODO: enable after flash attn package supported in ROCm TE - if IS_HIP_EXTENSION: - return + assert ( _flash_attn_version >= _flash_attn_version_required ), f"FlashAttention minimum version {_flash_attn_version_required} is required." diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index ab18c2a6..5872e6ab 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -134,10 +134,8 @@ std::vector fused_attn_bwd( const c10::optional scale_dP, const c10::optional scale_dQKV, c10::optional amax_dP, c10::optional amax_dQKV); -#ifndef USE_ROCM at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -#endif /*************************************************************************************************** * GEMM @@ -426,19 +424,18 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten * Miscellaneous **************************************************************************************************/ -//TODO: support user buffer for ROCm #ifndef USE_ROCM size_t get_cublasLt_version(); size_t get_cudnn_version(); #endif +//TODO: support user buffer for ROCm void placeholder(); /*************************************************************************************************** * Support THD format for Context Parallel **************************************************************************************************/ -#ifndef USE_ROCM at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, int half_idx); @@ -458,7 +455,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank); -#endif /*************************************************************************************************** * multi_tensor_* kernels diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 823d67fa..cc4572e5 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1250,7 +1250,6 @@ std::vector fused_attn_bwd( return {dQ, dK, dV, dBias}; } -#ifndef USE_ROCM namespace flash_attention { constexpr int warp_size = 32; @@ -1943,5 +1942,3 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t return output; } - -#endif //ifndef USE_ROCM diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 45d99141..5d2da88a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -148,12 +148,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dswiglu", &dswiglu, "Backward of SwiGLU", py::call_guard()); m.def("dqgelu", &dqgelu, "Backward of QuickGELU", py::call_guard()); m.def("dsrelu", &dsrelu, "Backward of Squared ReLU", py::call_guard()); -#ifndef USE_ROCM m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", py::call_guard()); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", py::call_guard()); -#endif m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, @@ -180,7 +178,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #endif m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); -#ifndef USE_ROCM // Support THD format for Context Parallel m.def("thd_read_half_tensor", &thd_read_half_tensor, "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " @@ -199,7 +196,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, "Generate partitioned indices for inputs in THD format", py::call_guard()); -#endif // multi-tensor functions m.def("multi_tensor_scale", &multi_tensor_scale_cuda,