Skip to content

Commit

Permalink
version check against PyTorch's CUDA version
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed May 22, 2019
1 parent 11da8e8 commit be37608
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def write_version_file():
with open(version_path, 'w') as f:
f.write("__version__ = '{}'\n".format(version))
f.write("git_version = {}\n".format(repr(sha)))
f.write("from torchvision import _C\n")
f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
f.write(" cuda = _C.CUDA_VERSION\n")


write_version_file()
Expand Down
28 changes: 28 additions & 0 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,31 @@ def get_image_backend():
Gets the name of the package used to load images
"""
return _image_backend


def _check_cuda_matches():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
import torch
from torchvision import _C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))


_check_cuda_matches()
7 changes: 7 additions & 0 deletions torchvision/csrc/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
#include "ROIPool.h"
#include "nms.h"

#ifdef WITH_CUDA
#include <cuda.h>
#endif

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression");
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
#ifdef WITH_CUDA
m.attr("CUDA_VERSION") = CUDA_VERSION;
#endif
}

0 comments on commit be37608

Please sign in to comment.