Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
ksyeo1010 committed Nov 23, 2023
1 parent 47d4081 commit 3133c85
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ public void create(
String accessKey,
String modelPath,
boolean enableAutomaticPunctuation,
boolean enableDiarization,
Promise promise) {
try {

Leopard leopard = new Leopard.Builder()
.setAccessKey(accessKey)
.setModelPath(modelPath.isEmpty() ? null : modelPath)
.setEnableAutomaticPunctuation(enableAutomaticPunctuation)
.setEnableDiarization(enableDiarization)
.build(reactContext);
leopardPool.put(String.valueOf(System.identityHashCode(leopard)), leopard);

Expand Down
8 changes: 5 additions & 3 deletions binding/react-native/ios/Leopard.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright 2022 Picovoice Inc.
// Copyright 2022-2023 Picovoice Inc.
//
// You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE"
// file accompanying this source.
Expand All @@ -20,19 +20,21 @@ class PvLeopard: NSObject {
Leopard.setSdk(sdk: "react-native")
}

@objc(create:modelPath:enableAutomaticPunctuation:resolver:rejecter:)
@objc(create:modelPath:enableAutomaticPunctuation:enableDiarization:resolver:rejecter:)
func create(
accessKey: String,
modelPath: String,
enableAutomaticPunctuation: Bool,
enableDiarization: Bool,
resolver resolve: RCTPromiseResolveBlock,
rejecter reject: RCTPromiseRejectBlock) {

do {
let leopard = try Leopard(
accessKey: accessKey,
modelPath: modelPath,
enableAutomaticPunctuation: enableAutomaticPunctuation
enableAutomaticPunctuation: enableAutomaticPunctuation,
enableDiarization: enableDiarization
)

let handle: String = String(describing: leopard)
Expand Down
159 changes: 118 additions & 41 deletions binding/react-native/test-app/LeopardTestApp/Tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ export type Result = {
errorString?: string;
};

type TestWord = {
word: string;
start_sec: number;
end_sec: number;
confidence: number;
speaker_tag: number;
};

const levenshteinDistance = (words1: string[], words2: string[]) => {
const res = Array.from(
Array(words1.length + 1),
Expand Down Expand Up @@ -54,34 +62,34 @@ const wordErrorRate = (

const validateMetadata = (
words: LeopardWord[],
transcript: string,
audioLength: number,
expectedWords: TestWord[],
enableDiarization: boolean,
): string | null => {
const normTranscript = transcript.toUpperCase();
if (words.length !== expectedWords.length) {
return `Length ${words.length} does not match ${expectedWords.length}`;
}
for (let i = 0; i < words.length; i++) {
if (!normTranscript.includes(words[i].word.toUpperCase())) {
return `${words[i].word} is not in transcript.`;
if (words[i].word !== expectedWords[i].word) {
return `Word ${words[i].word} is not equal to ${expectedWords[i].word}`;
}
if (words[i].startSec <= 0) {
return `${words[i].word} has invalid startSec: '${words[i].startSec}'`;
if (Math.abs(words[i].startSec - expectedWords[i].start_sec) > 0.01) {
return `Start sec ${words[i].startSec} is not equal to ${expectedWords[i].start_sec}`;
}
if (words[i].startSec > words[i].endSec) {
return `${words[i].word} invalid meta: startSec '${words[i].startSec}' > endSec '${words[i].endSec}`;
if (Math.abs(words[i].endSec - expectedWords[i].end_sec) > 0.01) {
return `End sec ${words[i].endSec} is not equal to ${expectedWords[i].end_sec}`;
}
if (i < words.length - 1) {
if (words[i].endSec > words[i + 1].startSec) {
return `${words[i].word} invalid meta: endSec '${words[i].endSec}'
is greater than next word startSec
'${words[i + 1].word} - ${words[i + 1].startSec}'`;
if (Math.abs(words[i].confidence - expectedWords[i].confidence) > 0.01) {
return `Confidence ${words[i].confidence} is not equal to ${expectedWords[i].confidence}`;
}
if (enableDiarization) {
if (words[i].speakerTag !== expectedWords[i].speaker_tag) {
return `Speaker ${words[i].speakerTag} is not equal to ${expectedWords[i].speaker_tag}`;
}
} else {
if (words[i].endSec > audioLength) {
return `${words[i].word} invalid meta: endSec '${words[i].endSec}' is greater than audio length '${audioLength}'`;
if (words[i].speakerTag !== -1) {
return `Invalid speaker tag ${words[i].speakerTag}`;
}
}
if (!(words[i].confidence >= 0 || words[i].confidence <= 1)) {
return `${words[i].word} invalid meta: invalid confidence value '${words[i].confidence}'`;
}
}
return null;
};
Expand Down Expand Up @@ -150,7 +158,7 @@ async function getPcmFromFile(
}

const pcm: number[] = [];
for (let i = (headerOffset / Int16Array.BYTES_PER_ELEMENT); i < fileBytes.length; i += 2) {
for (let i = headerOffset; i < fileBytes.length; i += 2) {
pcm.push(dataView.getInt16(i, true));
}

Expand Down Expand Up @@ -208,14 +216,19 @@ async function runProcTestCase(
language: string,
audioFile: string,
expectedTranscript: string,
punctuations: string[],
expectedWords: TestWord[],
errorRate: number,
params: {
asFile?: boolean;
enablePunctuation?: boolean;
enableDiarization?: boolean;
} = {},
): Promise<Result> {
const {asFile = false, enablePunctuation = false} = params;
const {
asFile = false,
enablePunctuation = false,
enableDiarization = false,
} = params;

const result: Result = {testName: '', success: false};

Expand All @@ -228,6 +241,7 @@ async function runProcTestCase(

const leopard = await Leopard.create(TEST_ACCESS_KEY, modelPath, {
enableAutomaticPunctuation: enablePunctuation,
enableDiarization: enableDiarization
});

const pcm = await getPcmFromFile(audioPath, leopard.sampleRate);
Expand All @@ -238,24 +252,17 @@ async function runProcTestCase(

await leopard.delete();

let normalizedTranscript = expectedTranscript;
if (!enablePunctuation) {
for (const punctuation of punctuations) {
normalizedTranscript = normalizedTranscript.replace(punctuation, '');
}
}

const wer = wordErrorRate(
transcript,
normalizedTranscript,
expectedTranscript,
language === 'ja',
);
if (wer > errorRate) {
result.errorString = `Expected WER '${wer}' to be less than '${errorRate}'`;
return result;
}

const errorMessage = validateMetadata(words, transcript, pcm.length);
const errorMessage = validateMetadata(words, expectedWords, enableDiarization);
if (errorMessage) {
result.errorString = errorMessage;
return result;
Expand Down Expand Up @@ -291,28 +298,28 @@ async function initTests(): Promise<Result[]> {
return results;
}

async function paramTests(): Promise<Result[]> {
async function languageTests(): Promise<Result[]> {
const results: Result[] = [];

for (const testParam of testData.tests.parameters) {
for (const testParam of testData.tests.language_tests) {
const result = await runProcTestCase(
testParam.language,
testParam.audio_file,
testParam.transcript,
testParam.punctuations,
testParam.words,
testParam.error_rate,
);
result.testName = `Process test for '${testParam.language}'`;
logResult(result);
results.push(result);
}

for (const testParam of testData.tests.parameters) {
for (const testParam of testData.tests.language_tests) {
const result = await runProcTestCase(
testParam.language,
testParam.audio_file,
testParam.transcript,
testParam.punctuations,
testParam.transcript_with_punctuation,
testParam.words,
testParam.error_rate,
{
enablePunctuation: true,
Expand All @@ -323,12 +330,12 @@ async function paramTests(): Promise<Result[]> {
results.push(result);
}

for (const testParam of testData.tests.parameters) {
for (const testParam of testData.tests.language_tests) {
const result = await runProcTestCase(
testParam.language,
testParam.audio_file,
testParam.transcript,
testParam.punctuations,
testParam.words,
testParam.error_rate,
{
asFile: true,
Expand All @@ -339,11 +346,81 @@ async function paramTests(): Promise<Result[]> {
results.push(result);
}

for (const testParam of testData.tests.language_tests) {
const result = await runProcTestCase(
testParam.language,
testParam.audio_file,
testParam.transcript,
testParam.words,
testParam.error_rate,
{
enableDiarization: true,
},
);
result.testName = `Process test with diarization for '${testParam.language}'`;
logResult(result);
results.push(result);
}

return results;
}

async function diarizationTests(): Promise<Result[]> {
const results: Result[] = [];

for (const testParam of testData.tests.diarization_tests) {
const result: Result = {testName: '', success: false};

const { language, audio_file, words: expectedWords } = testParam;

try {
const modelPath =
language === 'en'
? getPath('model_files/leopard_params.pv')
: getPath(`model_files/leopard_params_${language}.pv`);

const leopard = await Leopard.create(TEST_ACCESS_KEY, modelPath, {
enableDiarization: true
});

const { words } = await leopard.processFile(await absolutePath('audio_samples', audio_file));
await leopard.delete();

let errorMessage: string | null = null;
if (words.length !== expectedWords.length) {
errorMessage = `Length ${words.length} does not match ${expectedWords.length}`;
}
for (let i = 0; i < words.length; i++) {
if (words[i].word !== expectedWords[i].word) {
errorMessage = `Word ${words[i].word} is not equal to ${expectedWords[i].word}`;
break;
}
if (words[i].speakerTag !== expectedWords[i].speaker_tag) {
errorMessage = `Speaker ${words[i].speakerTag} is not equal to ${expectedWords[i].speaker_tag}`;
break;
}
}

if (errorMessage) {
result.errorString = errorMessage;
} else {
result.success = true;
}
} catch (e) {
result.errorString = `Failed to process leopard with: ${e}`;
}

result.testName = `Diarization multiple speaker test for '${testParam.language}'`;
logResult(result);
results.push(result);
}

return results;
}

export async function runLeopardTests(): Promise<Result[]> {
const initResults = await initTests();
const paramResults = await paramTests();
return [...initResults, ...paramResults];
const languageResults = await languageTests();
const diarizationResults = await diarizationTests();
return [...initResults, ...languageResults, ...diarizationResults];
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ allprojects {
maven {
url("$rootDir/../node_modules/detox/Detox-android")
}
maven {
url 'https://s01.oss.sonatype.org/content/repositories/aipicovoice-1305'
}
}
}
6 changes: 3 additions & 3 deletions binding/react-native/test-app/LeopardTestApp/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1445,9 +1445,9 @@
version "2.0.0"

"@picovoice/react-native-voice-processor@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@picovoice/react-native-voice-processor/-/react-native-voice-processor-1.1.0.tgz#fe469f990c3b031411b585f28fc80fc0a9b9de41"
integrity sha512-TzIYL/34VbYhIcKH/LbivZwRpxUiKr4+XW5Gcp8vFdZKGP431SP4Lob2hLhSIu+E/yo1z4xwCELCznMpHyWNhg==
version "1.2.0"
resolved "https://registry.yarnpkg.com/@picovoice/react-native-voice-processor/-/react-native-voice-processor-1.2.0.tgz#82a98b41d9236ababe330dae873062ee0e1b24c3"
integrity sha512-zolTEo3qsqeUwY7JRslV/yhiA+oBrkeogOTxjHIEJ//yEsr7YKlI1PcqTbU5/xjmUiukh62gmwTXhosnQYdasQ==

"@react-native-community/cli-clean@^10.0.0":
version "10.1.1"
Expand Down

0 comments on commit 3133c85

Please sign in to comment.