From bcf7a3ee0d919eca45d2f07241479b5776975bc3 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 19 Dec 2024 23:29:49 -0800 Subject: [PATCH] bugfix: bug fix on `determine_attention_backend` condition (#688) Should only enable fa3 for cuda 12.3+ --- flashinfer/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index d38af827..7164038e 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -20,6 +20,7 @@ from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import torch +import torch.version from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version @@ -342,12 +343,16 @@ def determine_attention_backend( """ major, _ = get_compute_capability(device) - if major >= 9 and is_fa3_backend_supported( - pos_encoding_mode, - allow_fp16_qk_reductions, - use_custom_mask, - dtype_q, - dtype_kv, + if ( + major >= 9 + and torch.version.cuda >= "12.3" + and is_fa3_backend_supported( + pos_encoding_mode, + allow_fp16_qk_reductions, + use_custom_mask, + dtype_q, + dtype_kv, + ) ): return "fa3" else: