From f8b3c7d60c2af941e4a60b9ca314d218691ee127 Mon Sep 17 00:00:00 2001 From: bruno Date: Thu, 15 Oct 2020 11:12:40 +0200 Subject: [PATCH] Expose cuda_version in the API, use it to avoid pruning of ops initializer --- torchvision/csrc/vision.cpp | 15 ++++----------- torchvision/csrc/vision.h | 10 +++++++--- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 9bac5d20526..75e65d67661 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -34,13 +34,15 @@ PyMODINIT_FUNC PyInit__C(void) { #endif #endif -int64_t _cuda_version() { +namespace vision { +int64_t cuda_version() noexcept { #ifdef WITH_CUDA return CUDA_VERSION; #else return -1; #endif } +} // namespace vision TORCH_LIBRARY(torchvision, m) { m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); @@ -53,7 +55,7 @@ TORCH_LIBRARY(torchvision, m) { m.def("ps_roi_align", &ps_roi_align); m.def("ps_roi_pool", &ps_roi_pool); m.def("deform_conv2d", &deform_conv2d); - m.def("_cuda_version", &_cuda_version); + m.def("_cuda_version", &vision::cuda_version); } TORCH_LIBRARY_IMPL(torchvision, CPU, m) { @@ -83,12 +85,3 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("roi_align", ROIAlign_autograd); m.impl("_roi_align_backward", ROIAlign_backward_autograd); } - -namespace vision { -// This function is needed to reference the static variable created by the -// TORCH_LIBRARY macro so that it is not optimized away. -int RegisterOps() noexcept { - (void)TORCH_LIBRARY_static_init_torchvision; - return 0; -} -} // namespace vision diff --git a/torchvision/csrc/vision.h b/torchvision/csrc/vision.h index 9362cd069e4..50bebab1fb1 100644 --- a/torchvision/csrc/vision.h +++ b/torchvision/csrc/vision.h @@ -2,14 +2,18 @@ #define VISION_H #include +#include #include "macros.h" namespace vision { -VISION_API int RegisterOps() noexcept; +VISION_API int64_t cuda_version() noexcept; namespace detail { -VISION_INLINE_VARIABLE int dummy = RegisterOps(); -} +// Dummy variable to reference a symbol from vision.cpp. +// This ensures that the torchvision library and the ops registration +// initializers are not pruned. +VISION_INLINE_VARIABLE int64_t _cuda_version = cuda_version(); +} // namespace detail } // namespace vision #endif // VISION_H