diff --git a/setup.py b/setup.py index 5ad76fd9b8e..a83e555acda 100644 --- a/setup.py +++ b/setup.py @@ -139,9 +139,8 @@ def get_extensions(): ) is_rocm_pytorch = False - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): + + if torch.__version__ >= "1.5": from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False