Skip to content

Commit

Permalink
Add speaker identification APIs for HarmonyOS (#1607)
Browse files Browse the repository at this point in the history
* Add speaker embedding extractor API for HarmonyOS

* Add ArkTS API for speaker identification
  • Loading branch information
csukuangfj authored Dec 9, 2024
1 parent a743a44 commit 314545f
Show file tree
Hide file tree
Showing 19 changed files with 374 additions and 60 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,5 @@ sherpa-onnx-online-punct-en-2024-08-06
sherpa-onnx-pyannote-segmentation-3-0
sherpa-onnx-moonshine-tiny-en-int8
sherpa-onnx-moonshine-base-en-int8
harmony-os/SherpaOnnxHar/sherpa_onnx/LICENSE
harmony-os/SherpaOnnxHar/sherpa_onnx/CHANGELOG.md
6 changes: 6 additions & 0 deletions harmony-os/SherpaOnnxHar/sherpa_onnx/Index.ets
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ export {
TtsOutput,
TtsInput,
} from './src/main/ets/components/NonStreamingTts';

export {
SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingExtractor,
SpeakerEmbeddingManager,
} from './src/main/ets/components/SpeakerIdentification';
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
static Napi::External<SherpaOnnxSpeakerEmbeddingExtractor>
CreateSpeakerEmbeddingExtractorWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

#if __OHOS__
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();

Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();

return {};
}
#else
if (info.Length() != 1) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
Expand All @@ -19,6 +30,7 @@ CreateSpeakerEmbeddingExtractorWrapper(const Napi::CallbackInfo &info) {

return {};
}
#endif

if (!info[0].IsObject()) {
Napi::TypeError::New(env, "You should pass an object as the only argument.")
Expand Down Expand Up @@ -46,8 +58,18 @@ CreateSpeakerEmbeddingExtractorWrapper(const Napi::CallbackInfo &info) {

SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider);

#if __OHOS__
std::unique_ptr<NativeResourceManager,
decltype(&OH_ResourceManager_ReleaseNativeResourceManager)>
mgr(OH_ResourceManager_InitNativeResourceManager(env, info[1]),
&OH_ResourceManager_ReleaseNativeResourceManager);

const SherpaOnnxSpeakerEmbeddingExtractor *extractor =
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(&c, mgr.get());
#else
const SherpaOnnxSpeakerEmbeddingExtractor *extractor =
SherpaOnnxCreateSpeakerEmbeddingExtractor(&c);
#endif

if (c.model) {
delete[] c.model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,18 @@ export type TtsOutput = {

export const offlineTtsGenerate: (handle: object, input: object) => TtsOutput;
export const offlineTtsGenerateAsync: (handle: object, input: object) => Promise<TtsOutput>;

export const createSpeakerEmbeddingExtractor: (config: object, mgr?: object) => object;
export const speakerEmbeddingExtractorDim: (handle: object) => number;
export const speakerEmbeddingExtractorCreateStream: (handle: object) => object;
export const speakerEmbeddingExtractorIsReady: (handle: object, stream: object) => boolean;
export const speakerEmbeddingExtractorComputeEmbedding: (handle: object, stream: object, enableExternalBuffer: boolean) => Float32Array;
export const createSpeakerEmbeddingManager: (dim: number) => object;
export const speakerEmbeddingManagerAdd: (handle: object, speaker: {name: string, v: Float32Array}) => boolean;
export const speakerEmbeddingManagerAddListFlattened: (handle: object, speaker: {name: string, vv: Float32Array, n: number}) => boolean;
export const speakerEmbeddingManagerRemove: (handle: object, name: string) => boolean;
export const speakerEmbeddingManagerSearch: (handle: object, obj: {v: Float32Array, threshold: number}) => string;
export const speakerEmbeddingManagerVerify: (handle: object, obj: {name: string, v: Float32Array, threshold: number}) => boolean;
export const speakerEmbeddingManagerContains: (handle: object, name: string) => boolean;
export const speakerEmbeddingManagerNumSpeakers: (handle: object) => number;
export const speakerEmbeddingManagerGetAllSpeakers: (handle: object) => Array<string>;
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {
getOfflineTtsSampleRate,
offlineTtsGenerate,
offlineTtsGenerateAsync,
} from "libsherpa_onnx.so";
} from 'libsherpa_onnx.so';

export class OfflineTtsVitsModelConfig {
public model: string = '';
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import {
createSpeakerEmbeddingExtractor,
createSpeakerEmbeddingManager,
speakerEmbeddingExtractorComputeEmbedding,
speakerEmbeddingExtractorCreateStream,
speakerEmbeddingExtractorDim,
speakerEmbeddingExtractorIsReady,
speakerEmbeddingManagerAdd,
speakerEmbeddingManagerAddListFlattened,
speakerEmbeddingManagerContains,
speakerEmbeddingManagerGetAllSpeakers,
speakerEmbeddingManagerNumSpeakers,
speakerEmbeddingManagerRemove,
speakerEmbeddingManagerSearch,
speakerEmbeddingManagerVerify
} from 'libsherpa_onnx.so';
import { OnlineStream } from './StreamingAsr';

export class SpeakerEmbeddingExtractorConfig {
public model: string = '';
public numThreads: number = 1;
public debug: boolean = false;
public provider: string = 'cpu';
}

export class SpeakerEmbeddingExtractor {
public config: SpeakerEmbeddingExtractorConfig = new SpeakerEmbeddingExtractorConfig();
public dim: number;
private handle: object;

constructor(config: SpeakerEmbeddingExtractorConfig, mgr?: object) {
this.handle = createSpeakerEmbeddingExtractor(config, mgr);
this.config = config;
this.dim = speakerEmbeddingExtractorDim(this.handle);
}

createStream(): OnlineStream {
return new OnlineStream(
speakerEmbeddingExtractorCreateStream(this.handle));
}

isReady(stream: OnlineStream): boolean {
return speakerEmbeddingExtractorIsReady(this.handle, stream.handle);
}

compute(stream: OnlineStream, enableExternalBuffer: boolean = true): Float32Array {
return speakerEmbeddingExtractorComputeEmbedding(
this.handle, stream.handle, enableExternalBuffer);
}
}

function flatten(arrayList: Float32Array[]): Float32Array {
let n = 0;
for (let i = 0; i < arrayList.length; ++i) {
n += arrayList[i].length;
}
let ans = new Float32Array(n);

let offset = 0;
for (let i = 0; i < arrayList.length; ++i) {
ans.set(arrayList[i], offset);
offset += arrayList[i].length;
}
return ans;
}

interface SpeakerNameWithEmbedding {
name: string;
v: Float32Array;
}

interface SpeakerNameWithEmbeddingList {
name: string;
v: Float32Array[];
}

interface SpeakerNameWithEmbeddingN {
name: string;
vv: Float32Array;
n: number;
}

interface EmbeddingWithThreshold {
v: Float32Array;
threshold: number;
}

interface SpeakerNameEmbeddingThreshold {
name: string;
v: Float32Array;
threshold: number;
}

export class SpeakerEmbeddingManager {
public dim: number;
private handle: object;

constructor(dim: number) {
this.handle = createSpeakerEmbeddingManager(dim);
this.dim = dim;
}

add(speaker: SpeakerNameWithEmbedding): boolean {
return speakerEmbeddingManagerAdd(this.handle, speaker);
}

addMulti(speaker: SpeakerNameWithEmbeddingList): boolean {
const c: SpeakerNameWithEmbeddingN = {
name: speaker.name,
vv: flatten(speaker.v),
n: speaker.v.length,
};
return speakerEmbeddingManagerAddListFlattened(this.handle, c);
}

remove(name: string): boolean {
return speakerEmbeddingManagerRemove(this.handle, name);
}

search(obj: EmbeddingWithThreshold): string {
return speakerEmbeddingManagerSearch(this.handle, obj);
}

verify(obj: SpeakerNameEmbeddingThreshold): boolean {
return speakerEmbeddingManagerVerify(this.handle, obj);
}

contains(name: string): boolean {
return speakerEmbeddingManagerContains(this.handle, name);
}

getNumSpeakers(): number {
return speakerEmbeddingManagerNumSpeakers(this.handle);
}

getAllSpeakerNames(): string[] {
return speakerEmbeddingManagerGetAllSpeakers(this.handle);
}
}
33 changes: 31 additions & 2 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1328,8 +1328,8 @@ struct SherpaOnnxSpeakerEmbeddingExtractor {
std::unique_ptr<sherpa_onnx::SpeakerEmbeddingExtractor> impl;
};

const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractor(
static sherpa_onnx::SpeakerEmbeddingExtractorConfig
GetSpeakerEmbeddingExtractorConfig(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config) {
sherpa_onnx::SpeakerEmbeddingExtractorConfig c;
c.model = SHERPA_ONNX_OR(config->model, "");
Expand All @@ -1342,9 +1342,21 @@ SherpaOnnxCreateSpeakerEmbeddingExtractor(
}

if (config->debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", c.ToString().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str());
#endif
}

return c;
}

const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config) {
auto c = GetSpeakerEmbeddingExtractorConfig(config);

if (!c.Validate()) {
SHERPA_ONNX_LOGE("Errors in config!");
return nullptr;
Expand Down Expand Up @@ -1983,6 +1995,23 @@ SherpaOnnxVoiceActivityDetector *SherpaOnnxCreateVoiceActivityDetectorOHOS(
return p;
}

const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config,
NativeResourceManager *mgr) {
if (!mgr) {
return SherpaOnnxCreateSpeakerEmbeddingExtractor(config);
}

auto c = GetSpeakerEmbeddingExtractorConfig(config);

auto p = new SherpaOnnxSpeakerEmbeddingExtractor;

p->impl = std::make_unique<sherpa_onnx::SpeakerEmbeddingExtractor>(mgr, c);

return p;
}

#if SHERPA_ONNX_ENABLE_TTS == 1
SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS(
const SherpaOnnxOfflineTtsConfig *config, NativeResourceManager *mgr) {
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,11 @@ SherpaOnnxCreateVoiceActivityDetectorOHOS(

SHERPA_ONNX_API SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS(
const SherpaOnnxOfflineTtsConfig *config, NativeResourceManager *mgr);

SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config,
NativeResourceManager *mgr);
#endif

#if defined(__GNUC__)
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
for (const auto &f : files) {
if (config.model.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str());
#else
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
#endif
}
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
Expand Down
16 changes: 13 additions & 3 deletions sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ class SpeakerEmbeddingExtractorGeneralImpl
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}

#if __ANDROID_API__ >= 9
template <typename Manager>
SpeakerEmbeddingExtractorGeneralImpl(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
Manager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: model_(mgr, config) {}
#endif

int32_t Dim() const override { return model_.GetMetaData().output_dim; }

Expand All @@ -46,9 +45,15 @@ class SpeakerEmbeddingExtractorGeneralImpl
std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %{public}d",
num_frames);
#else
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
#endif
return {};
}

Expand All @@ -64,8 +69,13 @@ class SpeakerEmbeddingExtractorGeneralImpl
if (meta_data.feature_normalize_type == "global-mean") {
SubtractGlobalMean(features.data(), num_frames, feat_dim);
} else {
#if __OHOS__
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %{public}s",
meta_data.feature_normalize_type.c_str());
#else
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
meta_data.feature_normalize_type.c_str());
#endif
exit(-1);
}
}
Expand Down
Loading

0 comments on commit 314545f

Please sign in to comment.