Skip to content

Commit

Permalink
Merge pull request ggerganov#33 from tazz4843/whisper-cpp-v1.3.0
Browse files Browse the repository at this point in the history
Update to whisper.cpp v1.3.0.
  • Loading branch information
tazz4843 authored Apr 18, 2023
2 parents 400d03f + bde74db commit 3e5d0f3
Show file tree
Hide file tree
Showing 13 changed files with 916 additions and 182 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# Version 0.6.0 (2023-04-17)
* Update upstream whisper.cpp to v1.3.0
* Fix breaking changes in update, which cascade to users:
* `WhisperContext`s now have a generic type parameter, which is a hashable key for a state map.
This allows for a single context to be reused for multiple different states, saving memory.
* You must create a new state upon creation, even if you are using the context only once, by calling `WhisperContext::create_key`.
* Each method that now takes a state now takes a key, which internally is used to look up the state.
* This also turns `WhisperContext` into an entirely immutable object, meaning it can be shared across threads and used concurrently, safely.
* Send feedback on these changes to the PR: https://github.com/tazz4843/whisper-rs/pull/33

# Version 0.2.0 (2022-10-28)
* Update upstream whisper.cpp to 2c281d190b7ec351b8128ba386d110f100993973.
* Fix breaking changes in update, which cascade to users:
Expand Down
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]

[package]
name = "whisper-rs"
version = "0.5.0"
version = "0.6.0"
edition = "2021"
description = "Rust bindings for whisper.cpp"
license = "Unlicense"
Expand All @@ -14,7 +14,8 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
whisper-rs-sys = { path = "sys", version = "0.3" }
whisper-rs-sys = { path = "sys", version = "0.4" }
dashmap = "5"

[dev-dependencies]
hound = "3.5.0"
Expand Down
24 changes: 17 additions & 7 deletions examples/audio_transcription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
fn main() -> Result<(), &'static str> {
// Load a context and model.
let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
let ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
.expect("failed to load model");
// Create a single global key.
ctx.create_key(()).expect("failed to create key");

// Create a params object for running the model.
// Currently, only the Greedy sampling strategy is implemented, with BeamSearch as a WIP.
// The number of past samples to consider defaults to 0.
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });

Expand Down Expand Up @@ -62,18 +63,27 @@ fn main() -> Result<(), &'static str> {
}

// Run the model.
ctx.full(params, &audio[..]).expect("failed to run model");
ctx.full(&(), params, &audio[..])
.expect("failed to run model");

// Create a file to write the transcript to.
let mut file = File::create("transcript.txt").expect("failed to create file");

// Iterate through the segments of the transcript.
let num_segments = ctx.full_n_segments();
let num_segments = ctx
.full_n_segments(&())
.expect("failed to get number of segments");
for i in 0..num_segments {
// Get the transcribed text and timestamps for the current segment.
let segment = ctx.full_get_segment_text(i).expect("failed to get segment");
let start_timestamp = ctx.full_get_segment_t0(i);
let end_timestamp = ctx.full_get_segment_t1(i);
let segment = ctx
.full_get_segment_text(&(), i)
.expect("failed to get segment");
let start_timestamp = ctx
.full_get_segment_t0(&(), i)
.expect("failed to get start timestamp");
let end_timestamp = ctx
.full_get_segment_t1(&(), i)
.expect("failed to get end timestamp");

// Print the segment to stdout.
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
Expand Down
24 changes: 18 additions & 6 deletions examples/basic_use.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
// more dependencies than the base library.
pub fn usage() -> Result<(), &'static str> {
// load a context and model
let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model");
let ctx = WhisperContext::new("path/to/model").expect("failed to load model");
// make a sample key
// here, since we only use this model once, we use a unique global key
ctx.create_key(()).expect("failed to create key");

// create a params object
// note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
Expand Down Expand Up @@ -41,15 +44,24 @@ pub fn usage() -> Result<(), &'static str> {
)?;

// now we can run the model
ctx.full(params, &audio_data[..])
// note the key we use here is the one we created above
ctx.full(&(), params, &audio_data[..])
.expect("failed to run model");

// fetch the results
let num_segments = ctx.full_n_segments();
let num_segments = ctx
.full_n_segments(&())
.expect("failed to get number of segments");
for i in 0..num_segments {
let segment = ctx.full_get_segment_text(i).expect("failed to get segment");
let start_timestamp = ctx.full_get_segment_t0(i);
let end_timestamp = ctx.full_get_segment_t1(i);
let segment = ctx
.full_get_segment_text(&(), i)
.expect("failed to get segment");
let start_timestamp = ctx
.full_get_segment_t0(&(), i)
.expect("failed to get segment start timestamp");
let end_timestamp = ctx
.full_get_segment_t1(&(), i)
.expect("failed to get segment end timestamp");
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
}

Expand Down
13 changes: 7 additions & 6 deletions examples/full_usage/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ fn main() {
let original_samples = parse_wav_file(audio_path);
let samples = whisper_rs::convert_integer_to_float_audio(&original_samples);

let mut ctx =
let ctx =
WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
ctx.create_key(()).expect("failed to create key");
let params = FullParams::new(SamplingStrategy::default());

ctx.full(params, &samples)
ctx.full(&(), params, &samples)
.expect("failed to convert samples");

let num_segments = ctx.full_n_segments();
let num_segments = ctx.full_n_segments(&()).expect("failed to get number of segments");
for i in 0..num_segments {
let segment = ctx.full_get_segment_text(i).expect("failed to get segment");
let start_timestamp = ctx.full_get_segment_t0(i);
let end_timestamp = ctx.full_get_segment_t1(i);
let segment = ctx.full_get_segment_text(&(), i).expect("failed to get segment");
let start_timestamp = ctx.full_get_segment_t0(&(), i).expect("failed to get start timestamp");
let end_timestamp = ctx.full_get_segment_t1(&(), i).expect("failed to get end timestamp");
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
}
}
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ pub enum WhisperError {
GenericError(c_int),
/// Whisper failed to convert the provided text into tokens.
InvalidText,
/// Creating a state pointer failed. Check stderr for more information.
FailedToCreateState,
/// State pointer ID already exists.
StateIdAlreadyExists,
/// State pointer ID does not exist.
StateIdDoesNotExist,
}

impl From<Utf8Error> for WhisperError {
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod standalone;
mod utilities;
mod whisper_ctx;
mod whisper_params;
mod whisper_state;

pub use error::WhisperError;
pub use standalone::*;
Expand All @@ -17,3 +18,5 @@ pub type WhisperTokenData = whisper_rs_sys::whisper_token_data;
pub type WhisperToken = whisper_rs_sys::whisper_token;
pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback;
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;
Loading

0 comments on commit 3e5d0f3

Please sign in to comment.