diff --git a/CMakeLists.txt b/CMakeLists.txt index b590ed795fe1..ec757914cef4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -292,43 +292,6 @@ endif() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) -if(USE_CUDA) - find_package(CUDA REQUIRED) - add_definitions(-DMSHADOW_USE_CUDA=1) - if(NOT CUDA_TOOLSET) - set(CUDA_TOOLSET "${CUDA_VERSION_STRING}") - endif() - if(USE_NCCL) - find_package(NCCL) - if(NCCL_FOUND) - include_directories(${NCCL_INCLUDE_DIRS}) - list(APPEND mxnet_LINKER_LIBS ${NCCL_LIBRARIES}) - else() - message(WARNING "Could not find NCCL libraries") - endif() - endif() - if(UNIX) - find_package(NVTX) - if(NVTX_FOUND) - include_directories(${NVTX_INCLUDE_DIRS}) - list(APPEND mxnet_LINKER_LIBS ${NVTX_LIBRARIES}) - add_definitions(-DMXNET_USE_NVTX=1) - else() - message(WARNING "Could not find NVTX libraries") - endif() - endif() - - include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -else() - add_definitions(-DMSHADOW_USE_CUDA=0) -endif() - -if(NCCL_FOUND) - add_definitions(-DMXNET_USE_NCCL=1) -else() - add_definitions(-DMXNET_USE_NCCL=0) -endif() - if (USE_INT64_TENSOR_SIZE) message(STATUS "Using 64-bit integer for tensor size") add_definitions(-DMSHADOW_INT64_TENSOR_SIZE=1) @@ -618,8 +581,12 @@ if(MSVC) endif() if(USE_CUDA) + # CUDA_SELECT_NVCC_ARCH_FLAGS is not deprecated, though part of deprecated + # FindCUDA https://gitlab.kitware.com/cmake/cmake/issues/19199 + include(${CMAKE_ROOT}/Modules/FindCUDA/select_compute_arch.cmake) CUDA_SELECT_NVCC_ARCH_FLAGS(NVCC_FLAGS_ARCH ${MXNET_CUDA_ARCH}) LIST(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_ARCH}) + list(APPEND mxnet_LINKER_LIBS cublas cufft cusolver curand) if(ENABLE_CUDA_RTC) list(APPEND mxnet_LINKER_LIBS nvrtc cuda) @@ -627,6 +594,31 @@ if(USE_CUDA) endif() list(APPEND SOURCE ${CUDA}) add_definitions(-DMXNET_USE_CUDA=1) + add_definitions(-DMSHADOW_USE_CUDA=1) + + if(USE_NCCL) + find_package(NCCL) + if(NCCL_FOUND) + include_directories(${NCCL_INCLUDE_DIRS}) + list(APPEND mxnet_LINKER_LIBS ${NCCL_LIBRARIES}) + add_definitions(-DMXNET_USE_NCCL=1) + else() + add_definitions(-DMXNET_USE_NCCL=0) + message(WARNING "Could not find NCCL libraries") + endif() + endif() + if(UNIX) + find_package(NVTX) + if(NVTX_FOUND) + include_directories(${NVTX_INCLUDE_DIRS}) + list(APPEND mxnet_LINKER_LIBS ${NVTX_LIBRARIES}) + add_definitions(-DMXNET_USE_NVTX=1) + else() + message(WARNING "Could not find NVTX libraries") + endif() + endif() + + include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64) endif()