Skip to content

Commit

Permalink
Add on-device tex-to-speech (TTS) demo for HarmonyOS (#1590)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 4, 2024
1 parent 47a2dd4 commit 74a8735
Show file tree
Hide file tree
Showing 61 changed files with 1,930 additions and 117 deletions.
1 change: 1 addition & 0 deletions harmony-os/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
!build-profile.json5
*.har
6 changes: 5 additions & 1 deletion harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
export { readWave, readWaveFromBinary } from "libsherpa_onnx.so";
export {
listRawfileDir,
readWave,
readWaveFromBinary,
} from "libsherpa_onnx.so";

export {
CircularBuffer,
Expand Down
2 changes: 1 addition & 1 deletion harmony-os/SherpaOnnxHar/sherpa_onnx/build-profile.json5
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"externalNativeOptions": {
"path": "./src/main/cpp/CMakeLists.txt",
"arguments": "",
"cppFlags": "",
"cppFlags": "-std=c++17",
"abiFilters": [
"arm64-v8a",
"x86_64",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
cmake_minimum_required(VERSION 3.13.0)
project(myNpmLib)

if (NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ version to use")
endif()

# Disable warning about
#
# "The DOWNLOAD_EXTRACT_TIMESTAMP option was not given and policy CMP0135 is
Expand Down Expand Up @@ -46,6 +50,7 @@ add_library(sherpa_onnx SHARED
speaker-identification.cc
spoken-language-identification.cc
streaming-asr.cc
utils.cc
vad.cc
wave-reader.cc
wave-writer.cc
Expand Down
260 changes: 257 additions & 3 deletions harmony-os/SherpaOnnxHar/sherpa_onnx/src/main/cpp/non-streaming-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,13 @@ static Napi::Number OfflineTtsNumSpeakersWrapper(
return Napi::Number::New(env, num_speakers);
}

// synchronous version
static Napi::Object OfflineTtsGenerateWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
os << "Expect only 2 arguments. Given: " << info.Length();

Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();

Expand Down Expand Up @@ -298,8 +299,8 @@ static Napi::Object OfflineTtsGenerateWrapper(const Napi::CallbackInfo &info) {
int32_t sid = obj.Get("sid").As<Napi::Number>().Int32Value();
float speed = obj.Get("speed").As<Napi::Number>().FloatValue();

const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerate(tts, text.c_str(), sid, speed);
const SherpaOnnxGeneratedAudio *audio;
audio = SherpaOnnxOfflineTtsGenerate(tts, text.c_str(), sid, speed);

if (enable_external_buffer) {
Napi::ArrayBuffer arrayBuffer = Napi::ArrayBuffer::New(
Expand Down Expand Up @@ -334,6 +335,256 @@ static Napi::Object OfflineTtsGenerateWrapper(const Napi::CallbackInfo &info) {
}
}

struct TtsCallbackData {
std::vector<float> samples;
float progress;
bool processed = false;
bool cancelled = false;
};

// see
// https://github.com/nodejs/node-addon-examples/blob/main/src/6-threadsafe-function/typed_threadsafe_function/node-addon-api/clock.cc
void InvokeJsCallback(Napi::Env env, Napi::Function callback,
Napi::Reference<Napi::Value> *context,
TtsCallbackData *data) {
if (env != nullptr) {
if (callback != nullptr) {
Napi::ArrayBuffer arrayBuffer =
Napi::ArrayBuffer::New(env, sizeof(float) * data->samples.size());

Napi::Float32Array float32Array =
Napi::Float32Array::New(env, data->samples.size(), arrayBuffer, 0);

std::copy(data->samples.begin(), data->samples.end(),
float32Array.Data());

Napi::Object arg = Napi::Object::New(env);
arg.Set(Napi::String::New(env, "samples"), float32Array);
arg.Set(Napi::String::New(env, "progress"), data->progress);

auto v = callback.Call(context->Value(), {arg});
data->processed = true;
if (v.IsNumber() && v.As<Napi::Number>().Int32Value()) {
data->cancelled = false;
} else {
data->cancelled = true;
}
}
}
}

using TSFN = Napi::TypedThreadSafeFunction<Napi::Reference<Napi::Value>,
TtsCallbackData, InvokeJsCallback>;

class TtsGenerateWorker : public Napi::AsyncWorker {
public:
TtsGenerateWorker(const Napi::Env &env, TSFN tsfn, SherpaOnnxOfflineTts *tts,
const std::string &text, float speed, int32_t sid,
bool use_external_buffer)
: tsfn_(tsfn),
Napi::AsyncWorker{env, "TtsGenerateWorker"},
deferred_(env),
tts_(tts),
text_(text),
speed_(speed),
sid_(sid),
use_external_buffer_(use_external_buffer) {}

Napi::Promise Promise() { return deferred_.Promise(); }

~TtsGenerateWorker() {
for (auto d : data_list_) {
delete d;
}
}

protected:
void Execute() override {
auto callback = [](const float *samples, int32_t n, float progress,
void *arg) -> int32_t {
TtsGenerateWorker *_this = reinterpret_cast<TtsGenerateWorker *>(arg);

for (auto d : _this->data_list_) {
if (d->cancelled) {
OH_LOG_INFO(LOG_APP, "TtsGenerate is cancelled");
return 0;
}
}

auto data = new TtsCallbackData;
data->samples = std::vector<float>{samples, samples + n};
data->progress = progress;
_this->data_list_.push_back(data);

_this->tsfn_.NonBlockingCall(data);

return 1;
};
audio_ = SherpaOnnxOfflineTtsGenerateWithProgressCallbackWithArg(
tts_, text_.c_str(), sid_, speed_, callback, this);

tsfn_.Release();
}

void OnOK() override {
Napi::Env env = deferred_.Env();
Napi::Object ans = Napi::Object::New(env);
if (use_external_buffer_) {
Napi::ArrayBuffer arrayBuffer = Napi::ArrayBuffer::New(
env, const_cast<float *>(audio_->samples), sizeof(float) * audio_->n,
[](Napi::Env /*env*/, void * /*data*/,
const SherpaOnnxGeneratedAudio *hint) {
SherpaOnnxDestroyOfflineTtsGeneratedAudio(hint);
},
audio_);
Napi::Float32Array float32Array =
Napi::Float32Array::New(env, audio_->n, arrayBuffer, 0);

ans.Set(Napi::String::New(env, "samples"), float32Array);
ans.Set(Napi::String::New(env, "sampleRate"), audio_->sample_rate);
} else {
// don't use external buffer
Napi::ArrayBuffer arrayBuffer =
Napi::ArrayBuffer::New(env, sizeof(float) * audio_->n);

Napi::Float32Array float32Array =
Napi::Float32Array::New(env, audio_->n, arrayBuffer, 0);

std::copy(audio_->samples, audio_->samples + audio_->n,
float32Array.Data());

ans.Set(Napi::String::New(env, "samples"), float32Array);
ans.Set(Napi::String::New(env, "sampleRate"), audio_->sample_rate);
SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio_);
}

deferred_.Resolve(ans);
}

private:
TSFN tsfn_;
Napi::Promise::Deferred deferred_;
SherpaOnnxOfflineTts *tts_;
std::string text_;
float speed_;
int32_t sid_;
bool use_external_buffer_;

const SherpaOnnxGeneratedAudio *audio_;

std::vector<TtsCallbackData *> data_list_;
};

static Napi::Object OfflineTtsGenerateAsyncWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();

Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();

return {};
}

if (!info[0].IsExternal()) {
Napi::TypeError::New(env, "Argument 0 should be an offline tts pointer.")
.ThrowAsJavaScriptException();

return {};
}

SherpaOnnxOfflineTts *tts =
info[0].As<Napi::External<SherpaOnnxOfflineTts>>().Data();

if (!info[1].IsObject()) {
Napi::TypeError::New(env, "Argument 1 should be an object")
.ThrowAsJavaScriptException();

return {};
}

Napi::Object obj = info[1].As<Napi::Object>();

if (!obj.Has("text")) {
Napi::TypeError::New(env, "The argument object should have a field text")
.ThrowAsJavaScriptException();

return {};
}

if (!obj.Get("text").IsString()) {
Napi::TypeError::New(env, "The object['text'] should be a string")
.ThrowAsJavaScriptException();

return {};
}

if (!obj.Has("sid")) {
Napi::TypeError::New(env, "The argument object should have a field sid")
.ThrowAsJavaScriptException();

return {};
}

if (!obj.Get("sid").IsNumber()) {
Napi::TypeError::New(env, "The object['sid'] should be a number")
.ThrowAsJavaScriptException();

return {};
}

if (!obj.Has("speed")) {
Napi::TypeError::New(env, "The argument object should have a field speed")
.ThrowAsJavaScriptException();

return {};
}

if (!obj.Get("speed").IsNumber()) {
Napi::TypeError::New(env, "The object['speed'] should be a number")
.ThrowAsJavaScriptException();

return {};
}

bool enable_external_buffer = true;
if (obj.Has("enableExternalBuffer") &&
obj.Get("enableExternalBuffer").IsBoolean()) {
enable_external_buffer =
obj.Get("enableExternalBuffer").As<Napi::Boolean>().Value();
}

Napi::String _text = obj.Get("text").As<Napi::String>();
std::string text = _text.Utf8Value();
int32_t sid = obj.Get("sid").As<Napi::Number>().Int32Value();
float speed = obj.Get("speed").As<Napi::Number>().FloatValue();

Napi::Function cb;
if (obj.Has("callback") && obj.Get("callback").IsFunction()) {
cb = obj.Get("callback").As<Napi::Function>();
}

auto context =
new Napi::Reference<Napi::Value>(Napi::Persistent(info.This()));

TSFN tsfn = TSFN::New(
env,
cb, // JavaScript function called asynchronously
"TtsGenerateFunc", // Name
0, // Unlimited queue
1, // Only one thread will use this initially
context,
[](Napi::Env, void *, Napi::Reference<Napi::Value> *ctx) { delete ctx; });

const SherpaOnnxGeneratedAudio *audio;
TtsGenerateWorker *worker = new TtsGenerateWorker(
env, tsfn, tts, text, speed, sid, enable_external_buffer);
worker->Queue();
return worker->Promise();
}

void InitNonStreamingTts(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "createOfflineTts"),
Napi::Function::New(env, CreateOfflineTtsWrapper));
Expand All @@ -346,4 +597,7 @@ void InitNonStreamingTts(Napi::Env env, Napi::Object exports) {

exports.Set(Napi::String::New(env, "offlineTtsGenerate"),
Napi::Function::New(env, OfflineTtsGenerateWrapper));

exports.Set(Napi::String::New(env, "offlineTtsGenerateAsync"),
Napi::Function::New(env, OfflineTtsGenerateAsyncWrapper));
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports);

void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports);

#if __OHOS__
void InitUtils(Napi::Env env, Napi::Object exports);
#endif

Napi::Object Init(Napi::Env env, Napi::Object exports) {
InitStreamingAsr(env, exports);
InitNonStreamingAsr(env, exports);
Expand All @@ -41,7 +45,15 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
InitKeywordSpotting(env, exports);
InitNonStreamingSpeakerDiarization(env, exports);

#if __OHOS__
InitUtils(env, exports);
#endif

return exports;
}

#if __OHOS__
NODE_API_MODULE(sherpa_onnx, Init)
#else
NODE_API_MODULE(addon, Init)
#endif
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export const listRawfileDir: (mgr: object, dir: string) => Array<string>;

export const readWave: (filename: string, enableExternalBuffer: boolean = true) => {samples: Float32Array, sampleRate: number};
export const readWaveFromBinary: (data: Uint8Array, enableExternalBuffer: boolean = true) => {samples: Float32Array, sampleRate: number};
export const createCircularBuffer: (capacity: number) => object;
Expand Down Expand Up @@ -37,4 +39,11 @@ export const getOnlineStreamResultAsJson: (handle: object, streamHandle: object)
export const createOfflineTts: (config: object, mgr?: object) => object;
export const getOfflineTtsNumSpeakers: (handle: object) => number;
export const getOfflineTtsSampleRate: (handle: object) => number;
export const offlineTtsGenerate: (handle: object, input: object) => object;

export type TtsOutput = {
samples: Float32Array;
sampleRate: number;
};

export const offlineTtsGenerate: (handle: object, input: object) => TtsOutput;
export const offlineTtsGenerateAsync: (handle: object, input: object) => Promise<TtsOutput>;
Loading

0 comments on commit 74a8735

Please sign in to comment.