Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix CUDNN detection for CMake build
Browse files Browse the repository at this point in the history
- Use FindCUDNN.cmake instead of the previous macro
- Previous macro did not detect CUDNN correctly on my system (Deep Learning AMI
  Ubuntu 18.04)
- We now explicitly fail if the user does not provide CUDNN and hasn't manually
  set -DUSE_CUDNN=0
  • Loading branch information
leezu committed Dec 9, 2019
1 parent 3b8fdac commit 5df2368
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
14 changes: 6 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,12 @@ add_subdirectory(${GTEST_ROOT})
find_package(GTest REQUIRED)

# cudnn detection
if(USE_CUDNN AND USE_CUDA)
detect_cuDNN()
if(HAVE_CUDNN)
add_definitions(-DUSE_CUDNN)
include_directories(SYSTEM ${CUDNN_INCLUDE})
list(APPEND mxnet_LINKER_LIBS ${CUDNN_LIBRARY})
add_definitions(-DMSHADOW_USE_CUDNN=1)
endif()
if(USE_CUDNN)
find_package(CUDNN)
add_definitions(-DUSE_CUDNN)
include_directories(SYSTEM ${CUDNN_INCLUDE})
list(APPEND mxnet_LINKER_LIBS ${CUDNN_LIBRARY})
add_definitions(-DMSHADOW_USE_CUDNN=1)
endif()

if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dmlc-core/cmake)
Expand Down
28 changes: 0 additions & 28 deletions cmake/FirstClassLangCuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,6 @@ if(USE_CXX14_IF_AVAILABLE)
check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14)
endif()

################################################################################################
# Short command for cuDNN detection. Believe it soon will be a part of CUDA toolkit distribution.
# That's why not FindcuDNN.cmake file, but just the macro
# Usage:
# detect_cuDNN()
function(detect_cuDNN)
set(CUDNN_ROOT "" CACHE PATH "CUDNN root folder")

find_path(CUDNN_INCLUDE cudnn.h
PATHS ${CUDNN_ROOT} $ENV{CUDNN_ROOT}
DOC "Path to cuDNN include directory." )


find_library(CUDNN_LIBRARY NAMES libcudnn.so cudnn.lib # libcudnn_static.a
PATHS ${CUDNN_ROOT} $ENV{CUDNN_ROOT} ${CUDNN_INCLUDE}
PATH_SUFFIXES lib lib/x64
DOC "Path to cuDNN library.")

if(CUDNN_INCLUDE AND CUDNN_LIBRARY)
set(HAVE_CUDNN TRUE PARENT_SCOPE)
set(CUDNN_FOUND TRUE PARENT_SCOPE)

mark_as_advanced(CUDNN_INCLUDE CUDNN_LIBRARY CUDNN_ROOT)
message(STATUS "Found cuDNN (include: ${CUDNN_INCLUDE}, library: ${CUDNN_LIBRARY})")
endif()
endfunction()



################################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
Expand Down
33 changes: 33 additions & 0 deletions cmake/Modules/FindCUDNN.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

include(FindPackageHandleStandardArgs)

set(CUDNN_ROOT "/usr/local/cuda/include" CACHE PATH "cuDNN root folder")

find_path(CUDNN_INCLUDE cudnn.h
PATHS ${CUDNN_ROOT} $ENV{CUDNN_ROOT}
DOC "Path to cuDNN include directory." )

find_library(CUDNN_LIBRARY NAMES libcudnn.so cudnn.lib # libcudnn_static.a
PATHS ${CUDNN_ROOT} $ENV{CUDNN_ROOT} ${CUDNN_INCLUDE}
PATH_SUFFIXES lib lib/x64 cuda/lib cuda/lib64 lib/x64
DOC "Path to cuDNN library.")

find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_LIBRARY CUDNN_INCLUDE)

mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE CUDNN_LIBRARY)

0 comments on commit 5df2368

Please sign in to comment.