Skip to content

Commit

Permalink
[JS API] Extract code from CompiledModel getters (openvinotoolkit#23515)
Browse files Browse the repository at this point in the history
### Details:
- Extract the same logic structure from `CompileModel::input` and
`CompileModel::output`
- Add a private `CompileModel::get_node` method that gets the specified
input or output node.

Note:
No changes to argument validation or conversion.

### Tickets:
 - *127617*
  • Loading branch information
almilosz authored and alvoron committed Apr 29, 2024
1 parent ada7d82 commit 99efeb4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 38 deletions.
6 changes: 6 additions & 0 deletions src/bindings/js/node/include/compiled_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,11 @@ class CompiledModelWrap : public Napi::ObjectWrap<CompiledModelWrap> {
Napi::Value export_model(const Napi::CallbackInfo& info);

private:
/** @brief Gets node of a compiled model specified in CallbackInfo. */
Napi::Value get_node(const Napi::CallbackInfo& info,
const ov::Output<const ov::Node>& (ov::CompiledModel::*func)() const,
const ov::Output<const ov::Node>& (ov::CompiledModel::*func_tname)(const std::string&)const,
const ov::Output<const ov::Node>& (ov::CompiledModel::*func_idx)(size_t) const);

ov::CompiledModel _compiled_model;
};
80 changes: 42 additions & 38 deletions src/bindings/js/node/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,17 @@ Napi::Value CompiledModelWrap::create_infer_request(const Napi::CallbackInfo& in
}

Napi::Value CompiledModelWrap::get_output(const Napi::CallbackInfo& info) {
if (info.Length() == 0) {
try {
return Output<const ov::Node>::wrap(info.Env(), _compiled_model.output());
} catch (std::exception& e) {
reportError(info.Env(), e.what());
return Napi::Value();
}
} else if (info.Length() != 1) {
reportError(info.Env(), "Invalid number of arguments -> " + std::to_string(info.Length()));
return Napi::Value();
} else if (info[0].IsString()) {
auto tensor_name = info[0].ToString();
return Output<const ov::Node>::wrap(info.Env(), _compiled_model.output(tensor_name));
} else if (info[0].IsNumber()) {
auto idx = info[0].As<Napi::Number>().Int32Value();
return Output<const ov::Node>::wrap(info.Env(), _compiled_model.output(idx));
} else {
reportError(info.Env(), "Error while getting compiled model outputs.");
return Napi::Value();
try {
return get_node(
info,
static_cast<const ov::Output<const ov::Node>& (ov::CompiledModel::*)() const>(&ov::CompiledModel::output),
static_cast<const ov::Output<const ov::Node>& (ov::CompiledModel::*)(const std::string&)const>(
&ov::CompiledModel::output),
static_cast<const ov::Output<const ov::Node>& (ov::CompiledModel::*)(size_t) const>(
&ov::CompiledModel::output));
} catch (std::exception& e) {
reportError(info.Env(), e.what() + std::string("outputs."));
return info.Env().Null();
}
}

Expand All @@ -79,25 +71,17 @@ Napi::Value CompiledModelWrap::get_outputs(const Napi::CallbackInfo& info) {
}

Napi::Value CompiledModelWrap::get_input(const Napi::CallbackInfo& info) {
if (info.Length() == 0) {
try {
return Output<const ov::Node>::wrap(info.Env(), _compiled_model.input());
} catch (std::exception& e) {
reportError(info.Env(), e.what());
return Napi::Value();
}
} else if (info.Length() != 1) {
reportError(info.Env(), "Invalid number of arguments -> " + std::to_string(info.Length()));
return Napi::Value();
} else if (info[0].IsString()) {
auto tensor_name = info[0].ToString();
return Output<const ov::Node>::wrap(info.Env(), _compiled_model.input(tensor_name));
} else if (info[0].IsNumber()) {
auto idx = info[0].As<Napi::Number>().Int32Value();
return Output<const ov::Node>::wrap(info.Env(), _compiled_model.input(idx));
} else {
reportError(info.Env(), "Error while getting compiled model inputs.");
return Napi::Value();
try {
return get_node(
info,
static_cast<const ov::Output<const ov::Node>& (ov::CompiledModel::*)() const>(&ov::CompiledModel::input),
static_cast<const ov::Output<const ov::Node>& (ov::CompiledModel::*)(const std::string&)const>(
&ov::CompiledModel::input),
static_cast<const ov::Output<const ov::Node>& (ov::CompiledModel::*)(size_t) const>(
&ov::CompiledModel::input));
} catch (std::exception& e) {
reportError(info.Env(), e.what() + std::string("inputs."));
return info.Env().Null();
}
}

Expand All @@ -112,6 +96,26 @@ Napi::Value CompiledModelWrap::get_inputs(const Napi::CallbackInfo& info) {
return js_inputs;
}

Napi::Value CompiledModelWrap::get_node(
const Napi::CallbackInfo& info,
const ov::Output<const ov::Node>& (ov::CompiledModel::*func)() const,
const ov::Output<const ov::Node>& (ov::CompiledModel::*func_tname)(const std::string&)const,
const ov::Output<const ov::Node>& (ov::CompiledModel::*func_idx)(size_t) const) {
if (info.Length() == 0) {
return Output<const ov::Node>::wrap(info.Env(), (_compiled_model.*func)());
} else if (info.Length() != 1) {
OPENVINO_THROW(std::string("Invalid number of arguments."));
} else if (info[0].IsString()) {
auto tensor_name = info[0].ToString();
return Output<const ov::Node>::wrap(info.Env(), (_compiled_model.*func_tname)(tensor_name));
} else if (info[0].IsNumber()) {
auto idx = info[0].As<Napi::Number>().Int32Value();
return Output<const ov::Node>::wrap(info.Env(), (_compiled_model.*func_idx)(idx));
} else {
OPENVINO_THROW(std::string("Error while getting compiled model "));
}
}

Napi::Value CompiledModelWrap::export_model(const Napi::CallbackInfo& info) {
std::stringstream _stream;
_compiled_model.export_model(_stream);
Expand Down

0 comments on commit 99efeb4

Please sign in to comment.