Skip to content

Commit

Permalink
perf: improve transcribe runtime (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarnphm authored Mar 1, 2023
1 parent 0ed8a9b commit 044442d
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 45 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: wheels
on:
workflow_dispatch:
push:
branch:
- main
tags:
- "v*"
pull_request:
branch:
- main
Expand All @@ -28,7 +28,7 @@ jobs:
name: Build source distribution
runs-on: ubuntu-latest
timeout-minutes: 20
if: github.repository_owner == 'aarnphm' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') # Don't run on fork repository
if: github.repository_owner == 'aarnphm' # Don't run on fork repository
steps:
- name: Checkout
uses: actions/checkout@v3
Expand All @@ -43,6 +43,7 @@ jobs:
python -m build --sdist
- name: Upload to PyPI
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
Expand All @@ -54,7 +55,7 @@ jobs:
name: Build wheels for python${{ matrix.python }} (${{ matrix.os }})
runs-on: ${{ matrix.os }}
timeout-minutes: 20
if: github.repository_owner == 'aarnphm' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') # Don't run on fork repository
if: github.repository_owner == 'aarnphm' # Don't run on fork repository
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -85,6 +86,7 @@ jobs:
VERSION=${{ github.ref_name }}
echo "version=${VERSION:1}" >>$GITHUB_OUTPUT
- name: Publish built wheels
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ skip_glob = ["test/*", "venv/*", "lib/*", "bazel-*"]

[tool.pyright]
pythonVersion = "3.11"
exclude = ["bazel-*", "extern"]
exclude = ["bazel-*", "extern", "venv"]
typeCheckingMode = "strict"
analysis.useLibraryCodeForTypes = true
enableTypeIgnoreComments = true
16 changes: 11 additions & 5 deletions src/whispercpp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from __future__ import annotations

import typing as _t
from dataclasses import dataclass as _dataclass
import typing as t
from dataclasses import dataclass

from .utils import LazyLoader
from .utils import MODELS_URL
from .utils import download_model

if _t.TYPE_CHECKING:
if t.TYPE_CHECKING:
from . import api
else:
api = LazyLoader("api", globals(), "whispercpp.api")
del LazyLoader


@_dataclass
@dataclass
class Whisper:
def __init__(self, *args: _t.Any, **kwargs: _t.Any):
def __init__(self, *args: t.Any, **kwargs: t.Any):
raise RuntimeError(
"Using '__init__()' is not allowed. Use 'from_pretrained()' instead."
)

if t.TYPE_CHECKING:
# The following will be populated by from_pretrained.
_ref: api.WhisperPreTrainedModel
context: api.Context
params: api.Params

@classmethod
def from_pretrained(cls, model_name: str):
if model_name not in MODELS_URL:
Expand Down
1 change: 1 addition & 0 deletions src/whispercpp/api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Context:
) -> int: ...
def full_get_segment_text(self, segment: int) -> str: ...
def full_n_segments(self) -> int: ...
def free(self) -> None: ...

class WhisperPreTrainedModel:
context: Context
Expand Down
54 changes: 32 additions & 22 deletions src/whispercpp/api_export.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
#include "context.h"
#include <iterator>
#include <algorithm>
#include <functional>
#include <numeric>
#include <sstream>
#include <stdio.h>

namespace whisper {

struct new_segment_callback_data {
std::vector<std::string> *results;
};

struct Whisper {
Context ctx;
FullParams params;
Expand All @@ -13,37 +18,42 @@ struct Whisper {
Whisper(const char *model_path)
: ctx(Context::from_file(model_path)), params(defaults()) {}

// Set default params to recommended.
static FullParams defaults() {
SamplingStrategies st = SamplingStrategies();
st.type = SamplingStrategies::GREEDY;
st.greedy = SamplingGreedy();
FullParams p = FullParams::from_sampling_strategy(st);
FullParams p = FullParams::from_sampling_strategy(
SamplingStrategies::from_strategy_type(SamplingStrategies::GREEDY));
// disable printing progress
p.set_print_progress(false);
// disable realtime print, using callback
p.set_print_realtime(false);

// invoke new_segment_callback for faster transcription.
p.set_new_segment_callback([](struct whisper_context *ctx, int n_new,
void *user_data) {
const auto &results = ((new_segment_callback_data *)user_data)->results;

const int n_segments = whisper_full_n_segments(ctx);

for (int i = n_segments - n_new; i < n_segments; i++) {
const char *text = whisper_full_get_segment_text(ctx, i);
results->push_back(text);
};
});
return p;
}

std::string transcribe(std::vector<float> data, int num_proc) {
std::vector<std::string> res;
int ret;
if (num_proc > 0) {
ret = ctx.full_parallel(params, data, num_proc);
} else {
ret = ctx.full(params, data);
}
if (ret != 0) {
std::vector<std::string> results;
new_segment_callback_data user_data = {&results};
params.set_new_segment_callback_user_data(&user_data);
if (ctx.full_parallel(params, data, num_proc) != 0) {
throw std::runtime_error("transcribe failed");
}
for (int i = 0; i < ctx.full_n_segments(); i++) {
res.push_back(ctx.full_get_segment_text(i));
}

// We are copying this in memory here, not ideal.
const char *const delim = "";
std::ostringstream imploded;
std::copy(res.begin(), res.end(),
std::ostream_iterator<std::string>(imploded, delim));
return imploded.str();
// We are allocating a new string for every element in the vector.
// This is not efficient, for larger files.
return std::accumulate(results.begin(), results.end(), std::string(delim));
};
};

Expand Down
18 changes: 15 additions & 3 deletions src/whispercpp/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,12 @@ int Context::full_get_segment_t1(int segment) {
}

// Get the text of the specified segment.
std::string Context::full_get_segment_text(int segment) {
const char *Context::full_get_segment_text(int segment) {
const char *ret = whisper_full_get_segment_text(ctx, segment);
if (ret == nullptr) {
throw std::runtime_error("null pointer");
}
return std::string(ret);
return ret;
}

// Get numbers of tokens in specified segments.
Expand Down Expand Up @@ -608,6 +608,17 @@ void FullParams::set_logits_filter_callback_user_data(void *user_data) {
fp.logits_filter_callback_user_data = user_data;
}

SamplingStrategies SamplingStrategies::from_strategy_type(StrategyType type) {
switch (type) {
case GREEDY:
return SamplingStrategies(SamplingGreedy());
case BEAM_SEARCH:
return SamplingStrategies(SamplingBeamSearch());
default:
throw std::invalid_argument("Invalid strategy type");
};
}

void ExportContextApi(py::module &m) {
py::class_<Context>(m, "Context", "A light wrapper around whisper_context")
.def_static("from_file", &Context::from_file, "filename"_a)
Expand Down Expand Up @@ -689,7 +700,8 @@ void ExportParamsApi(py::module &m) {

py::class_<SamplingStrategies>(m, "SamplingStrategies",
"Available sampling strategy for whisper")
.def(py::init<>())
.def_static("from_strategy_type", &SamplingStrategies::from_strategy_type,
"strategy"_a)
.def_property(
"type", [](SamplingStrategies &self) { return self.type; },
[](SamplingStrategies &self, SamplingStrategies::StrategyType type) {
Expand Down
12 changes: 11 additions & 1 deletion src/whispercpp/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "pybind11/stl.h"
#include "whisper.h"
#endif
#include <stdio.h>
#include <string>
#include <vector>

Expand All @@ -30,6 +31,15 @@ struct SamplingStrategies {
struct SamplingGreedy greedy;
struct SamplingBeamSearch beam_search;
};

SamplingStrategies(SamplingGreedy greedy) : type(GREEDY) {
this->greedy = greedy;
};
SamplingStrategies(SamplingBeamSearch beam_search) : type(BEAM_SEARCH) {
this->beam_search = beam_search;
};

static SamplingStrategies from_strategy_type(StrategyType type);
};

class FullParams {
Expand Down Expand Up @@ -292,7 +302,7 @@ class Context {
int full_get_segment_t0(int segment);
int full_get_segment_t1(int segment);

std::string full_get_segment_text(int segment);
const char *full_get_segment_text(int segment);
int full_n_tokens(int segment);

std::string full_get_token_text(int segment, int token);
Expand Down
24 changes: 15 additions & 9 deletions tools/release
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
set -e

if [ "$#" -eq 1 ]; then
VERSION_STR=$1
VERSION=$1
else
echo "Must provide release version string, e.g. ./script/release.sh 1.0.5"
exit 1
fi

SEMVER_REGEX="^[vV]?(0|[1-9][0-9]*)\\.(0|[1-9][0-9]*)\\.(0|[1-9][0-9]*)(\\-[0-9A-Za-z-]+(\\.[0-9A-Za-z-]+)*)?(\\+[0-9A-Za-z-]+(\\.[0-9A-Za-z-]+)*)?$"

if [[ "$VERSION_STR" =~ $SEMVER_REGEX ]]; then
echo "Releasing whispercpp version v$VERSION_STR:"
if [[ "$VERSION" =~ $SEMVER_REGEX ]]; then
echo "Releasing whispercpp version v$VERSION:"
else
echo "Warning: version $VERSION_STR must follow semantic versioning schema, ignore this for preview releases"
echo "Warning: version $VERSION must follow semantic versioning schema, ignore this for preview releases"
exit 0
fi

Expand All @@ -26,7 +26,7 @@ if [ -d "$GIT_ROOT"/dist ]; then
rm -rf "$GIT_ROOT"/build
fi

tag_name="v$VERSION_STR"
tag_name="v$VERSION"

if git rev-parse "$tag_name" > /dev/null 2>&1; then
echo "git tag '$tag_name' exist, using existing tag."
Expand All @@ -36,11 +36,17 @@ if git rev-parse "$tag_name" > /dev/null 2>&1; then
else
echo "Creating git tag '$tag_name'"

sed -i.bak "s/version =.*/version = \"${VERSION_STR}\"/g" pyproject.toml
rm pyproject.toml.bak
git add pyproject.toml && git commit --signoff -S -sv -m "release(pyproject): bump version to $VERSION_STR [generated]"
sed -i.bak "s/version =.*/version = \"${VERSION}\"/g" pyproject.toml && rm pyproject.toml.bak
git add pyproject.toml && git commit --signoff -S -sv -m "release(pyproject): bump version to $VERSION [generated]"

git tag -s "$tag_name" -m "Tag generated by tools/release, version: $VERSION_STR"
git tag -s "$tag_name" -m "Tag generated by tools/release, version: $VERSION"

git push origin "$tag_name"
fi

MINOR="$((${VERSION##*.} + 1))"
DEV_VERSION="${VERSION%.*}.${MINOR}.dev0"
sed -i.bak "s/version =.*/version = \"$DEV_VERSION\"/g" pyproject.toml && rm pyproject.toml.bak
sed -i.bak "s/\"\/\/conditions:default\":.*/\"\/\/conditions:default\": \"$DEV_VERSION\",/g" BUILD.bazel && rm BUILD.bazel.bak
git add pyproject.toml BUILD.bazel && git commit --signoff -S -sv -m "chore: bump development version [generated]"
git push origin main

0 comments on commit 044442d

Please sign in to comment.