Skip to content

Commit

Permalink
smpart ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
laitassou committed Apr 22, 2020
1 parent 4195bef commit 45c3abd
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 91 deletions.
44 changes: 29 additions & 15 deletions service/include/Model.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//
// Model
//
/*
* class Model.h
* user: laitassou
* description
*/

#pragma once

Expand All @@ -17,11 +19,29 @@

class Tensor;

//using GraphResourcePtr = std::unique_ptr<TF_Graph> ;


class Model {
public:
struct GraphCreate {
TF_Graph * operator()() { return TF_NewGraph();}
};
struct GraphDeleter {
void operator()(TF_Graph* b) { TF_DeleteGraph(b);}
};

struct SessionDeleter {
void operator()(TF_Session* sess, TF_Status * status ) { TF_DeleteSession(sess,status);}

};

struct StatusDeleter {
void operator()(TF_Status* status) { TF_DeleteStatus(status);}
};


using unique_graph_ptr = std::unique_ptr<TF_Graph, GraphDeleter>;
using unique_session_ptr = std::unique_ptr<TF_Session, SessionDeleter>;
using unique_status_ptr = std::unique_ptr<TF_Status, StatusDeleter >;
explicit Model(const std::string&);

// Rule of five, moving is easy as the pointers can be copied, copying not as i have no idea how to copy
Expand Down Expand Up @@ -51,17 +71,11 @@ class Model {
void run(const std::vector<Tensor*>& inputs, Tensor* output);
void run(Tensor* input, Tensor* output);

struct GraphCreate {
TF_Graph * operator()() { return TF_NewGraph(); }
};
struct GraphDeleter {
void operator()(TF_Graph* b) { TF_DeleteGraph(b); }
};

private:
std::unique_ptr<TF_Graph, GraphDeleter> graph;
TF_Session* session;
TF_Status* status;
unique_graph_ptr _graph;
//unique_session_ptr _session;
TF_Session * _session;
unique_status_ptr _status;

// Read a file from a string
static TF_Buffer* read(const std::string&);
Expand Down
6 changes: 6 additions & 0 deletions service/include/Tensor.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/*
* class Tensor.h
* user: laitassou
* description:
*/

#pragma once

#include <vector>
Expand Down
47 changes: 24 additions & 23 deletions service/src/Model.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
#include "../include/Model.h"

Model::Model(const std::string& model_filename):graph(TF_NewGraph()){

this->status = TF_NewStatus();
//this->graph = std::make_unique<TF_Graph>(GraphCreate(), GraphDeleter); //TF_NewGraph();
Model::Model(const std::string& model_filename):_graph(TF_NewGraph()),_status( TF_NewStatus()){

// Create the session.
TF_SessionOptions* sess_opts = TF_NewSessionOptions();

this->session = TF_NewSession(graph.get(), sess_opts, this->status);
_session = TF_NewSession(_graph.get(), sess_opts, _status.get());

//_session.reset ( TF_NewSession(_graph.get(), sess_opts, _status.get()));
TF_DeleteSessionOptions(sess_opts);



// Check the status
this->status_check(true);

Expand All @@ -23,7 +24,7 @@ Model::Model(const std::string& model_filename):graph(TF_NewGraph()){
this->error_check(def != nullptr, "An error occurred reading the model");

TF_ImportGraphDefOptions* graph_opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(graph.get(), def, graph_opts, this->status);
TF_GraphImportGraphDef(_graph.get(), def, graph_opts, _status.get());
TF_DeleteImportGraphDefOptions(graph_opts);
TF_DeleteBuffer(def);

Expand All @@ -32,19 +33,19 @@ Model::Model(const std::string& model_filename):graph(TF_NewGraph()){
}

Model::~Model() {
TF_DeleteSession(this->session, this->status);
//TF_DeleteSession(_session, _status);
///TF_DeleteGraph(this->graph);
this->status_check(true);
TF_DeleteStatus(this->status);
//TF_DeleteStatus(_status);
}


void Model::init() {
TF_Operation* init_op[1] = {TF_GraphOperationByName(this->graph.get(), "init")};
TF_Operation* init_op[1] = {TF_GraphOperationByName(_graph.get(), "init")};

this->error_check(init_op[0]!= nullptr, "Error: No operation named \"init\" exists");

TF_SessionRun(this->session, nullptr, nullptr, nullptr, 0, nullptr, nullptr, 0, init_op, 1, nullptr, this->status);
TF_SessionRun(_session, nullptr, nullptr, nullptr, 0, nullptr, nullptr, 0, init_op, 1, nullptr, _status.get());
this->status_check(true);
}

Expand All @@ -54,10 +55,10 @@ void Model::save(const std::string &ckpt) {
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
char* data = static_cast<char *>(TF_TensorData(t));
for (int i=0; i<8; i++) {data[i]=0;}
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, _status.get());

memset(data, 0, 8); // 8-byte offset of first string.
TF_StringEncode(ckpt.c_str(), ckpt.length(), (char*)(data + 8), size - 8, status);
TF_StringEncode(ckpt.c_str(), ckpt.length(), (char*)(data + 8), size - 8, _status.get());

// Check errors
if (!this->status_check(false)) {
Expand All @@ -67,19 +68,19 @@ void Model::save(const std::string &ckpt) {
}

TF_Output output_file;
output_file.oper = TF_GraphOperationByName(this->graph.get(), "save/Const");
output_file.oper = TF_GraphOperationByName(_graph.get(), "save/Const");
output_file.index = 0;
TF_Output inputs[1] = {output_file};

TF_Tensor* input_values[1] = {t};
const TF_Operation* restore_op[1] = {TF_GraphOperationByName(this->graph.get(), "save/control_dependency")};
const TF_Operation* restore_op[1] = {TF_GraphOperationByName(_graph.get(), "save/control_dependency")};
if (!restore_op[0]) {
TF_DeleteTensor(t);
this->error_check(false, "Error: No operation named \"save/control_dependencyl\" exists");
}


TF_SessionRun(this->session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, this->status);
TF_SessionRun(_session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, _status.get());
TF_DeleteTensor(t);

this->status_check(true);
Expand All @@ -92,7 +93,7 @@ void Model::restore(const std::string& ckpt) {
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
char* data = static_cast<char *>(TF_TensorData(t));
for (int i=0; i<8; i++) {data[i]=0;}
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, _status.get());

// Check errors
if (!this->status_check(false)) {
Expand All @@ -102,20 +103,20 @@ void Model::restore(const std::string& ckpt) {
}

TF_Output output_file;
output_file.oper = TF_GraphOperationByName(this->graph.get(), "save/Const");
output_file.oper = TF_GraphOperationByName(_graph.get(), "save/Const");
output_file.index = 0;
TF_Output inputs[1] = {output_file};

TF_Tensor* input_values[1] = {t};
const TF_Operation* restore_op[1] = {TF_GraphOperationByName(this->graph.get(), "save/restore_all")};
const TF_Operation* restore_op[1] = {TF_GraphOperationByName(_graph.get(), "save/restore_all")};
if (!restore_op[0]) {
TF_DeleteTensor(t);
this->error_check(false, "Error: No operation named \"save/restore_all\" exists");
}



TF_SessionRun(this->session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, this->status);
TF_SessionRun(_session, nullptr, inputs, input_values, 1, nullptr, nullptr, 0, restore_op, 1, nullptr, _status.get());
TF_DeleteTensor(t);

this->status_check(true);
Expand Down Expand Up @@ -164,7 +165,7 @@ std::vector<std::string> Model::get_operations() const {
TF_Operation* oper;

// Iterate through the operations of a graph
while ((oper = TF_GraphNextOperation(this->graph.get(), &pos)) != nullptr) {
while ((oper = TF_GraphNextOperation(_graph.get(), &pos)) != nullptr) {
result.emplace_back(TF_OperationName(oper));
}

Expand Down Expand Up @@ -198,7 +199,7 @@ void Model::run(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>&
// Prepare output recipients
auto ov = new TF_Tensor*[outputs.size()];

TF_SessionRun(this->session, nullptr, io.data(), iv.data(), inputs.size(), oo.data(), ov, outputs.size(), nullptr, 0, nullptr, this->status);
TF_SessionRun(_session, nullptr, io.data(), iv.data(), inputs.size(), oo.data(), ov, outputs.size(), nullptr, 0, nullptr, _status.get());
this->status_check(true);

// Save results on outputs and mark as full
Expand Down Expand Up @@ -240,9 +241,9 @@ void Model::run(Tensor *input, const std::vector<Tensor*> &outputs) {

bool Model::status_check(bool throw_exc) const {

if (TF_GetCode(this->status) != TF_OK) {
if (TF_GetCode(_status.get()) != TF_OK) {
if (throw_exc) {
throw std::runtime_error(TF_Message(status));
throw std::runtime_error(TF_Message(_status.get()));
} else {
return false;
}
Expand Down
58 changes: 5 additions & 53 deletions service/src/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Tensor::Tensor(const Model& model, const std::string& operation) {

// Get operation by the name
this->op.oper = TF_GraphOperationByName(model.graph.get(), operation.c_str());
this->op.oper = TF_GraphOperationByName(model._graph.get(), operation.c_str());
this->op.index = 0;

// Operation did not exists
Expand All @@ -14,7 +14,7 @@ Tensor::Tensor(const Model& model, const std::string& operation) {
// DIMENSIONS

// Get number of dimensions
int n_dims = TF_GraphGetTensorNumDims(model.graph.get(), this->op, model.status);
int n_dims = TF_GraphGetTensorNumDims(model._graph.get(), this->op, model._status.get());

// DataType
this->type = TF_OperationOutputType(this->op);
Expand All @@ -23,7 +23,7 @@ Tensor::Tensor(const Model& model, const std::string& operation) {
if (n_dims > 0) {
// Get dimensions
auto *dims = new int64_t[n_dims];
TF_GraphGetTensorShape(model.graph.get(), this->op, dims, n_dims, model.status);
TF_GraphGetTensorShape(model._graph.get(), this->op, dims, n_dims, model._status.get());

// Check error on Model Status
model.status_check(true);
Expand Down Expand Up @@ -153,24 +153,6 @@ TF_DataType Tensor::deduce_type() {
return TF_FLOAT;
if (std::is_same<T, double>::value)
return TF_DOUBLE;
if (std::is_same<T, int32_t >::value)
return TF_INT32;
if (std::is_same<T, uint8_t>::value)
return TF_UINT8;
if (std::is_same<T, int16_t>::value)
return TF_INT16;
if (std::is_same<T, int8_t>::value)
return TF_INT8;
if (std::is_same<T, int64_t>::value)
return TF_INT64;
// if constexpr (std::is_same<T, bool>::value)
// return TF_BOOL;
if (std::is_same<T, uint16_t>::value)
return TF_UINT16;
if (std::is_same<T, uint32_t>::value)
return TF_UINT32;
if (std::is_same<T, uint64_t>::value)
return TF_UINT64;

throw std::runtime_error{"Could not deduce type!"};
}
Expand All @@ -194,50 +176,20 @@ void Tensor::deduce_shape() {
template TF_DataType Tensor::deduce_type<float>();
template TF_DataType Tensor::deduce_type<double>();
//template TF_DataType Tensor::deduce_type<bool>();
template TF_DataType Tensor::deduce_type<int8_t>();
template TF_DataType Tensor::deduce_type<int16_t>();
template TF_DataType Tensor::deduce_type<int32_t>();
template TF_DataType Tensor::deduce_type<int64_t>();
template TF_DataType Tensor::deduce_type<uint8_t>();
template TF_DataType Tensor::deduce_type<uint16_t>();
template TF_DataType Tensor::deduce_type<uint32_t>();
template TF_DataType Tensor::deduce_type<uint64_t>();

// VALID get_data TEMPLATES
template std::vector<float> Tensor::get_data<float>();
template std::vector<double> Tensor::get_data<double>();
template std::vector<bool> Tensor::get_data<bool>();
template std::vector<int8_t> Tensor::get_data<int8_t>();
template std::vector<int16_t> Tensor::get_data<int16_t>();
template std::vector<int32_t> Tensor::get_data<int32_t>();
template std::vector<int64_t> Tensor::get_data<int64_t>();
template std::vector<uint8_t> Tensor::get_data<uint8_t>();
template std::vector<uint16_t> Tensor::get_data<uint16_t>();
template std::vector<uint32_t> Tensor::get_data<uint32_t>();
template std::vector<uint64_t> Tensor::get_data<uint64_t>();


// VALID set_data TEMPLATES
template void Tensor::set_data<float>(std::vector<float> new_data);
template void Tensor::set_data<double>(std::vector<double> new_data);
//template void Tensor::set_data<bool>(std::vector<bool> new_data);
template void Tensor::set_data<int8_t>(std::vector<int8_t> new_data);
template void Tensor::set_data<int16_t>(std::vector<int16_t> new_data);
template void Tensor::set_data<int32_t>(std::vector<int32_t> new_data);
template void Tensor::set_data<int64_t>(std::vector<int64_t> new_data);
template void Tensor::set_data<uint8_t>(std::vector<uint8_t> new_data);
template void Tensor::set_data<uint16_t>(std::vector<uint16_t> new_data);
template void Tensor::set_data<uint32_t>(std::vector<uint32_t> new_data);
template void Tensor::set_data<uint64_t>(std::vector<uint64_t> new_data);


// VALID set_data TEMPLATES
template void Tensor::set_data<float>(std::vector<float> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<double>(std::vector<double> new_data, const std::vector<int64_t>& new_shape);
//template void Tensor::set_data<bool>(std::vector<bool> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<int8_t>(std::vector<int8_t> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<int16_t>(std::vector<int16_t> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<int32_t>(std::vector<int32_t> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<int64_t>(std::vector<int64_t> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<uint8_t>(std::vector<uint8_t> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<uint16_t>(std::vector<uint16_t> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<uint32_t>(std::vector<uint32_t> new_data, const std::vector<int64_t>& new_shape);
template void Tensor::set_data<uint64_t>(std::vector<uint64_t> new_data, const std::vector<int64_t>& new_shape);

0 comments on commit 45c3abd

Please sign in to comment.