-
Notifications
You must be signed in to change notification settings - Fork 1
/
backend.h
37 lines (32 loc) · 1.2 KB
/
backend.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <map>
#include <string>
#include <vector>
#include "mlperf_bench.h"
#include "onnxruntime/core/session/onnxruntime_cxx_api.h"
#include "status.h"
namespace mlperf_bench {
class Backend {
public:
Backend();
Status LoadModel(std::string path, std::vector<std::string> outputs);
std::vector<Ort::Value> Run(Ort::Value* inputs, size_t input_count);
template <typename T>
Ort::Value GetTensor(std::vector<int64_t>& shapes, std::vector<T>& data) {
return Ort::Value::CreateTensor<T>(
allocator_info_.GetInfo(), data.data(), data.size(), shapes.data(), shapes.size());
};
ONNXTensorElementDataType GetInputType(size_t idx) { return input_type_[idx]; };
Ort::SessionOptions& GetOpt() { return opt_; };
private:
Ort::Session* session_;
Ort::SessionOptions opt_;
std::vector<std::vector<int64_t>> output_shapes_;
std::vector<char*> output_names_;
std::vector<char*> input_names_;
std::vector<ONNXTensorElementDataType> input_type_;
const Ort::RunOptions run_options_; //(nullptr);
Ort::AllocatorWithDefaultOptions allocator_info_;
OrtAllocator* allocator_;
Ort::Env env_{ ORT_LOGGING_LEVEL_WARNING, "mlperf_bench" };
};
}