-
Notifications
You must be signed in to change notification settings - Fork 0
/
TorchExtension.cmake
107 lines (96 loc) · 4.41 KB
/
TorchExtension.cmake
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# @file TorchExtension.cmake
# @author Zizheng Guo
# @brief Use CMake to compile PyTorch extensions
add_subdirectory(thirdparty/pybind11)
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c
"import torch; print(torch.__path__[0]); print(int(torch.cuda.is_available())); print(torch.__version__);"
OUTPUT_VARIABLE TORCH_OUTPUT OUTPUT_STRIP_TRAILING_WHITESPACE)
string(REPLACE "\n" ";" TORCH_OUTPUT_LIST ${TORCH_OUTPUT})
list(GET TORCH_OUTPUT_LIST 0 TORCH_INSTALL_PREFIX)
list(GET TORCH_OUTPUT_LIST 1 TORCH_ENABLE_CUDA)
list(GET TORCH_OUTPUT_LIST 2 TORCH_VERSION)
string(REPLACE "." ";" TORCH_VERSION_LIST ${TORCH_VERSION})
list(GET TORCH_VERSION_LIST 0 TORCH_MAJOR_VERSION)
list(GET TORCH_VERSION_LIST 1 TORCH_MINOR_VERSION)
message(STATUS TORCH_INSTALL_PREFIX=${TORCH_INSTALL_PREFIX})
message(STATUS TORCH_VERSION=${TORCH_MAJOR_VERSION}.${TORCH_MINOR_VERSION})
if ("${TORCH_MAJOR_VERSION}.${TORCH_MINOR_VERSION}" VERSION_LESS 1.6)
message(SEND_ERROR "require PyTorch version >=1.6")
#elseif ("${TORCH_MAJOR_VERSION}.${TORCH_MINOR_VERSION}" VERSION_GREATER_EQUAL 1.8)
# message(SEND_ERROR "require PyTorch version < 1.8")
endif()
if (TORCH_ENABLE_CUDA)
find_package(CUDA 9.0)
if (NOT CUDA_FOUND)
set(TORCH_ENABLE_CUDA 0 CACHE BOOL "Whether enable CUDA" FORCE)
endif(NOT CUDA_FOUND)
endif()
message(STATUS TORCH_ENABLE_CUDA=${TORCH_ENABLE_CUDA})
add_library(torch STATIC IMPORTED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib" REQUIRED)
find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib" REQUIRED)
find_library(C10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib" REQUIRED)
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")
find_library(TORCH_CPU_LIBRARY torch_cpu PATHS "${TORCH_INSTALL_PREFIX}/lib" REQUIRED)
find_library(TORCH_CUDA_LIBRARY torch_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")
if (EXISTS ${TORCH_INSTALL_PREFIX}/include)
# torch version 1.4+
set(TORCH_HEADER_PREFIX ${TORCH_INSTALL_PREFIX}/include)
elseif (EXISTS ${TORCH_INSTALL_PREFIX}/lib/include)
# torch version 1.0
set(TORCH_HEADER_PREFIX ${TORCH_INSTALL_PREFIX}/lib/include)
endif()
set(TORCH_INCLUDE_DIRS
${TORCH_HEADER_PREFIX}
${TORCH_HEADER_PREFIX}/torch/csrc/api/include)
set(LINK_LIBS ${C10_LIBRARY} ${TORCH_CPU_LIBRARY})
if (TORCH_ENABLE_CUDA)
set(LINK_LIBS ${LINK_LIBS}
${C10_CUDA_LIBRARY}
${TORCH_CUDA_LIBRARY})
endif()
set_target_properties(torch PROPERTIES
IMPORTED_LOCATION "${TORCH_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
INTERFACE_LINK_LIBRARIES "${LINK_LIBS}"
INTERFACE_COMPILE_OPTIONS "-D_GLIBCXX_USE_CXX11_ABI=${CMAKE_CXX_ABI}"
)
# CXX only
function(add_torch_extension target_name)
set(multiValueArgs EXTRA_INCLUDE_DIRS EXTRA_LINK_LIBRARIES EXTRA_DEFINITIONS)
cmake_parse_arguments(ARG "" "" "${multiValueArgs}" ${ARGN})
if (TORCH_ENABLE_CUDA)
cuda_add_library(${target_name} STATIC ${ARG_UNPARSED_ARGUMENTS})
else()
# remove cuda files
list(FILTER ARG_UNPARSED_ARGUMENTS EXCLUDE REGEX ".*cu$")
list(FILTER ARG_UNPARSED_ARGUMENTS EXCLUDE REGEX ".*cuh$")
add_library(${target_name} STATIC ${ARG_UNPARSED_ARGUMENTS})
endif()
target_include_directories(${target_name} PRIVATE ${ARG_EXTRA_INCLUDE_DIRS})
target_link_libraries(${target_name} ${ARG_EXTRA_LINK_LIBRARIES} torch pybind11::module)
target_compile_definitions(${target_name} PRIVATE
TORCH_EXTENSION_NAME=${target_name}
TORCH_MAJOR_VERSION=${TORCH_MAJOR_VERSION}
TORCH_MINOR_VERSION=${TORCH_MINOR_VERSION}
ENABLE_CUDA=${TORCH_ENABLE_CUDA}
${ARG_EXTRA_DEFINITIONS})
set_target_properties(${target_name} PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_VISIBILITY_PRESET "hidden"
CUDA_VISIBILITY_PRESET "hidden"
)
endfunction()
function(add_pytorch_extension target_name)
set(multiValueArgs EXTRA_INCLUDE_DIRS EXTRA_LINK_LIBRARIES EXTRA_DEFINITIONS)
cmake_parse_arguments(ARG "" "" "${multiValueArgs}" ${ARGN})
pybind11_add_module(${target_name} MODULE ${ARG_UNPARSED_ARGUMENTS})
target_include_directories(${target_name} PRIVATE ${ARG_EXTRA_INCLUDE_DIRS})
target_link_libraries(${target_name} PRIVATE ${ARG_EXTRA_LINK_LIBRARIES} torch ${TORCH_PYTHON_LIBRARY})
target_compile_definitions(${target_name} PRIVATE
TORCH_EXTENSION_NAME=${target_name}
TORCH_MAJOR_VERSION=${TORCH_MAJOR_VERSION}
TORCH_MINOR_VERSION=${TORCH_MINOR_VERSION}
ENABLE_CUDA=${TORCH_ENABLE_CUDA}
${ARG_EXTRA_DEFINITIONS})
endfunction()