Skip to content

Commit

Permalink
Expose cuda_version in the API, use it to avoid pruning of ops initia…
Browse files Browse the repository at this point in the history
…lizer
  • Loading branch information
bmanga committed Oct 15, 2020
1 parent 2c3b9d0 commit f8b3c7d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
15 changes: 4 additions & 11 deletions torchvision/csrc/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ PyMODINIT_FUNC PyInit__C(void) {
#endif
#endif

int64_t _cuda_version() {
namespace vision {
int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
return -1;
#endif
}
} // namespace vision

TORCH_LIBRARY(torchvision, m) {
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
Expand All @@ -53,7 +55,7 @@ TORCH_LIBRARY(torchvision, m) {
m.def("ps_roi_align", &ps_roi_align);
m.def("ps_roi_pool", &ps_roi_pool);
m.def("deform_conv2d", &deform_conv2d);
m.def("_cuda_version", &_cuda_version);
m.def("_cuda_version", &vision::cuda_version);
}

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
Expand Down Expand Up @@ -83,12 +85,3 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", ROIAlign_autograd);
m.impl("_roi_align_backward", ROIAlign_backward_autograd);
}

namespace vision {
// This function is needed to reference the static variable created by the
// TORCH_LIBRARY macro so that it is not optimized away.
int RegisterOps() noexcept {
(void)TORCH_LIBRARY_static_init_torchvision;
return 0;
}
} // namespace vision
10 changes: 7 additions & 3 deletions torchvision/csrc/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
#define VISION_H

#include <torchvision/models/models.h>
#include <cstdint>
#include "macros.h"

namespace vision {
VISION_API int RegisterOps() noexcept;
VISION_API int64_t cuda_version() noexcept;

namespace detail {
VISION_INLINE_VARIABLE int dummy = RegisterOps();
}
// Dummy variable to reference a symbol from vision.cpp.
// This ensures that the torchvision library and the ops registration
// initializers are not pruned.
VISION_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace detail
} // namespace vision

#endif // VISION_H

0 comments on commit f8b3c7d

Please sign in to comment.