From 411b1f8c0bd6cc1b6c028f770e137e214fe4e92e Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Tue, 17 Aug 2021 16:25:52 +0000 Subject: [PATCH] Make torch version check numeric --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4cc3d0698a4..d046098a6ef 100644 --- a/setup.py +++ b/setup.py @@ -145,7 +145,9 @@ def get_extensions(): ) is_rocm_pytorch = False - if torch.__version__ >= '1.5': + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False