-
Notifications
You must be signed in to change notification settings - Fork 7k
/
test_frcnn_tracing.cpp
65 lines (53 loc) · 1.75 KB
/
test_frcnn_tracing.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <ATen/ATen.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <torchvision/ROIAlign.h>
#include <torchvision/cpu/vision_cpu.h>
#include <torchvision/nms.h>
#ifdef _WIN32
// Windows only
// This is necessary until operators are automatically registered on include
static auto _nms = &nms_cpu;
#endif
int main() {
torch::DeviceType device_type;
device_type = torch::kCPU;
torch::jit::script::Module module;
try {
std::cout << "Loading model\n";
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load("fasterrcnn_resnet50_fpn.pt");
std::cout << "Model loaded\n";
} catch (const torch::Error& e) {
std::cout << "error loading the model\n";
return -1;
} catch (const std::exception& e) {
std::cout << "Other error: " << e.what() << "\n";
return -1;
}
// TorchScript models require a List[IValue] as input
std::vector<torch::jit::IValue> inputs;
// Faster RCNN accepts a List[Tensor] as main input
std::vector<torch::Tensor> images;
images.push_back(torch::rand({3, 256, 275}));
images.push_back(torch::rand({3, 256, 275}));
inputs.push_back(images);
auto output = module.forward(inputs);
std::cout << "ok\n";
std::cout << "output" << output << "\n";
if (torch::cuda::is_available()) {
// Move traced model to GPU
module.to(torch::kCUDA);
// Add GPU inputs
images.clear();
inputs.clear();
torch::TensorOptions options = torch::TensorOptions{torch::kCUDA};
images.push_back(torch::rand({3, 256, 275}, options));
images.push_back(torch::rand({3, 256, 275}, options));
inputs.push_back(images);
auto output = module.forward(inputs);
std::cout << "ok\n";
std::cout << "output" << output << "\n";
}
return 0;
}