Skip to content

Commit

Permalink
whisper.android : support benchmark for Android example. (ggerganov#542)
Browse files Browse the repository at this point in the history
* whisper.android: Support benchmark for Android example.

* whisper.android: update screenshot in README.

* update: Make text selectable for copy & paste.

* Update whisper.h to restore API name

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* whisper.android: Restore original API names.

---------

Co-authored-by: tinoue <tinoue@xevo.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
3 people authored Mar 7, 2023
1 parent 308d581 commit 343bf57
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/whisper.android/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ To use:
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
[^1]: I recommend the tiny or base models for running on an Android device.

<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1670775/221613663-a17bf770-27ef-45ab-9a46-a5f99ba65d2a.jpg">
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.whispercppdemo.ui.main

import androidx.compose.foundation.layout.*
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.text.selection.SelectionContainer
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.*
import androidx.compose.runtime.Composable
Expand All @@ -19,6 +20,7 @@ fun MainScreen(viewModel: MainScreenViewModel) {
canTranscribe = viewModel.canTranscribe,
isRecording = viewModel.isRecording,
messageLog = viewModel.dataLog,
onBenchmarkTapped = viewModel::benchmark,
onTranscribeSampleTapped = viewModel::transcribeSample,
onRecordTapped = viewModel::toggleRecord
)
Expand All @@ -30,6 +32,7 @@ private fun MainScreen(
canTranscribe: Boolean,
isRecording: Boolean,
messageLog: String,
onBenchmarkTapped: () -> Unit,
onTranscribeSampleTapped: () -> Unit,
onRecordTapped: () -> Unit
) {
Expand All @@ -45,8 +48,11 @@ private fun MainScreen(
.padding(innerPadding)
.padding(16.dp)
) {
Row(horizontalArrangement = Arrangement.SpaceBetween) {
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
Column(verticalArrangement = Arrangement.SpaceBetween) {
Row(horizontalArrangement = Arrangement.SpaceBetween, modifier = Modifier.fillMaxWidth()) {
BenchmarkButton(enabled = canTranscribe, onClick = onBenchmarkTapped)
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
}
RecordButton(
enabled = canTranscribe,
isRecording = isRecording,
Expand All @@ -60,7 +66,16 @@ private fun MainScreen(

@Composable
private fun MessageLog(log: String) {
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
SelectionContainer() {
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
}
}

@Composable
private fun BenchmarkButton(enabled: Boolean, onClick: () -> Unit) {
Button(onClick = onClick, enabled = enabled) {
Text("Benchmark")
}
}

@Composable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,15 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {

init {
viewModelScope.launch {
printSystemInfo()
loadData()
}
}

private suspend fun printSystemInfo() {
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
}

private suspend fun loadData() {
printMessage("Loading data...\n")
try {
Expand Down Expand Up @@ -81,10 +86,29 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
}

fun benchmark() = viewModelScope.launch {
runBenchmark(6)
}

fun transcribeSample() = viewModelScope.launch {
transcribeAudio(getFirstSample())
}

private suspend fun runBenchmark(nthreads: Int) {
if (!canTranscribe) {
return
}

canTranscribe = false

printMessage("Running benchmark. This will take minutes...\n")
whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
printMessage("\n")
whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }

canTranscribe = true
}

private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
samplesPath.listFiles()!!.first()
}
Expand Down Expand Up @@ -114,11 +138,14 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
canTranscribe = false

try {
printMessage("Reading wave samples...\n")
printMessage("Reading wave samples... ")
val data = readAudioSamples(file)
printMessage("${data.size / (16000 / 1000)} ms\n")
printMessage("Transcribing data...\n")
val start = System.currentTimeMillis()
val text = whisperContext?.transcribeData(data)
printMessage("Done: $text\n")
val elapsed = System.currentTimeMillis() - start
printMessage("Done ($elapsed ms): $text\n")
} catch (e: Exception) {
Log.w(LOG_TAG, e)
printMessage("${e.localizedMessage}\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class WhisperContext private constructor(private var ptr: Long) {
}
}

suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchMemcpy(nthreads)
}

suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) {
return@withContext WhisperLib.benchGgmlMulMat(nthreads)
}

suspend fun release() = withContext(scope.coroutineContext) {
if (ptr != 0L) {
WhisperLib.freeContext(ptr)
Expand Down Expand Up @@ -66,6 +74,10 @@ class WhisperContext private constructor(private var ptr: Long) {
}
return WhisperContext(ptr)
}

fun getSystemInfo(): String {
return WhisperLib.getSystemInfo()
}
}
}

Expand Down Expand Up @@ -117,6 +129,9 @@ private class WhisperLib {
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
external fun benchMemcpy(nthread: Int): String
external fun benchGgmlMulMat(nthread: Int): String
}
}

Expand Down
29 changes: 28 additions & 1 deletion examples/whisper.android/app/src/main/jni/whisper/jni.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <sys/sysinfo.h>
#include <string.h>
#include "whisper.h"
#include "ggml.h"

#define UNUSED(x) (void)(x)
#define TAG "JNI"
Expand Down Expand Up @@ -213,4 +214,30 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
const char *text = whisper_full_get_segment_text(context, index);
jstring string = (*env)->NewStringUTF(env, text);
return string;
}
}

JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getSystemInfo(
JNIEnv *env, jobject thiz
) {
UNUSED(thiz);
const char *sysinfo = whisper_print_system_info();
jstring string = (*env)->NewStringUTF(env, sysinfo);
return string;
}

JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
}

JNIEXPORT jstring JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz,
jint n_threads) {
UNUSED(thiz);
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
}
31 changes: 26 additions & 5 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4551,6 +4551,15 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
//

WHISPER_API int whisper_bench_memcpy(int n_threads) {
fputs(whisper_bench_memcpy_str(n_threads), stderr);
return 0;
}

WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
static std::string s;
s = "";
char strbuf[256];

ggml_time_init();

size_t n = 50;
Expand Down Expand Up @@ -4580,24 +4589,35 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
src[0] = rand();
}

fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
s += strbuf;

// needed to prevent the compile from optimizing the memcpy away
{
double sum = 0.0;

for (size_t i = 0; i < size; i++) sum += dst[i];

fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
snprintf(strbuf, sizeof(strbuf), "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
s += strbuf;
}

free(src);
free(dst);

return 0;
return s.c_str();
}

WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr);
return 0;
}

WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
static std::string s;
s = "";
char strbuf[256];

ggml_time_init();

const int n_max = 128;
Expand Down Expand Up @@ -4673,11 +4693,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
s = ((2.0*N*N*N*n)/tsum)*1e-9;
}

fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
N, N, s_fp16, n_fp16, s_fp32, n_fp32);
s += strbuf;
}

return 0;
return s.c_str();
}

// =================================================================================================
Expand Down
2 changes: 2 additions & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,9 @@ extern "C" {
// Temporary helpers needed for exposing ggml interface

WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);

#ifdef __cplusplus
}
Expand Down

0 comments on commit 343bf57

Please sign in to comment.