Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ipanfilo/fa support #96

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
..
This file was modified to include portability information to AMDGPU.

Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.

Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Expand All @@ -21,7 +21,7 @@ Feature Support Status

* Activation, cast, fused softmax, layernorm, rmsnorm, transpose, fused rope, fp8 recipe, HipRTC: fully supported
* GEMM: partially supported with following input/output types: (fp32/fp32), (fp16/fp16), (bf16/bf16), (fp8, bf8/fp16, bf16, fp32)
* Attention (Flash Attention, Fused Multihead Attention): partially supported: Fused Attention with AOTriton and CK backends
* Attention (Flash Attention, Fused Multihead Attention): partially supported: Fused Attention with AOTriton and CK backends, FlashAttention-2 with fixed length sequences
* HipGraph, HipTX: partially supported
* Tensor Parallelism, Sequence Parallelism, Context Parallelism: supported

Expand Down
9 changes: 8 additions & 1 deletion ci/_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,24 @@ configure_gemm_env() {
configure_fused_attn_env() {
case "$1" in
"auto")
unset NVTE_FUSED_ATTN NVTE_FUSED_ATTN_AOTRITON NVTE_FUSED_ATTN_CK
unset NVTE_FLASH_ATTN NVTE_FUSED_ATTN NVTE_FUSED_ATTN_AOTRITON NVTE_FUSED_ATTN_CK
;;
"aotriton")
export NVTE_FLASH_ATTN=0
export NVTE_FUSED_ATTN_CK=0
unset NVTE_FUSED_ATTN NVTE_FUSED_ATTN_AOTRITON
;;
"ck")
export NVTE_FLASH_ATTN=0
export NVTE_FUSED_ATTN_AOTRITON=0
unset NVTE_FUSED_ATTN NVTE_FUSED_ATTN_CK
;;
"flash")
export NVTE_FUSED_ATTN=0 NVTE_FUSED_ATTN_CK=0 NVTE_FUSED_ATTN_AOTRITON=0
unset NVTE_FLASH_ATTN
wenchenvincent marked this conversation as resolved.
Show resolved Hide resolved
;;
"unfused")
export NVTE_FLASH_ATTN=0
export NVTE_FUSED_ATTN=0
unset NVTE_FUSED_ATTN_AOTRITON NVTE_FUSED_ATTN_CK
;;
Expand Down
9 changes: 5 additions & 4 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@ pip list | egrep "flash|ml_dtypes|numpy|onnx|torch|transformer_e|typing_ext"
for _gemm in hipblaslt rocblas; do
configure_gemm_env $_gemm || continue

for _fus_attn in auto ck aotriton unfused; do
for _fus_attn in auto flash ck aotriton unfused; do
configure_fused_attn_env $_fus_attn || continue

#Auto - default mode with all Fused attentions backends enabled
#CK/AOTriton - only corresponding Fused attention backend is enabled
#Unfused - Fused attention is disabled
#Auto - default mode with all Flash and Fused attention backends enabled
#Flash - Fused attention is disabled
#CK/AOTriton - no Flash attention and only corresponding Fused attention backend is enabled
#Unfused - Flash and Fused attentions are disabled
#Level 1 - run hipBlasLt in auto and unfused modes, rocBlas in auto mode
#Level 3 - run hipBlasLt in all but unfused modes, rocBlas in auto and unfused modes
if [ $TEST_LEVEL -ge 3 ]; then
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _get_attention_backends(
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""

os.environ["NVTE_FLASH_ATTN"] = "1" if not IS_HIP_EXTENSION else "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
Expand Down
6 changes: 4 additions & 2 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -12,7 +12,7 @@
from test_fused_attn import ModelConfig
from transformer_engine.pytorch.attention import (
_flash_attn_2_plus,
_flash_attn_2_3_plus,
_flash_attn_2_6_plus,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this is not used in this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes. It left from pre IFU 1.11 commit where THD was only skipped with 2.6+

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, also, I think we need to merge the updates from the dev branch (now with IFU 1.11) to make sure there is no breakage from that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has those changes

)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
Expand Down Expand Up @@ -74,6 +74,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if IS_HIP_EXTENSION and qkv_format == "thd":
pytest.skip("CP tests do not support THD format on ROCm yet!")

subprocess.run(
get_bash_arguments(
Expand Down
12 changes: 9 additions & 3 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,8 @@ def test_gpt_cuda_graph(dtype, bs, model):
if not use_hipblaslt():
pytest.skip("CUDA graph capture not supported with rocBLAS path")
if dtype not in (torch.float32,):
if int(os.environ.get("FLASH_ATTN", "1")) != 0:
pytest.skip(f"ROCm flash attention does not support cuda graph with {dtype}")
use_aotriton, use_ck = rocm_fused_attn_backend()
if use_aotriton and not use_ck:
pytest.skip(f"AOTriton attention backend does not support cuda graph with {dtype}")
Expand Down Expand Up @@ -1808,7 +1810,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
device="cuda",
attn_input_format="bshd",
)

#TODO: release after rocm fused attn support var seq len features
if not IS_HIP_EXTENSION:
torch.manual_seed(0)
Expand Down Expand Up @@ -1838,7 +1840,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
for (n1, p1), (n2, p2) in zip(
block_bshd.named_parameters(), block_sbhd.named_parameters()
):
assert torch.all(torch.eq(p1, p2)), f"{n1} and {n2} not identical"
assert torch.all(torch.eq(p1, p2)), f"{n1} and {n2} not identical"

x_sbhd = torch.randn(
(config.seq_len, bs, config.hidden_size),
Expand Down Expand Up @@ -1873,7 +1875,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
y_bshd,
y_sbhd.transpose(0, 1).contiguous(),
)

# TODO: wait for rocm fused attn support var seqlen
if not IS_HIP_EXTENSION:
# THD is not supported in float32 and on GPUs older than Ampere, skip the test here
Expand Down Expand Up @@ -1903,6 +1905,10 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
if ((backend == "FlashAttention" and os.getenv("NVTE_FLASH_ATTN", "1") == "0") or
(backend == "FusedAttention" and os.getenv("NVTE_FUSED_ATTN", "1") == "0")):
pytest.skip(f"{backend} is disabled")

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"

Expand Down
104 changes: 55 additions & 49 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2022-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -90,13 +90,8 @@
from transformer_engine.pytorch.graph import is_graph_capturing


#TODO: add back once rocm TE support flash-attn
if not IS_HIP_EXTENSION:
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
else:
_NVTE_FLASH_ATTN = 0
_flash_attn_version = PkgVersion("0.0.1")
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
_flash_attn_version_required = PkgVersion("2.0.6")
Expand All @@ -107,33 +102,35 @@
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_2_6_plus = _flash_attn_version >= PkgVersion("2.6")
_flash_attn_3_plus = False
_use_flash_attn_3 = False
try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n"""
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
if not IS_HIP_EXTENSION:
try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
except PackageNotFoundError:
if get_device_compute_capability() == (9, 0) and _NVTE_FLASH_ATTN:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n"""
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_forward as _flash_attn_forward_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_backward as _flash_attn_backward_v3,
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_forward as _flash_attn_forward_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_backward as _flash_attn_backward_v3,
)

_use_flash_attn_3 = True
_use_flash_attn_3 = True

if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
Expand Down Expand Up @@ -334,11 +331,7 @@ def get_attention_backend(

# Filter: Environment variables
global _NVTE_FLASH_ATTN, _NVTE_FUSED_ATTN, _NVTE_UNFUSED_ATTN
# TODO: enable flash attn package in rocm TE
if IS_HIP_EXTENSION:
_NVTE_FLASH_ATTN = 0
else:
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
use_flash_attention = _NVTE_FLASH_ATTN
Expand Down Expand Up @@ -635,13 +628,6 @@ def get_attention_backend(
"with causal mask, no dropout, and qkv_format = bshd/sbhd"
)
use_fused_attention = False
# ROCm TE can support generic sliding window with dropout
elif IS_HIP_EXTENSION and qkv_format == "thd":
logger.debug(
"Disabling ROCm FusedAttention as it only supports sliding window attention "
"with qkv_format = bshd/sbhd"
)
use_fused_attention = False
elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [
"no_mask",
"padding",
Expand Down Expand Up @@ -860,7 +846,7 @@ def get_attention_backend(

# Select FusedAttention for performance
if (
use_flash_attention
use_flash_attention and (not IS_HIP_EXTENSION)
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
Expand All @@ -877,6 +863,7 @@ def get_attention_backend(
if (
use_flash_attention
and use_fused_attention
and not IS_HIP_EXTENSION
and fused_attention_backend == FusedAttnBackend["FP8"]
and _use_flash_attn_3
):
Expand Down Expand Up @@ -2083,6 +2070,14 @@ def forward(
**fa_optional_forward_kwargs,
)

# Depending on flash_attn version softmax_lse shape may be
# either (batch_size, nheads, seqlen) or (nheads, total_q_seqlen)
# Here we use the former format
if not use_fused_attention and _flash_attn_2_6_plus:
softmax_lse_per_step[i] = softmax_lse_per_step[i].view(
softmax_lse_per_step[i].shape[0], cu_seqlens_q.numel() - 1, -1
).movedim(0, 1)

if i > 0:
# wait until fwd restuls correction of last step is done
if i > 1:
Expand Down Expand Up @@ -2301,13 +2296,21 @@ def backward(ctx, dout):
softmax_lse_ = softmax_lse.view(
*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2
)
softmax_lse_ = softmax_lse_[..., 1, :].contiguous()
softmax_lse_ = softmax_lse_[..., 1, :]
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
elif _flash_attn_2_6_plus:
# [b, np, sq//2] -> [np, b*sq//2]
softmax_lse_ = softmax_lse_.movedim(1,0).reshape(softmax_lse_.shape[1], -1)
softmax_lse_ = softmax_lse_.contiguous()

if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
elif _flash_attn_2_6_plus:
# [b, np, sq] -> [np, b*sq]
softmax_lse = softmax_lse.movedim(1,0).reshape(softmax_lse.shape[1], -1).contiguous()

if ctx.fp8:
if ctx.use_fused_attention:
Expand Down Expand Up @@ -2370,6 +2373,8 @@ def backward(ctx, dout):
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
if _flash_attn_2_6_plus:
fa_optional_backward_kwargs["softcap"] = 0.0

for i in range(cp_size):
# wait until KV is received
Expand Down Expand Up @@ -3305,6 +3310,8 @@ def backward(ctx, dout):
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
if _flash_attn_2_6_plus:
fa_optional_backward_kwargs["softcap"] = 0.0

for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
Expand Down Expand Up @@ -3878,6 +3885,8 @@ def backward(ctx, dout):
fa_optional_backward_kwargs["alibi_slopes"] = None
if _flash_attn_2_4_1_plus:
fa_optional_backward_kwargs["deterministic"] = ctx.deterministic
if _flash_attn_2_6_plus:
fa_optional_backward_kwargs["softcap"] = 0.0

if ctx.use_fused_attention:
dq, dk, dv, _ = fused_attn_bwd(
Expand Down Expand Up @@ -4868,10 +4877,7 @@ def __init__(
deterministic: bool = False,
) -> None:
super().__init__()

# TODO: enable after flash attn package supported in ROCm TE
if IS_HIP_EXTENSION:
return

assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
Expand Down
6 changes: 1 addition & 5 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,8 @@ std::vector<at::Tensor> fused_attn_bwd(
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);

#ifndef USE_ROCM
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
#endif

/***************************************************************************************************
* GEMM
Expand Down Expand Up @@ -426,19 +424,18 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
* Miscellaneous
**************************************************************************************************/

//TODO: support user buffer for ROCm
#ifndef USE_ROCM
size_t get_cublasLt_version();
size_t get_cudnn_version();
#endif

//TODO: support user buffer for ROCm
void placeholder();

/***************************************************************************************************
* Support THD format for Context Parallel
**************************************************************************************************/

#ifndef USE_ROCM
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
int half_idx);

Expand All @@ -458,7 +455,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,

at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank);
#endif

/***************************************************************************************************
* multi_tensor_* kernels
Expand Down
3 changes: 0 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,6 @@ std::vector<at::Tensor> fused_attn_bwd(
return {dQ, dK, dV, dBias};
}

#ifndef USE_ROCM
namespace flash_attention {

constexpr int warp_size = 32;
Expand Down Expand Up @@ -1943,5 +1942,3 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t

return output;
}

#endif //ifndef USE_ROCM
Loading