Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure torchvision operators are added in C++ #2798

Merged
merged 5 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ so make sure that it is also available to cmake via the ``CMAKE_PREFIX_PATH``.

For an example setup, take a look at ``examples/cpp/hello_world``.

TorchVision Operators
---------------------
In order to get the torchvision operators registered with torch (eg. for the JIT), all you need to do is to ensure that you
:code:`#include <torchvision/vision.h>` in your project.

Documentation
=============
You can find the API documentation on the pytorch website: https://pytorch.org/docs/stable/torchvision/index.html
Expand Down
11 changes: 1 addition & 10 deletions torchvision/csrc/cpu/vision_cpu.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
#pragma once
#include <torch/extension.h>

#ifdef _WIN32
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
#define VISION_API __declspec(dllimport)
#endif
#else
#define VISION_API
#endif
#include "../macros.h"

VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
Expand Down
11 changes: 1 addition & 10 deletions torchvision/csrc/cuda/vision_cuda.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
#pragma once
#include <torch/extension.h>

#ifdef _WIN32
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
#define VISION_API __declspec(dllimport)
#endif
#else
#define VISION_API
#endif
#include "../macros.h"

VISION_API at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
Expand Down
24 changes: 24 additions & 0 deletions torchvision/csrc/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef TORCHVISION_MACROS_H
#define TORCHVISION_MACROS_H

#ifdef _WIN32
#if defined(torchvision_EXPORTS)
#define VISION_API __declspec(dllexport)
#else
#define VISION_API __declspec(dllimport)
#endif
#else
#define VISION_API
#endif

#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define VISION_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define VISION_INLINE_VARIABLE __declspec(selectany)
#else
#define VISION_INLINE_VARIABLE __attribute__((weak))
#endif
#endif

#endif // TORCHVISION_MACROS_H
6 changes: 4 additions & 2 deletions torchvision/csrc/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ PyMODINIT_FUNC PyInit__C(void) {
#endif
#endif

int64_t _cuda_version() {
namespace vision {
int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
return CUDA_VERSION;
#else
return -1;
#endif
}
} // namespace vision

TORCH_LIBRARY(torchvision, m) {
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
Expand All @@ -53,7 +55,7 @@ TORCH_LIBRARY(torchvision, m) {
m.def("ps_roi_align", &ps_roi_align);
m.def("ps_roi_pool", &ps_roi_pool);
m.def("deform_conv2d", &deform_conv2d);
m.def("_cuda_version", &_cuda_version);
m.def("_cuda_version", &vision::cuda_version);
}

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
Expand Down
13 changes: 13 additions & 0 deletions torchvision/csrc/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,18 @@
#define VISION_H

#include <torchvision/models/models.h>
#include <cstdint>
#include "macros.h"

namespace vision {
VISION_API int64_t cuda_version() noexcept;

namespace detail {
// Dummy variable to reference a symbol from vision.cpp.
// This ensures that the torchvision library and the ops registration
// initializers are not pruned.
VISION_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be better to store a reference to the function pointer, rather than doing a full static initializer (which will bang on cuda version even though there's no reason to do so.)

} // namespace detail
} // namespace vision

#endif // VISION_H