Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support streaming batched chat completion requests #69

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tower-http = { version = "0.5.1", features = ["cors"]}
flume = "0.10.14"
#actix-web = "4.8.0"
anyhow = "1.0.75"
rand = "0.8.5"
hyper = { version = "0.14", features = ["full"] }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.6.0" }
candle-examples = { git = "https://github.com/huggingface/candle.git", version = "0.6.0" }
Expand Down Expand Up @@ -39,7 +40,6 @@ tokio = { version = "1.38.0", features = ["sync"] }
env_logger = "0.10.1"
tracing = "0.1.40"
range-checked = { git = "https://github.com/EricLBuehler/range-checked.git", version = "0.1.0" }
chrono = { version = "0.4.31", features = ["clock"] }
either = { version = "1.13.0", features = ["serde"] }
dirs = "5.0.1"
kernels = {path = "./kernels", version="0.1.0"}
Expand Down
85 changes: 70 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a

Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, BF16)
|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)|
| #2 | **Mistral** |✅|70 tks/s (7B)|
| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|
| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)|
| #5 | **Yi** |✅|75 tks/s (6B)|
| #6 | **StableLM** |✅|99 tks/s (3B)|
| #7 | BigCode/StarCode |TBD|TBD|
| #8 | ChatGLM |TBD|TBD|
| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|
| #10 | **Google Gemma** |✅|130 tks/s (2B)|
| #11 | Blip-large (Multimodal) |TBD|TBD|
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|
| Model ID | Model Type | Supported | Speed (A100, BF16) | Throughput (bs=16)
|--|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 386 tks/s (7B) |
| #2 | **Mistral** |✅|70 tks/s (7B)| 291 tks/s (7B) |
| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD|
| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 467 tks/s (3.8B)|
| #5 | **Yi** |✅|75 tks/s (6B)| 375 tks/s (6B) |
| #6 | **StableLM** |✅|99 tks/s (3B)|TBD|
| #7 | BigCode/StarCode |TBD|TBD|TBD |
| #8 | ChatGLM |TBD|TBD|TBD |
| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|TBD |
| #10 | **Google Gemma** |✅|130 tks/s (2B)|TBD |
| #11 | Blip-large (Multimodal) |TBD|TBD|TBD |
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD |


## Demo Chat with candle-vllm (61-65 tokens/s, LLaMa3.1 8B, bf16, on A100)
Expand Down Expand Up @@ -131,6 +131,61 @@ print(completion.choices[0].message.content)
After the `candle-vllm` service is running, run the Python script and enjoy efficient inference with an OpenAI compatible API server!


## Batched requests

Refer to `examples/benchmark.py`

``` python
async def benchmark():
model = "mistral7b"
max_tokens = 1024
# 16 requests
prompts = ["Explain how to best learn Rust.",
"Please talk about deep learning in 100 words.",
"Do you know the capital city of China? Talk the details of you known.",
"Who is the best female actor in the world? Explain why.",
"How to dealing with depression?",
"How to make money in short time?",
"What is the future trend of large language model?",
"The famous tech companies in the world.",
"Explain how to best learn Rust.",
"Please talk about deep learning in 100 words.",
"Do you know the capital city of China? Talk the details of you known.",
"Who is the best female actor in the world? Explain why.",
"How to dealing with depression?",
"How to make money in short time?",
"What is the future trend of large language model?",
"The famous tech companies in the world."]

# send 16 chat requests at the same time
tasks: List[asyncio.Task] = []
for i in range(len(prompts)):
tasks.append(
asyncio.create_task(
chat_completion(model, max_tokens, prompts[i]))
)

# obtain the correspond stream object for each request
outputs: List[Stream[ChatCompletionChunk]] = await asyncio.gather(*tasks)

# tasks to streaming chat responses
tasks_stream: List[asyncio.Task] = []
for i in range(len(outputs)):
tasks_stream.append(
asyncio.create_task(
stream_response(i, outputs[i]))
)

# gathering the response texts
outputs: List[(int, str)] = await asyncio.gather(*tasks_stream)

# print the results, you may find chat completion statistics in the backend server (i.e., candle-vllm)
for idx, output in outputs:
print("\n\n Response {}: \n\n {}".format(idx, output))


asyncio.run(benchmark())
```


## Usage Help
Expand All @@ -140,7 +195,7 @@ For model-specific help, run `cargo run -- --port 2000 <MODEL_TYPE> --help`

For local model weights, run `cargo run --release -- --port 2000 --weight-path /home/llama2_7b/ llama --repeat-last-n 64`, change the path when needed.

`MODEL_TYPE` = ["llama", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"]
`MODEL_TYPE` = ["llama", "llama3", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"]

`WEIGHT_FILE_PATH` = Corresponding weight path for the given model type

Expand Down
85 changes: 85 additions & 0 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import openai
import asyncio
from openai import Stream
from openai.types.chat import ChatCompletionChunk
from typing import List
# Run: cargo run --release -- --port 2000 --model-id <MODEL_ID> <MODEL_TYPE> --repeat-last-n 64
# MODEL_ID is the huggingface model id or local weight path
# MODEL_TYPE is one of ["llama", "llama3", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"]


openai.api_key = "EMPTY"

openai.base_url = "http://localhost:2000/v1/"

async def chat_completion(model, max_tokens, prompt):
completion = openai.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": prompt,
},
],
max_tokens = max_tokens,
stream=True,
)
return completion

async def stream_response(response_idx, stream: Stream[ChatCompletionChunk]):
result = ""
for o in stream:
r = o.choices[0].delta.content
if r != None:
result += r
return (response_idx, result)

async def benchmark():
model = "mistral7b"
max_tokens = 1024
# 16 requests
prompts = ["Explain how to best learn Rust.",
"Please talk about deep learning in 100 words.",
"Do you know the capital city of China? Talk the details of you known.",
"Who is the best female actor in the world? Explain why.",
"How to dealing with depression?",
"How to make money in short time?",
"What is the future trend of large language model?",
"The famous tech companies in the world.",
"Explain how to best learn Rust.",
"Please talk about deep learning in 100 words.",
"Do you know the capital city of China? Talk the details of you known.",
"Who is the best female actor in the world? Explain why.",
"How to dealing with depression?",
"How to make money in short time?",
"What is the future trend of large language model?",
"The famous tech companies in the world."]

# send 16 chat requests at the same time
tasks: List[asyncio.Task] = []
for i in range(len(prompts)):
tasks.append(
asyncio.create_task(
chat_completion(model, max_tokens, prompts[i]))
)

# obtain the correspond stream object for each request
outputs: List[Stream[ChatCompletionChunk]] = await asyncio.gather(*tasks)

# tasks for streaming chat responses
tasks_stream: List[asyncio.Task] = []
for i in range(len(outputs)):
tasks_stream.append(
asyncio.create_task(
stream_response(i, outputs[i]))
)

# gathering the response texts
outputs: List[(int, str)] = await asyncio.gather(*tasks_stream)

# print the results, you may find chat completion statistics in the backend server (i.e., candle-vllm)
for idx, output in outputs:
print("\n\n Response {}: \n\n {}".format(idx, output))


asyncio.run(benchmark())
4 changes: 0 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,6 @@ pub fn get_model_loader(
}
}

pub fn log_warning(message: &str) {
eprintln!("Warning at {:?}: '{}'", chrono::offset::Utc::now(), message);
}

pub fn hub_load_local_safetensors(
path: &String,
json_file: &str,
Expand Down
152 changes: 152 additions & 0 deletions src/openai/logits_processor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use crate::candle::D;
use crate::candle::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
#[derive(Clone, PartialEq, Debug)]
pub enum Sampling {
ArgMax,
All { temperature: f64 },
TopK { k: usize, temperature: f64 },
TopP { p: f64, temperature: f64 },
TopKThenTopP { k: usize, p: f64, temperature: f64 },
}

pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
sampling: Sampling,
}

impl LogitsProcessor {
pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
let rng = rand::rngs::StdRng::seed_from_u64(seed);
Self { rng, sampling }
}

pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
let sampling = match temperature {
None => Sampling::ArgMax,
Some(temperature) => match top_p {
None => Sampling::All { temperature },
Some(p) => Sampling::TopP { p, temperature },
},
};
Self::from_sampling(seed, sampling)
}

fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
// let logits_v: Vec<f32> = logits.to_vec1()?;
// Use gpu kernel
let next_token = logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
// let next_token = logits_v
// .iter()
// .enumerate()
// .max_by(|(_, u), (_, v)| u.total_cmp(v))
// .map(|(i, _)| i as u32)
// .unwrap();
Ok(next_token)
}

fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32;
Ok(next_token)
}

/// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed
/// probability top_p. This way we never sample tokens that have very low probabilities and are
/// less likely to go "off the rails".
fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();

// Sort by descending probability.
argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));

// Clamp smaller probabilities to zero.
let mut cumsum = 0.;
for index in &argsort_indices {
if cumsum >= top_p {
prs[*index] = 0.0;
} else {
cumsum += prs[*index];
}
}
// Sample with clamped probabilities.
self.sample_multinomial(prs)
}

// top-k sampling samples from the k tokens with the largest probabilities.
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
if top_k >= prs.len() {
self.sample_multinomial(prs)
} else {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
let (indices, _, _) =
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
let index = self.sample_multinomial(&prs)?;
Ok(indices[index as usize] as u32)
}
}

// top-k sampling samples from the k tokens with the largest probabilities.
// then top-p sampling.
fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
if top_k >= prs.len() {
self.sample_topp(prs, top_p)
} else {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
let (indices, _, _) =
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
let sum_p = prs.iter().sum::<f32>();
let index = if top_p <= 0.0 || top_p >= sum_p {
self.sample_multinomial(&prs)?
} else {
self.sample_topp(&mut prs, top_p)?
};
Ok(indices[index as usize] as u32)
}
}

pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
self.sample_f(logits, |_| {})
}

pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let prs = |temperature: f64| -> Result<Vec<f32>> {
let logits = (&logits / temperature)?;
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
let mut prs = prs.to_vec1()?;
f(&mut prs);
Ok(prs)
};

let next_token = match &self.sampling {
Sampling::ArgMax => self.sample_argmax(logits)?,
Sampling::All { temperature } => {
let prs = prs(*temperature)?;
self.sample_multinomial(&prs)?
}
Sampling::TopP { p, temperature } => {
let mut prs = prs(*temperature)?;
if *p <= 0.0 || *p >= 1.0 {
// simply sample from the predicted probability distribution
self.sample_multinomial(&prs)?
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
self.sample_topp(&mut prs, *p as f32)?
}
}
Sampling::TopK { k, temperature } => {
let mut prs = prs(*temperature)?;
self.sample_topk(&mut prs, *k)?
}
Sampling::TopKThenTopP { k, p, temperature } => {
let mut prs = prs(*temperature)?;
self.sample_topk_topp(&mut prs, *k, *p as f32)?
}
};
Ok(next_token)
}
}
1 change: 1 addition & 0 deletions src/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub struct OpenAIServerData {
}

pub mod conversation;
pub mod logits_processor;
pub mod models;
pub mod openai_server;
pub mod pipelines;
Expand Down
Loading
Loading