-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
806 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Copyright (c) 2020 Gemfield <gemfield@civilnet.cn> | ||
* This file is part of libdeepvac, licensed under the GPLv3 (the "License") | ||
* You may not use this file except in compliance with the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <memory> | ||
#include <vector> | ||
#include <string> | ||
#include "gemfield.h" | ||
#include "syszux_tensorrt_buffers.h" | ||
#include "NvInfer.h" | ||
|
||
class TrtLogger : public nvinfer1::ILogger { | ||
void log(Severity severity, const char* msg) override{ | ||
// suppress info-level messages | ||
if (severity != Severity::kVERBOSE){ | ||
std::cout << msg << std::endl; | ||
} | ||
} | ||
}; | ||
|
||
struct InferDeleter{ | ||
template <typename T> | ||
void operator()(T* obj) const{ | ||
if (obj){ | ||
obj->destroy(); | ||
} | ||
} | ||
}; | ||
|
||
namespace deepvac { | ||
class DeepvacNV{ | ||
template <typename T> | ||
using SampleUniquePtr = std::unique_ptr<T, InferDeleter>; | ||
|
||
public: | ||
DeepvacNV() = default; | ||
explicit DeepvacNV(const char* model_path, std::string device); | ||
explicit DeepvacNV(std::string model_path, std::string device):DeepvacNV(model_path.c_str(), device){} | ||
explicit DeepvacNV(std::vector<unsigned char>&& buffer, std::string device); | ||
DeepvacNV(const DeepvacNV& rhs) = delete; | ||
DeepvacNV& operator=(const DeepvacNV& rhs) = delete; | ||
DeepvacNV(DeepvacNV&&) = default; | ||
DeepvacNV& operator=(DeepvacNV&&) = default; | ||
virtual ~DeepvacNV() = default; | ||
|
||
public: | ||
void setDevice(std::string device){device_ = device;} | ||
void setModel(const char* model_path); | ||
virtual void setBinding(int io_num); | ||
void** forward(void** data) { | ||
bool s = trt_context_->executeV2(data); | ||
return data; | ||
} | ||
|
||
protected: | ||
//all data members must be movable !! | ||
//all data members need dynamic memory must be managed by smart ptr !! | ||
SampleUniquePtr<nvinfer1::ICudaEngine> trt_module_; | ||
SampleUniquePtr<nvinfer1::IExecutionContext> trt_context_; | ||
template <typename T> | ||
SampleUniquePtr<T> makeUnique(T* t){ | ||
return SampleUniquePtr<T>{t}; | ||
} | ||
TrtLogger logger_; | ||
std::vector<gemfield_org::ManagedBuffer> datas_; | ||
std::string device_; | ||
}; | ||
|
||
}// namespace deepvac |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/* | ||
* Copyright (c) 2020 Gemfield <gemfield@civilnet.cn> | ||
* This file is part of libdeepvac, licensed under the GPLv3 (the "License") | ||
* You may not use this file except in compliance with the License. | ||
*/ | ||
#include <type_traits> | ||
#include <chrono> | ||
#include <ctime> | ||
#include <iostream> | ||
#include <cassert> | ||
#include "deepvac_nv.h" | ||
|
||
namespace deepvac { | ||
DeepvacNV::DeepvacNV(const char* path, std::string device){ | ||
GEMFIELD_SI; | ||
auto start = std::chrono::system_clock::now(); | ||
try{ | ||
device_ = device; | ||
setModel(path); | ||
}catch(...){ | ||
std::string msg = "Internal ERROR!"; | ||
GEMFIELD_E(msg.c_str()); | ||
throw std::runtime_error(msg); | ||
} | ||
std::chrono::duration<double> model_loading_duration = std::chrono::system_clock::now() - start; | ||
std::string msg = gemfield_org::format("NV Model loading time: %f", model_loading_duration.count()); | ||
GEMFIELD_DI(msg.c_str()); | ||
} | ||
|
||
DeepvacNV::DeepvacNV(std::vector<unsigned char>&& buffer, std::string device){ | ||
GEMFIELD_SI; | ||
auto start = std::chrono::system_clock::now(); | ||
try{ | ||
device_ = device; | ||
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger_); | ||
trt_module_ = makeUnique(runtime->deserializeCudaEngine((void*)buffer.data(), buffer.size(), nullptr)); | ||
assert(trt_module_ != nullptr); | ||
trt_context_ = makeUnique(trt_module_->createExecutionContext()); | ||
assert(trt_context_ != nullptr); | ||
runtime->destroy(); | ||
runtime = nullptr; | ||
}catch(...){ | ||
std::string msg = "Internal ERROR!"; | ||
GEMFIELD_E(msg.c_str()); | ||
throw std::runtime_error(msg); | ||
} | ||
std::chrono::duration<double> model_loading_duration = std::chrono::system_clock::now() - start; | ||
std::string msg = gemfield_org::format("NV Model loading time: %f", model_loading_duration.count()); | ||
GEMFIELD_DI(msg.c_str()); | ||
} | ||
|
||
void DeepvacNV::setBinding(int io_num) { | ||
for(int i = 0; i < io_num; ++i) { | ||
gemfield_org::ManagedBuffer buffer{}; | ||
datas_.emplace_back(std::move(buffer)); | ||
} | ||
} | ||
|
||
void DeepvacNV::setModel(const char* model_path) { | ||
std::ifstream in(model_path, std::ifstream::binary); | ||
if(in.is_open()) { | ||
auto const start_pos = in.tellg(); | ||
in.ignore(std::numeric_limits<std::streamsize>::max()); | ||
size_t bufCount = in.gcount(); | ||
in.seekg(start_pos); | ||
std::unique_ptr<char[]> engineBuf(new char[bufCount]); | ||
in.read(engineBuf.get(), bufCount); | ||
//initLibNvInferPlugins(&logger_, ""); | ||
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger_); | ||
trt_module_ = makeUnique(runtime->deserializeCudaEngine((void*)engineBuf.get(), bufCount, nullptr)); | ||
assert(trt_module_ != nullptr); | ||
trt_context_ = makeUnique(trt_module_->createExecutionContext()); | ||
assert(trt_context_ != nullptr); | ||
//mBatchSize = trt_module_->getMaxBatchSize(); | ||
//spdlog::info("max batch size of deserialized engine: {}",mEngine->getMaxBatchSize()); | ||
runtime->destroy(); | ||
runtime = nullptr; | ||
} | ||
} | ||
} //namespace deepvac |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) 2020 Gemfield <gemfield@civilnet.cn> | ||
* This file is part of libdeepvac, licensed under the GPLv3 (the "License") | ||
* You may not use this file except in compliance with the License. | ||
*/ | ||
|
||
#include <chrono> | ||
#include "syszux_face_retina_nv.h" | ||
#include "syszux_face_reg_nv.h" | ||
#include "syszux_img2tensor.h" | ||
|
||
using namespace deepvac; | ||
int main(int argc, const char* argv[]) { | ||
if (argc != 3) { | ||
GEMFIELD_E("usage: deepvac <device> <img_path>"); | ||
return -1; | ||
} | ||
|
||
std::string device = argv[1]; | ||
std::string img_path = argv[2]; | ||
SyszuxFaceRetinaNV face_detect("detect.trt", device); | ||
SyszuxFaceRegNV face_reg("reg.trt", device); | ||
|
||
auto start = std::chrono::system_clock::now(); | ||
|
||
for(int i = 0; i < 155; ++i) { | ||
auto start1 = std::chrono::system_clock::now(); | ||
std::string img_name = img_path + std::to_string(i*10) + ".jpg"; | ||
auto mat_opt = gemfield_org::img2CvMat(img_name); | ||
|
||
if(!mat_opt){ | ||
throw std::runtime_error("illegal image detected"); | ||
return 1; | ||
} | ||
|
||
auto mat_out = mat_opt.value(); | ||
auto detect_out_opt = face_detect.process(mat_out); | ||
std::chrono::duration<double> model_loading_duration_d = std::chrono::system_clock::now() - start1; | ||
std::string msg = gemfield_org::format("Model loading time: %f", model_loading_duration_d.count()); | ||
std::cout << msg << std::endl; | ||
|
||
if(detect_out_opt){ | ||
face_reg.process(detect_out_opt); | ||
} | ||
|
||
std::chrono::duration<double> model_loading_duration_d1 = std::chrono::system_clock::now() - start1; | ||
std::string msg1 = gemfield_org::format("Model loading time: %f", model_loading_duration_d1.count()); | ||
std::cout << msg1 << std::endl; | ||
} | ||
|
||
std::chrono::duration<double> model_loading_duration = std::chrono::system_clock::now() - start; | ||
std::string msg = gemfield_org::format("Model loading time: %f", model_loading_duration.count()); | ||
std::cout << msg << std::endl; | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#pragma once | ||
|
||
#include <tuple> | ||
#include <vector> | ||
#include "deepvac_nv.h" | ||
#include "syszux_tensorrt_buffers.h" | ||
#include "syszux_img2tensor.h" | ||
|
||
namespace deepvac{ | ||
class SyszuxFaceRegNV : public DeepvacNV{ | ||
public: | ||
SyszuxFaceRegNV(std::string path, std::string device = "cpu"); | ||
SyszuxFaceRegNV(std::vector<unsigned char>&& buffer, std::string device = "cpu"); | ||
SyszuxFaceRegNV(const SyszuxFaceRegNV&) = delete; | ||
SyszuxFaceRegNV& operator=(const SyszuxFaceRegNV&) = delete; | ||
SyszuxFaceRegNV(SyszuxFaceRegNV&&) = default; | ||
SyszuxFaceRegNV& operator=(SyszuxFaceRegNV&&) = default; | ||
virtual ~SyszuxFaceRegNV() = default; | ||
virtual std::tuple<int, std::string, float> process(cv::Mat& frame); | ||
void initBinding(); | ||
}; | ||
}//namespace | ||
|
||
|
Oops, something went wrong.