Skip to content

Commit

Permalink
text gen blocking main thread fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jason136 committed Oct 17, 2023
1 parent 70df40f commit 5a3a0c8
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 68 deletions.
36 changes: 36 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ edition = "2021"

[target.'cfg(apple)'.dependencies]
accelerate-src = "0.3"
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0", features = ["accelerate"] }

[target.'cfg(any(not(apple)))'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0"}

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0", features = ["accelerate"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }

# cudarc = { version = "0.9.14", features = ["f16"] }

intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
tracing-chrome = "0.7"
tracing-subscriber = "0.3"
Expand All @@ -32,6 +33,8 @@ uuid = { version = "1.4", features = ["v4", "fast-rng"]}

anyhow = "1.0"
thiserror = "1.0"
parking_lot = "0.12"
crossbeam = "0.8"
rayon = "1.8"
once_cell = "1.18"
rand = "0.8"
Expand Down
15 changes: 6 additions & 9 deletions src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use actix_web::{
web::{Data, Json, Path},
HttpResponse,
};
use tokio::task;
use uuid::Uuid;

use crate::{
Expand Down Expand Up @@ -55,7 +54,7 @@ pub async fn prompt_streaming(
return Ok(HttpResponse::BadRequest().body("Client Not Found"));
}

task::spawn(async move {
actix_web::rt::spawn(async move {
let mut clients = state_cloned.text_streaming_controller.clients.lock().await;
if let Some(client) = clients
.get_mut(&user_id) {
Expand All @@ -70,13 +69,11 @@ pub async fn prompt_streaming(
pub async fn prompt_blob(state: Data<AppState>, body: Json<TextGenerationPrompt>) -> Response {
let id = Uuid::new_v4();

task::spawn(async move {
state
.text_blob_controller
.prompt(id, body.prompt.clone(), body.sample_len)
.await
.unwrap();
});
state
.text_blob_controller
.prompt(id, body.prompt.clone(), body.sample_len)
.await
.unwrap();

Ok(HttpResponse::Ok().body(id.to_string()))
}
Expand Down
18 changes: 6 additions & 12 deletions src/text_generation.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use candle_transformers::models::mistral::{Config, Model as Mistral};
use candle_transformers::models::quantized_mistral::Model as QMistral;

use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;

use crossbeam::channel::Sender;
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::Rng;
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::Sender;

use crate::token_stream::TokenOutputStream;
use crate::utils::device;

Expand Down Expand Up @@ -65,7 +63,7 @@ impl Default for TextGenerationArgs {
tokenizer_file: None,
weight_files: None,
quantized: true,
repeat_penalty: 1.1,
repeat_penalty: 1.2,
repeat_last_n: 64,
}
}
Expand Down Expand Up @@ -182,7 +180,7 @@ impl TextGeneration {
}

/// prompts an already loaded LLM and streams output mpsc Sender
pub async fn run(
pub fn run(
&mut self,
prompt: &str,
sample_len: u32,
Expand Down Expand Up @@ -232,13 +230,13 @@ impl TextGeneration {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
sender.send(t).await.unwrap();
sender.send(t).unwrap();
}
}

let gen_time = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(anyhow::Error::msg)? {
sender.send(rest).await.unwrap();
sender.send(rest).unwrap();
}

println!(
Expand All @@ -247,8 +245,4 @@ impl TextGeneration {
);
Ok(())
}
}

// https://github.com/huggingface/candle/blob/main/candle-examples/examples/mistral/main.rs
// https://github.com/huggingface/candle/tree/59ab6d7832600083a1519aa0511e9c7c832ae01c/candle-examples/examples/mistral
// https://docs.rs/hf-hub/0.3.2/hf_hub/api/sync/struct.ApiRepo.html#method.get
}
45 changes: 20 additions & 25 deletions src/text_polled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use std::{
time::{Duration, SystemTime},
};

use crossbeam::channel::{bounded, Sender, Receiver};
use futures::lock::Mutex;
use tokio::{
sync::mpsc::{channel, Receiver, Sender},
task,
time::interval,
};
Expand Down Expand Up @@ -58,55 +58,50 @@ impl TextPolledController {
}

async fn remove_expired_messages(messages: TextPolledMessages) {
let mut messages_lock = messages.lock().await;

println!("items before cleaning: {}", messages_lock.len());

messages_lock.retain(|_, message| {
messages.lock().await.retain(|_, message| {
if let Some(message) = message {
message.generated_at.elapsed().unwrap().as_secs() < 60 * 10
} else {
true
}
});

println!("items after cleaning: {}", messages_lock.len());
println!("messages: {:?}", *messages_lock);
}

pub async fn prompt(&self, id: Uuid, prompt: String, sample_len: u32) -> error::Result<()> {
self.messages.lock().await.insert(id, None);

let messages_clone = self.messages.clone();
task::spawn(async move {
let (tx, mut rx): (Sender<String>, Receiver<String>) = channel(sample_len as usize);

TextGeneration::default()
.run(&prompt, sample_len, tx)
.await
.unwrap();
let (sync_tx, sync_rx): (Sender<String>, Receiver<String>) = bounded(sample_len as usize);

let handle = task::spawn_blocking(move || {
TextGeneration::default().run(&prompt, sample_len, sync_tx).unwrap();
println!("done generating");
});

let mut text = String::new();
while let Some(token) = rx.recv().await {
text.push_str(&token);
}
task::spawn(async move {
messages_clone.lock().await.insert(id, None);
handle.await.unwrap();

let text = sync_rx.try_iter().collect();

println!("text: {:?}", text);

let generated_message = PolledMessage {
text,
generated_at: SystemTime::now(),
};

println!("messages: {:?}", messages_clone.lock().await);
println!("id: {:?}", id);
*messages_clone.lock().await.get_mut(&id).unwrap() = Some(generated_message);
println!("id: {:?}", id);
});

Ok(())
}

pub async fn get_message(&self, id: &Uuid) -> PolledMessageState {
match self.messages.lock().await.remove(id) {
Some(Some(message)) => PolledMessageState::Available(message.text.clone()),
match self.messages.lock().await.get(id) {
Some(Some(message)) => {
PolledMessageState::Available(message.text.clone())
},
Some(None) => PolledMessageState::Generating,
None => PolledMessageState::Missing,
}
Expand Down
32 changes: 13 additions & 19 deletions src/text_streaming.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use actix_web_lab::sse;
use crossbeam::channel::{Sender, Receiver, unbounded};
use futures::{future::join_all, lock::Mutex};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{
sync::mpsc::{Receiver, Sender},
task,
time::interval,
};
Expand All @@ -26,7 +26,7 @@ pub struct StreamingClient {
pipe_task: Arc<task::JoinHandle<()>>,
message_history: Arc<Mutex<Vec<String>>>,
model_args: TextGenerationArgs,
model: Arc<Mutex<TextGeneration>>,
model: Arc<parking_lot::Mutex<TextGeneration>>,
}

impl Default for TextStreamingController {
Expand All @@ -45,7 +45,7 @@ impl Default for TextStreamingController {
impl TextStreamingController {
/// pings clients every 10 seconds to see if they are alive and remove them from the client list if not.
fn spawn_ping(clients: TextStreamingClients) {
task::spawn(async move {
actix_web::rt::spawn(async move {
let mut interval = interval(Duration::from_secs(10));

loop {
Expand All @@ -57,9 +57,9 @@ impl TextStreamingController {

/// removes all non-responsive clients from client list
async fn remove_stale_clients(clients: TextStreamingClients) {
let mut clients_lock = clients.lock().await;
let clients_clone = clients.lock().await.clone();

let futures = clients_lock.iter().map(|(id, client)| async {
let futures = clients_clone.iter().map(|(id, client)| async {
if client
.sse_sender
.send(sse::Event::Comment("ping".into()))
Expand All @@ -75,7 +75,7 @@ impl TextStreamingController {

let ok_client_ids: Vec<Uuid> = join_all(futures).await.into_iter().flatten().collect();

clients_lock.retain(|k, _| ok_client_ids.contains(k));
clients.lock().await.retain(|k, _| ok_client_ids.contains(k));
}
}

Expand All @@ -84,26 +84,25 @@ impl StreamingClient {
let (tx, _) = sse::channel(10);
tx.send(sse::Data::new("connected")).await?;

let message_history: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let (stream_input, mut stream_output): (Sender<String>, Receiver<String>) =
tokio::sync::mpsc::channel(10);
let message_history = Arc::new(Mutex::new(Vec::new()));
let (sync_tx, sync_rx): (Sender<String>, Receiver<String>) = unbounded();

let message_history_clone = message_history.clone();
let tx_clone = tx.clone();
let pipe_task = Arc::new({
task::spawn(async move {
while let Some(msg) = stream_output.recv().await {
while let Ok(msg) = sync_rx.recv() {
message_history_clone.lock().await.push(msg.clone());
tx_clone.send(sse::Data::new(msg)).await.unwrap();
}
})
});

let model = Arc::new(Mutex::new(TextGeneration::new(&model_args)?));
let model = Arc::new(parking_lot::Mutex::new(TextGeneration::new(&model_args)?));

Ok(StreamingClient {
sse_sender: tx.clone(),
stream_input,
stream_input: sync_tx,
pipe_task,
message_history,
model_args,
Expand All @@ -113,20 +112,15 @@ impl StreamingClient {

/// refresh model, needs to be run after every prompt
pub async fn refresh_model(&mut self) -> error::Result<()> {
self.model = Arc::new(Mutex::new(TextGeneration::new(&self.model_args)?));
self.model = Arc::new(parking_lot::Mutex::new(TextGeneration::new(&self.model_args)?));
Ok(())
}

/// prompt the underlying model with message history, piping the results to the client
pub async fn prompt(&mut self, prompt: &str, sample_len: u32) -> error::Result<()> {
let full_prompt = self.message_history.lock().await.concat() + " " + prompt;

let sender = self.stream_input.clone();
self.model
.lock()
.await
.run(&full_prompt, sample_len, sender)
.await?;
self.model.lock().run(&full_prompt, sample_len, self.stream_input.clone())?;

self.refresh_model().await
}
Expand Down

0 comments on commit 5a3a0c8

Please sign in to comment.