Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OV JS] Expose export_model()/import_model() #23366

Merged
merged 13 commits into from
Mar 18, 2024
3 changes: 3 additions & 0 deletions src/bindings/js/node/include/compiled_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class CompiledModelWrap : public Napi::ObjectWrap<CompiledModelWrap> {
*/
Napi::Value get_inputs(const Napi::CallbackInfo& info);

/** @brief Exports the compiled model to bytes/output stream. */
Napi::Value export_model(const Napi::CallbackInfo& info);

private:
ov::CompiledModel _compiled_model;
};
3 changes: 3 additions & 0 deletions src/bindings/js/node/include/core_wrap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class CoreWrap : public Napi::ObjectWrap<CoreWrap> {
const Napi::String& device,
const std::map<std::string, ov::Any>& config);

/** @brief Imports a compiled model from the previously exported one. */
Napi::Value import_model(const Napi::CallbackInfo& info);

/** @brief Returns devices available for inference. */
Napi::Value get_available_devices(const Napi::CallbackInfo& info);

Expand Down
2 changes: 2 additions & 0 deletions src/bindings/js/node/lib/addon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ interface Core {
modelBuffer: Uint8Array, weightsBuffer?: Uint8Array): Promise<Model>;
readModelSync(modelPath: string, weightsPath?: string): Model;
readModelSync(modelBuffer: Uint8Array, weightsBuffer?: Uint8Array): Model;
importModelSync(modelStream: Buffer, device: string): CompiledModel;
getAvailableDevices(): string[];
}
interface CoreConstructor {
Expand All @@ -56,6 +57,7 @@ interface CompiledModel {
output(nameOrId?: string | number): Output;
input(nameOrId?: string | number): Output;
createInferRequest(): InferRequest;
exportModelSync(): Buffer;
}

interface Tensor {
Expand Down
10 changes: 9 additions & 1 deletion src/bindings/js/node/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ Napi::Function CompiledModelWrap::get_class(Napi::Env env) {
InstanceMethod("input", &CompiledModelWrap::get_input),
InstanceAccessor<&CompiledModelWrap::get_inputs>("inputs"),
InstanceMethod("output", &CompiledModelWrap::get_output),
InstanceAccessor<&CompiledModelWrap::get_outputs>("outputs")});
InstanceAccessor<&CompiledModelWrap::get_outputs>("outputs"),
InstanceMethod("exportModelSync", &CompiledModelWrap::export_model)});
}

Napi::Object CompiledModelWrap::wrap(Napi::Env env, ov::CompiledModel compiled_model) {
Expand Down Expand Up @@ -110,3 +111,10 @@ Napi::Value CompiledModelWrap::get_inputs(const Napi::CallbackInfo& info) {

return js_inputs;
}

Napi::Value CompiledModelWrap::export_model(const Napi::CallbackInfo& info) {
std::stringstream _stream;
_compiled_model.export_model(_stream);
const auto& exported = _stream.str();
return Napi::Buffer<const char>::Copy(info.Env(), exported.c_str(), exported.size());
}
23 changes: 23 additions & 0 deletions src/bindings/js/node/src/core_wrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Napi::Function CoreWrap::get_class(Napi::Env env) {
InstanceMethod("readModel", &CoreWrap::read_model_async),
InstanceMethod("compileModelSync", &CoreWrap::compile_model_sync_dispatch),
InstanceMethod("compileModel", &CoreWrap::compile_model_async),
InstanceMethod("importModelSync", &CoreWrap::import_model),
InstanceMethod("getAvailableDevices", &CoreWrap::get_available_devices)});
}

Expand Down Expand Up @@ -230,3 +231,25 @@ Napi::Value CoreWrap::get_available_devices(const Napi::CallbackInfo& info) {

return js_devices;
}

Napi::Value CoreWrap::import_model(const Napi::CallbackInfo& info) {
if (info.Length() != 2) {
reportError(info.Env(), "Invalid number of arguments -> " + std::to_string(info.Length()));
return info.Env().Undefined();
}
if (!info[0].IsBuffer()) {
reportError(info.Env(), "The first argument must be of type Buffer.");
return info.Env().Undefined();
}
if (!info[1].IsString()) {
reportError(info.Env(), "The second argument must be of type String.");
return info.Env().Undefined();
}
const auto& model_data = info[0].As<Napi::Buffer<uint8_t>>();
const auto model_stream = std::string(reinterpret_cast<char*>(model_data.Data()), model_data.Length());
std::stringstream _stream;
_stream << model_stream;

const auto& compiled = _core.import_model(_stream, std::string(info[1].ToString()));
return CompiledModelWrap::wrap(info.Env(), compiled);
}
14 changes: 14 additions & 0 deletions src/bindings/js/node/tests/basic.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,17 @@ describe('Input class for ov::Input<const ov::Node>', () => {
});

});

it('Test exportModel()/importModel()', () => {
const userStream = compiledModel.exportModelSync();
const newCompiled = core.importModelSync(userStream, 'CPU');
const epsilon = 0.5;
const tensor = Float32Array.from({ length: 3072 }, () => (Math.random() + epsilon));

const inferRequest = compiledModel.createInferRequest();
const res1 = inferRequest.infer([tensor]);
const newInferRequest = newCompiled.createInferRequest();
const res2 = newInferRequest.infer([tensor]);

assert.deepStrictEqual(res1['fc_out'].data[0], res2['fc_out'].data[0]);
});
Loading