diff --git a/flashinfer/utils.py b/flashinfer/utils.py index d38af827..7164038e 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -20,6 +20,7 @@ from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union import torch +import torch.version from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version @@ -342,12 +343,16 @@ def determine_attention_backend( """ major, _ = get_compute_capability(device) - if major >= 9 and is_fa3_backend_supported( - pos_encoding_mode, - allow_fp16_qk_reductions, - use_custom_mask, - dtype_q, - dtype_kv, + if ( + major >= 9 + and torch.version.cuda >= "12.3" + and is_fa3_backend_supported( + pos_encoding_mode, + allow_fp16_qk_reductions, + use_custom_mask, + dtype_q, + dtype_kv, + ) ): return "fa3" else: