Skip to content

Commit

Permalink
serve --model-path accept local path for loading the model
Browse files Browse the repository at this point in the history
  • Loading branch information
tiero committed Oct 24, 2023
1 parent 2c68e04 commit c34391d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 35 deletions.
25 changes: 0 additions & 25 deletions manifest.json

This file was deleted.

1 change: 0 additions & 1 deletion prem-registry.json

This file was deleted.

40 changes: 31 additions & 9 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::Serialize;
use serde_json::to_string;
use std::fs;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::{convert::Infallible, net::SocketAddr};
use whisper_cli::{Language, Model, Size, Whisper};

Expand All @@ -31,8 +32,11 @@ enum SubCommand {
#[command(about = "Start the transcription server.")]
Serve {
/// Port to listen on
#[clap(short, long, default_value = "3030")]
#[clap(short, long, default_value = "8000")]
port: u16,
/// Path to the Whisper model
#[clap(short, long)]
model_path: String,
},
#[command(about = "Transcribe a given audio file.")]
Transcribe(TranscribeArgs),
Expand Down Expand Up @@ -69,15 +73,29 @@ struct TranscribeArgs {
async fn main() {
let opts = Opts::parse();
match opts.subcmd {
SubCommand::Serve { port } => start_server(port).await,
SubCommand::Serve { port, model_path} => {
let model_path = Path::new(&model_path);
start_server(port, &model_path).await;
}
SubCommand::Transcribe(args) => transcribe_audio(args).await,
}
}

async fn start_server(port: u16) {
async fn start_server(port: u16, model_path: &Path) {
// load model
let whisper = Arc::new(Mutex::new(
Whisper::from_model_path(model_path, Some(Language::Auto)).await
));

let make_svc = make_service_fn(move |_conn| {
let whisper_clone = whisper.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| handle_transcription(req, whisper_clone.clone())))
}
});

// start listening
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let make_svc =
make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn(handle_transcription)) });
let server = Server::bind(&addr).serve(make_svc);

println!("🏃‍♀️ Server running at: {}", addr);
Expand All @@ -87,7 +105,10 @@ async fn start_server(port: u16) {
}

// A handler for incoming requests.
async fn handle_transcription(req: Request<Body>) -> Result<Response<Body>, Infallible> {
async fn handle_transcription(
req: Request<Body>,
whisper: Arc<Mutex<Whisper>>,
) -> Result<Response<Body>, Infallible> {
// Extract the `multipart/form-data` boundary from the headers.
let boundary = req
.headers()
Expand All @@ -108,9 +129,10 @@ async fn handle_transcription(req: Request<Body>) -> Result<Response<Body>, Infa

if let Ok(trans_req) = transcription_request {
let audio = Path::new(trans_req.as_str());
let mut whisper =
Whisper::new(Model::new(Size::TinyEnglish), Some(Language::English)).await;
let transcript = whisper.transcribe(audio, false, false).unwrap();
let transcript = {
let mut whisper_guard = whisper.lock().unwrap();
whisper_guard.transcribe(audio, false, false).unwrap()
};
println!("time: {:?}", transcript.processing_time);

let transcript_text = transcript.as_text();
Expand Down
8 changes: 8 additions & 0 deletions src/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ impl Whisper {
}
}

pub async fn from_model_path<P: AsRef<Path>>(model: P, lang: Option<Language>) -> Self {
Self {
lang,
ctx: WhisperContext::new(model.as_ref().to_str().unwrap())
.expect("Failed to load model."),
}
}

pub fn transcribe<P: AsRef<Path>>(
&mut self,
audio: P,
Expand Down

0 comments on commit c34391d

Please sign in to comment.