Skip to content

Commit

Permalink
Add ops-cpp target to torchvision (#3350)
Browse files Browse the repository at this point in the history
Summary:
This diff adds a new target to torchvision which enables users to use torchvision ops from C++.

For now, the `cpp_library` is not used by the `python_cpp_library`. We should instead refactor the logic in torchvision to directly use `cpp_library` instead.

There is currently an inconsistency between fbcode and OSS users. OSS users can import torchvision via
```
#include <torchvision/vision.h>
```
while fbcode users need to do
```
#include <torchvision/csrc/vision.h>
```
It would be good to fix this discrepancy in the future.

I didn't directly use `test_frcnn_tracing.cpp` due to complications for getting the `.pt` file in a way that works for both OSS and fbcode, so instead we added a self-contained test that should validate that the torchvision ops are properly registered and visible to JIT

Reviewed By: datumbox

Differential Revision: D26225669

fbshipit-source-id: 5dd9fb98dd58e854f95806e4860d02f54fc04ea4

Co-authored-by: Francisco Massa <fmassa@fb.com>
  • Loading branch information
NicolasHug and fmassa authored Feb 4, 2021
1 parent 88d18d6 commit aa26498
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/cpp/test_custom_operators.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2004-present Facebook. All Rights Reserved.

#include <gtest/gtest.h>
#include <torch/script.h>
#include <torch/torch.h>

// FIXME: the include path differs from OSS due to the extra csrc
#include <torchvision/csrc/ops/nms.h>

TEST(test_custom_operators, nms) {
// make sure that the torchvision ops are visible to the jit interpreter
auto& ops = torch::jit::getAllOperatorsFor(torch::jit::Symbol::fromQualString("torchvision::nms"));
ASSERT_EQ(ops.size(), 1);

auto& op = ops.front();
ASSERT_EQ(op->schema().name(), "torchvision::nms");

torch::jit::Stack stack;
at::Tensor boxes = at::rand({50, 4}), scores = at::rand({50});
double thresh = 0.7;

torch::jit::push(stack, boxes, scores, thresh);
op->getOperation()(&stack);
at::Tensor output_jit;
torch::jit::pop(stack, output_jit);

at::Tensor output = vision::ops::nms(boxes, scores, thresh);
ASSERT_TRUE(output_jit.allclose(output));

}

0 comments on commit aa26498

Please sign in to comment.