Skip to content

Commit

Permalink
Remove cpp extensions in favor of torch ops (#1348)
Browse files Browse the repository at this point in the history
* Remove C++ extensions in favor of custom ops

* Remove unused custom_ops.cpp file

* Rename _custom_ops.py

* Reorganize functions

* Minor improvements and fixes

* Fix lint

* Fully scriptable ops

* Import types used by annotations
  • Loading branch information
fmassa authored Sep 18, 2019
1 parent 0dd5588 commit f677ea3
Show file tree
Hide file tree
Showing 14 changed files with 230 additions and 239 deletions.
22 changes: 3 additions & 19 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def write_version_file():
with open(version_path, 'w') as f:
f.write("__version__ = '{}'\n".format(version))
f.write("git_version = {}\n".format(repr(sha)))
f.write("from torchvision import _C\n")
f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
f.write(" cuda = _C.CUDA_VERSION\n")
f.write("from torchvision.extension import _check_cuda_version\n")
f.write("if _check_cuda_version() > 0:\n")
f.write(" cuda = _check_cuda_version()\n")


write_version_file()
Expand Down Expand Up @@ -96,21 +96,12 @@ def get_extensions():
source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models

custom_ops_sources = [os.path.join(extensions_dir, "custom_ops", "custom_ops.cpp"),
os.path.join(extensions_dir, "cpu", "nms_cpu.cpp"),
os.path.join(extensions_dir, "cpu", "ROIAlign_cpu.cpp"),
os.path.join(extensions_dir, "cpu", "ROIPool_cpu.cpp")]
custom_ops_sources_cuda = [os.path.join(extensions_dir, "cuda", "nms_cuda.cu"),
os.path.join(extensions_dir, "cuda", "ROIAlign_cuda.cu"),
os.path.join(extensions_dir, "cuda", "ROIPool_cuda.cu")]

define_macros = []

extra_compile_args = {}
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
extension = CUDAExtension
sources += source_cuda
custom_ops_sources += custom_ops_sources_cuda
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
if nvcc_flags == '':
Expand Down Expand Up @@ -148,13 +139,6 @@ def get_extensions():
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
extension(
"torchvision._custom_ops",
sources=custom_ops_sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
]

return ext_modules
Expand Down
8 changes: 4 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def func(input):

@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
return ops.roi_pool(input, rois, 5, 1.0)[0]

assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool'

Expand Down Expand Up @@ -282,7 +282,7 @@ def func(input):

@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
return ops.roi_pool(input, rois, 5, 1.0)[0]

assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool on CUDA'

Expand Down Expand Up @@ -442,7 +442,7 @@ def func(input):

@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
return ops.roi_align(input, rois, 5, 0.5, 1)[0]

assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align'

Expand Down Expand Up @@ -482,7 +482,7 @@ def func(input):

@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
return ops.roi_align(input, rois, 5, 0.5, 1)[0]

assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA'

Expand Down
2 changes: 2 additions & 0 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torchvision import utils
from torchvision import io

from .extension import _HAS_OPS

try:
from .version import __version__ # noqa: F401
except ImportError:
Expand Down
71 changes: 71 additions & 0 deletions torchvision/csrc/ROIAlign.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,74 @@ at::Tensor ROIAlign_backward(
width,
sampling_ratio);
}

using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
auto result = ROIAlign_forward(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
return {result};
}

static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIAlign_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt());
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
}
};

Tensor roi_align(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio)[0];
}
64 changes: 63 additions & 1 deletion torchvision/csrc/ROIPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,66 @@ at::Tensor ROIPool_backward(
channels,
height,
width);
}
}

using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = ROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}

static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIPool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};

std::tuple<Tensor, Tensor> roi_pool(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<Tensor, Tensor>(result[0], result[1]);
}
159 changes: 0 additions & 159 deletions torchvision/csrc/custom_ops/custom_ops.cpp

This file was deleted.

Loading

0 comments on commit f677ea3

Please sign in to comment.