Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Add Model::ReadConfig & simplify handle creation #1738

Merged
merged 1 commit into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,6 @@
using namespace mmdeploy;
using namespace std;

namespace {

Value config_template(const Model& model) {
// clang-format off
static Value v{
{
"pipeline", {
{"input", {"img"}},
{"output", {"cls"}},
{
"tasks", {
{
{"name", "classifier"},
{"type", "Inference"},
{"params", {{"model", "TBD"}}},
{"input", {"img"}},
{"output", {"cls"}}
}
}
}
}
}
};
// clang-format on
auto config = v;
config["pipeline"]["tasks"][0]["params"]["model"] = model;
return config;
}

} // namespace

int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_classifier_t* classifier) {
mmdeploy_context_t context{};
Expand All @@ -73,8 +42,7 @@ int mmdeploy_classifier_create_by_path(const char* model_path, const char* devic

int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_classifier_t* classifier) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)classifier);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)classifier);
}

int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count,
Expand Down
19 changes: 1 addition & 18 deletions csrc/mmdeploy/apis/c/mmdeploy/detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,11 @@
using namespace std;
using namespace mmdeploy;

namespace {

Value config_template(Model model) {
// clang-format off
return {
{"name", "detector"},
{"type", "Inference"},
{"params", {{"model", std::move(model)}}},
{"input", {"image"}},
{"output", {"dets"}}
};
// clang-format on
}

using ResultType = mmdeploy::Structure<mmdeploy_detection_t, //
std::vector<int>, //
std::deque<mmdeploy_instance_mask_t>, //
std::vector<mmdeploy::framework::Buffer>>; //

} // namespace

int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_detector_t* detector) {
mmdeploy_context_t context{};
Expand All @@ -54,8 +38,7 @@ int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, in

int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_detector_t* detector) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)detector);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector);
}

int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, int device_id,
Expand Down
9 changes: 9 additions & 0 deletions csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t cont
return MMDEPLOY_E_FAIL;
}

int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_pipeline_t* pipeline) {
auto config = Cast(model)->ReadConfig("pipeline.json");
auto _context = *Cast(context);
_context["model"] = *Cast(model);
return mmdeploy_pipeline_create_v3(Cast(&config.value()), (mmdeploy_context_t)&_context,
pipeline);
}

int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, mmdeploy_sender_t input,
mmdeploy_sender_t* output) {
if (!pipeline || !input || !output) {
Expand Down
11 changes: 11 additions & 0 deletions csrc/mmdeploy/apis/c/mmdeploy/pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "mmdeploy/common.h"
#include "mmdeploy/executor.h"
#include "mmdeploy/model.h"

#ifdef __cplusplus
extern "C" {
Expand All @@ -24,6 +25,16 @@ typedef struct mmdeploy_pipeline* mmdeploy_pipeline_t;
*/
MMDEPLOY_API int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context,
mmdeploy_pipeline_t* pipeline);
/**
* Create pipeline from internal pipeline config of the model
* @param model
* @param context
* @param pipeline
* @return
*/
MMDEPLOY_API int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model,
mmdeploy_context_t context,
mmdeploy_pipeline_t* pipeline);

/**
* @brief Apply pipeline
Expand Down
19 changes: 1 addition & 18 deletions csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,6 @@
using namespace std;
using namespace mmdeploy;

namespace {

Value config_template(const Model& model) {
// clang-format off
return {
{"name", "pose-detector"},
{"type", "Inference"},
{"params", {{"model", model}, {"batch_size", 1}}},
{"input", {"image"}},
{"output", {"dets"}}
};
// clang-format on
}

} // namespace

int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_pose_detector_t* detector) {
mmdeploy_context_t context{};
Expand Down Expand Up @@ -95,8 +79,7 @@ void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector) {

int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_pose_detector_t* detector) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)detector);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector);
}

int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count,
Expand Down
19 changes: 1 addition & 18 deletions csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,8 @@

using namespace mmdeploy;

namespace {

Value config_template(const Model& model) {
// clang-format off
return {
{"name", "restorer"},
{"type", "Inference"},
{"params", {{"model", model}}},
{"input", {"img"}},
{"output", {"out"}}
};
// clang-format on
}

using ResultType = mmdeploy::Structure<mmdeploy_mat_t, mmdeploy::framework::Buffer>;

} // namespace

int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_restorer_t* restorer) {
mmdeploy_context_t context{};
Expand Down Expand Up @@ -81,8 +65,7 @@ void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer) {

int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_restorer_t* restorer) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)restorer);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)restorer);
}

int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count,
Expand Down
19 changes: 1 addition & 18 deletions csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,6 @@
using namespace std;
using namespace mmdeploy;

namespace {

Value config_template(const Model& model) {
// clang-format off
return {
{"name", "mmrotate"},
{"type", "Inference"},
{"params", {{"model", model}}},
{"input", {"image"}},
{"output", {"det"}}
};
// clang-format on
}

} // namespace

int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_rotated_detector_t* detector) {
mmdeploy_context_t context{};
Expand Down Expand Up @@ -84,8 +68,7 @@ void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector) {

int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_rotated_detector_t* detector) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)detector);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector);
}

int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count,
Expand Down
19 changes: 1 addition & 18 deletions csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,8 @@
using namespace std;
using namespace mmdeploy;

namespace {

Value config_template(const Model& model) {
// clang-format off
return {
{"name", "segmentor"},
{"type", "Inference"},
{"params", {{"model", model}}},
{"input", {"img"}},
{"output", {"mask"}}
};
// clang-format on
}

using ResultType = mmdeploy::Structure<mmdeploy_segmentation_t, mmdeploy::framework::Buffer>;

} // namespace

int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_segmentor_t* segmentor) {
mmdeploy_context_t context{};
Expand Down Expand Up @@ -83,8 +67,7 @@ void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor) {

int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_segmentor_t* segmentor) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)segmentor);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)segmentor);
}

int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count,
Expand Down
19 changes: 1 addition & 18 deletions csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,6 @@
using namespace std;
using namespace mmdeploy;

namespace {

Value config_template(const Model& model) {
// clang-format off
return {
{"name", "detector"},
{"type", "Inference"},
{"params", {{"model", model}}},
{"input", {"img"}},
{"output", {"dets"}}
};
// clang-format on
}

} // namespace

int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_text_detector_t* detector) {
mmdeploy_context_t context{};
Expand All @@ -46,8 +30,7 @@ int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_nam

int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_text_detector_t* detector) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)detector);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector);
}

int mmdeploy_text_detector_create_by_path(const char* model_path, const char* device_name,
Expand Down
2 changes: 1 addition & 1 deletion csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Value config_template(const Model& model) {
{"type", "Inference"},
{"input", "patches"},
{"output", "texts"},
{"params", {{"model", std::move(model)}}},
{"params", {{"model", model}}},
}
}
},
Expand Down
27 changes: 1 addition & 26 deletions csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,6 @@

using namespace mmdeploy;

namespace {
Value config_template(const Model& model) {
// clang-format off
return {
{"type", "Pipeline"},
{"input", {"video"}},
{
"tasks", {
{
{"name", "Video Recognizer"},
{"type", "Inference"},
{"input", "video"},
{"output", "label"},
{"params", {{"model", std::move(model)}}},
}
}
},
{"output", "label"},
};
// clang-format on
}

} // namespace

int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id,
mmdeploy_video_recognizer_t* recognizer) {
mmdeploy_context_t context{};
Expand Down Expand Up @@ -101,8 +77,7 @@ void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer) {

int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context,
mmdeploy_video_recognizer_t* recognizer) {
auto config = config_template(*Cast(model));
return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)recognizer);
return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)recognizer);
}

int mmdeploy_video_recognizer_create_input(const mmdeploy_mat_t* images,
Expand Down
4 changes: 4 additions & 0 deletions csrc/mmdeploy/core/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ Result<std::string> Model::ReadFile(const std::string& file_path) noexcept {
return impl_->ReadFile(file_path);
}

Result<Value> Model::ReadConfig(const string& config_path) noexcept {
return impl_->ReadConfig(config_path);
}

MMDEPLOY_DEFINE_REGISTRY(ModelImpl);

} // namespace mmdeploy::framework
3 changes: 3 additions & 0 deletions csrc/mmdeploy/core/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mmdeploy/core/mpl/type_traits.h"
#include "mmdeploy/core/serialization.h"
#include "mmdeploy/core/types.h"
#include "mmdeploy/core/value.h"

namespace mmdeploy {

Expand Down Expand Up @@ -73,6 +74,8 @@ class MMDEPLOY_API Model {
*/
Result<std::string> ReadFile(const std::string& file_path) noexcept;

Result<Value> ReadConfig(const std::string& config_path) noexcept;

/**
* @brief get meta information of the model
* @return SDK model's meta information
Expand Down
2 changes: 2 additions & 0 deletions csrc/mmdeploy/core/model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class ModelImpl {
*/
virtual Result<std::string> ReadFile(const std::string& file_path) const = 0;

virtual Result<Value> ReadConfig(const std::string& config_path) const = 0;

/**
* @brief get meta information of an sdk model
* @return SDK model's meta information
Expand Down
Loading