From 5a3a0c8929163a34577080dcc19c898a932dd0e7 Mon Sep 17 00:00:00 2001 From: jason136 Date: Tue, 17 Oct 2023 14:17:00 -0700 Subject: [PATCH] text gen blocking main thread fix --- Cargo.lock | 36 +++++++++++++++++++++++++++++++++ Cargo.toml | 9 ++++++--- src/controller.rs | 15 ++++++-------- src/text_generation.rs | 18 ++++++----------- src/text_polled.rs | 45 +++++++++++++++++++----------------------- src/text_streaming.rs | 32 ++++++++++++------------------ 6 files changed, 87 insertions(+), 68 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 46e8c65..357408d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -443,12 +443,14 @@ dependencies = [ "candle-core", "candle-nn", "candle-transformers", + "crossbeam", "dotenvy", "env_logger", "futures", "hf-hub", "intel-mkl-src", "once_cell", + "parking_lot", "rand", "rayon", "serde", @@ -832,6 +834,30 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.3" @@ -856,6 +882,16 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.16" diff --git a/Cargo.toml b/Cargo.toml index 25e48d3..42897a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,14 +7,15 @@ edition = "2021" [target.'cfg(apple)'.dependencies] accelerate-src = "0.3" +candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0", features = ["accelerate"] } + +[target.'cfg(any(not(apple)))'.dependencies] +candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0"} [dependencies] -candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0", features = ["accelerate"] } candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" } candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" } -# cudarc = { version = "0.9.14", features = ["f16"] } - intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } tracing-chrome = "0.7" tracing-subscriber = "0.3" @@ -32,6 +33,8 @@ uuid = { version = "1.4", features = ["v4", "fast-rng"]} anyhow = "1.0" thiserror = "1.0" +parking_lot = "0.12" +crossbeam = "0.8" rayon = "1.8" once_cell = "1.18" rand = "0.8" diff --git a/src/controller.rs b/src/controller.rs index 39dbbac..ea56690 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -3,7 +3,6 @@ use actix_web::{ web::{Data, Json, Path}, HttpResponse, }; -use tokio::task; use uuid::Uuid; use crate::{ @@ -55,7 +54,7 @@ pub async fn prompt_streaming( return Ok(HttpResponse::BadRequest().body("Client Not Found")); } - task::spawn(async move { + actix_web::rt::spawn(async move { let mut clients = state_cloned.text_streaming_controller.clients.lock().await; if let Some(client) = clients .get_mut(&user_id) { @@ -70,13 +69,11 @@ pub async fn prompt_streaming( pub async fn prompt_blob(state: Data, body: Json) -> Response { let id = Uuid::new_v4(); - task::spawn(async move { - state - .text_blob_controller - .prompt(id, body.prompt.clone(), body.sample_len) - .await - .unwrap(); - }); + state + .text_blob_controller + .prompt(id, body.prompt.clone(), body.sample_len) + .await + .unwrap(); Ok(HttpResponse::Ok().body(id.to_string())) } diff --git a/src/text_generation.rs b/src/text_generation.rs index e95c8f5..a64ba82 100644 --- a/src/text_generation.rs +++ b/src/text_generation.rs @@ -1,16 +1,14 @@ use candle_transformers::models::mistral::{Config, Model as Mistral}; use candle_transformers::models::quantized_mistral::Model as QMistral; - use candle_core::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; +use crossbeam::channel::Sender; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::Rng; use serde::{Deserialize, Serialize}; use tokenizers::Tokenizer; -use tokio::sync::mpsc::Sender; - use crate::token_stream::TokenOutputStream; use crate::utils::device; @@ -65,7 +63,7 @@ impl Default for TextGenerationArgs { tokenizer_file: None, weight_files: None, quantized: true, - repeat_penalty: 1.1, + repeat_penalty: 1.2, repeat_last_n: 64, } } @@ -182,7 +180,7 @@ impl TextGeneration { } /// prompts an already loaded LLM and streams output mpsc Sender - pub async fn run( + pub fn run( &mut self, prompt: &str, sample_len: u32, @@ -232,13 +230,13 @@ impl TextGeneration { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { - sender.send(t).await.unwrap(); + sender.send(t).unwrap(); } } let gen_time = start_gen.elapsed(); if let Some(rest) = self.tokenizer.decode_rest().map_err(anyhow::Error::msg)? { - sender.send(rest).await.unwrap(); + sender.send(rest).unwrap(); } println!( @@ -247,8 +245,4 @@ impl TextGeneration { ); Ok(()) } -} - -// https://github.com/huggingface/candle/blob/main/candle-examples/examples/mistral/main.rs -// https://github.com/huggingface/candle/tree/59ab6d7832600083a1519aa0511e9c7c832ae01c/candle-examples/examples/mistral -// https://docs.rs/hf-hub/0.3.2/hf_hub/api/sync/struct.ApiRepo.html#method.get +} \ No newline at end of file diff --git a/src/text_polled.rs b/src/text_polled.rs index df3726d..3ab1975 100644 --- a/src/text_polled.rs +++ b/src/text_polled.rs @@ -4,9 +4,9 @@ use std::{ time::{Duration, SystemTime}, }; +use crossbeam::channel::{bounded, Sender, Receiver}; use futures::lock::Mutex; use tokio::{ - sync::mpsc::{channel, Receiver, Sender}, task, time::interval, }; @@ -58,55 +58,50 @@ impl TextPolledController { } async fn remove_expired_messages(messages: TextPolledMessages) { - let mut messages_lock = messages.lock().await; - - println!("items before cleaning: {}", messages_lock.len()); - - messages_lock.retain(|_, message| { + messages.lock().await.retain(|_, message| { if let Some(message) = message { message.generated_at.elapsed().unwrap().as_secs() < 60 * 10 } else { true } }); - - println!("items after cleaning: {}", messages_lock.len()); - println!("messages: {:?}", *messages_lock); } pub async fn prompt(&self, id: Uuid, prompt: String, sample_len: u32) -> error::Result<()> { - self.messages.lock().await.insert(id, None); - let messages_clone = self.messages.clone(); - task::spawn(async move { - let (tx, mut rx): (Sender, Receiver) = channel(sample_len as usize); - TextGeneration::default() - .run(&prompt, sample_len, tx) - .await - .unwrap(); + let (sync_tx, sync_rx): (Sender, Receiver) = bounded(sample_len as usize); + + let handle = task::spawn_blocking(move || { + TextGeneration::default().run(&prompt, sample_len, sync_tx).unwrap(); + println!("done generating"); + }); - let mut text = String::new(); - while let Some(token) = rx.recv().await { - text.push_str(&token); - } + task::spawn(async move { + messages_clone.lock().await.insert(id, None); + handle.await.unwrap(); + + let text = sync_rx.try_iter().collect(); + println!("text: {:?}", text); + let generated_message = PolledMessage { text, generated_at: SystemTime::now(), }; - println!("messages: {:?}", messages_clone.lock().await); - println!("id: {:?}", id); *messages_clone.lock().await.get_mut(&id).unwrap() = Some(generated_message); + println!("id: {:?}", id); }); Ok(()) } pub async fn get_message(&self, id: &Uuid) -> PolledMessageState { - match self.messages.lock().await.remove(id) { - Some(Some(message)) => PolledMessageState::Available(message.text.clone()), + match self.messages.lock().await.get(id) { + Some(Some(message)) => { + PolledMessageState::Available(message.text.clone()) + }, Some(None) => PolledMessageState::Generating, None => PolledMessageState::Missing, } diff --git a/src/text_streaming.rs b/src/text_streaming.rs index ffe704c..4e40eb5 100644 --- a/src/text_streaming.rs +++ b/src/text_streaming.rs @@ -1,8 +1,8 @@ use actix_web_lab::sse; +use crossbeam::channel::{Sender, Receiver, unbounded}; use futures::{future::join_all, lock::Mutex}; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::{ - sync::mpsc::{Receiver, Sender}, task, time::interval, }; @@ -26,7 +26,7 @@ pub struct StreamingClient { pipe_task: Arc>, message_history: Arc>>, model_args: TextGenerationArgs, - model: Arc>, + model: Arc>, } impl Default for TextStreamingController { @@ -45,7 +45,7 @@ impl Default for TextStreamingController { impl TextStreamingController { /// pings clients every 10 seconds to see if they are alive and remove them from the client list if not. fn spawn_ping(clients: TextStreamingClients) { - task::spawn(async move { + actix_web::rt::spawn(async move { let mut interval = interval(Duration::from_secs(10)); loop { @@ -57,9 +57,9 @@ impl TextStreamingController { /// removes all non-responsive clients from client list async fn remove_stale_clients(clients: TextStreamingClients) { - let mut clients_lock = clients.lock().await; + let clients_clone = clients.lock().await.clone(); - let futures = clients_lock.iter().map(|(id, client)| async { + let futures = clients_clone.iter().map(|(id, client)| async { if client .sse_sender .send(sse::Event::Comment("ping".into())) @@ -75,7 +75,7 @@ impl TextStreamingController { let ok_client_ids: Vec = join_all(futures).await.into_iter().flatten().collect(); - clients_lock.retain(|k, _| ok_client_ids.contains(k)); + clients.lock().await.retain(|k, _| ok_client_ids.contains(k)); } } @@ -84,26 +84,25 @@ impl StreamingClient { let (tx, _) = sse::channel(10); tx.send(sse::Data::new("connected")).await?; - let message_history: Arc>> = Arc::new(Mutex::new(Vec::new())); - let (stream_input, mut stream_output): (Sender, Receiver) = - tokio::sync::mpsc::channel(10); + let message_history = Arc::new(Mutex::new(Vec::new())); + let (sync_tx, sync_rx): (Sender, Receiver) = unbounded(); let message_history_clone = message_history.clone(); let tx_clone = tx.clone(); let pipe_task = Arc::new({ task::spawn(async move { - while let Some(msg) = stream_output.recv().await { + while let Ok(msg) = sync_rx.recv() { message_history_clone.lock().await.push(msg.clone()); tx_clone.send(sse::Data::new(msg)).await.unwrap(); } }) }); - let model = Arc::new(Mutex::new(TextGeneration::new(&model_args)?)); + let model = Arc::new(parking_lot::Mutex::new(TextGeneration::new(&model_args)?)); Ok(StreamingClient { sse_sender: tx.clone(), - stream_input, + stream_input: sync_tx, pipe_task, message_history, model_args, @@ -113,7 +112,7 @@ impl StreamingClient { /// refresh model, needs to be run after every prompt pub async fn refresh_model(&mut self) -> error::Result<()> { - self.model = Arc::new(Mutex::new(TextGeneration::new(&self.model_args)?)); + self.model = Arc::new(parking_lot::Mutex::new(TextGeneration::new(&self.model_args)?)); Ok(()) } @@ -121,12 +120,7 @@ impl StreamingClient { pub async fn prompt(&mut self, prompt: &str, sample_len: u32) -> error::Result<()> { let full_prompt = self.message_history.lock().await.concat() + " " + prompt; - let sender = self.stream_input.clone(); - self.model - .lock() - .await - .run(&full_prompt, sample_len, sender) - .await?; + self.model.lock().run(&full_prompt, sample_len, self.stream_input.clone())?; self.refresh_model().await }