Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaMa3.1 chat completion #67

Merged
merged 2 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, BF16)
|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3** |✅|74 tks/s (7B)|
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)|
| #2 | **Mistral** |✅|70 tks/s (7B)|
| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|
| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)|
Expand Down Expand Up @@ -55,6 +55,11 @@ You may also run specific model using huggingface model-id, e.g.,
cargo run --release -- --port 2000 --model-id meta-llama/Llama-2-7b-chat-hf llama --repeat-last-n 64
```

Run latest LLaMa3.1 using local weights

```
cargo run --release -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --repeat-last-n 64
```
### Step 2:

#### Option 1: Chat with ChatUI (recommended)
Expand Down
45 changes: 45 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ pub enum ModelSelected {
max_gen_tokens: Option<usize>,
},

/// Select the llama3 model (default llama3.1-8b).
Llama3 {
/// Control the application of repeat penalty for the last n tokens
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,

#[arg(long)]
max_gen_tokens: Option<usize>,
},

/// Select the phi2 model (default 2.7b).
Phi2 {
/// Control the application of repeat penalty for the last n tokens
Expand Down Expand Up @@ -159,6 +175,12 @@ impl ToString for ModelSelected {
penalty: _,
max_gen_tokens: _,
} => "llama".to_string(),
ModelSelected::Llama3 {
repeat_last_n: _,
temperature: _,
penalty: _,
max_gen_tokens: _,
} => "llama3".to_string(),
ModelSelected::Phi2 {
repeat_last_n: _,
temperature: _,
Expand Down Expand Up @@ -237,6 +259,29 @@ pub fn get_model_loader(
"meta-llama/Llama-2-7b-chat-hf".to_string()
},
),
ModelSelected::Llama3 {
repeat_last_n,
temperature,
penalty,
max_gen_tokens,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(
repeat_last_n,
temperature,
None,
None,
penalty,
max_gen_tokens,
),
"llama3".to_string(),
)),
if model_id.is_some() {
model_id.unwrap()
} else {
"meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()
},
),
ModelSelected::Phi2 {
repeat_last_n,
temperature,
Expand Down
27 changes: 27 additions & 0 deletions src/openai/conversation/default_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum SeparatorStyle {
NoColonTwo,
AddNewLineSingle,
Llama,
Llama3,
Phi,
Qwen2,
Gemma,
Expand Down Expand Up @@ -248,6 +249,32 @@ impl Conversation for DefaultConversation {
accum
}

SeparatorStyle::Llama3 => {
let mut accum = "<|begin_of_text|>".to_string();
for (i, message) in self.messages.iter().enumerate() {
let Message((_role, message)) = message;
if _role.clone() == self.roles.0 {
//user message
if let Some(message) = message {
accum += &format!(
"<|start_header_id|>user<|end_header_id|>\n\n {message} <|eot_id|>"
);
} else {
accum +=
&format!("<|start_header_id|>user<|end_header_id|>\n\n <|eot_id|>");
}
} else if _role.clone() == self.roles.1 {
//assistant message
if let Some(message) = message {
accum += &format!("<|start_header_id|>assistant<|end_header_id|>\n\n {message} <|eot_id|>");
}
} else if i == 0 && !system_prompt.is_empty() {
accum += &system_prompt;
}
}
accum
}

SeparatorStyle::Phi => {
let mut accum = "".to_string();
for (i, message) in self.messages.iter().enumerate() {
Expand Down
6 changes: 3 additions & 3 deletions src/openai/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_core as candle;
use candle_nn::Activation;
use candle_nn::{linear_b, linear_no_bias as linear, Linear, RmsNorm, VarBuilder};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

#[derive(serde::Deserialize, Debug, Clone)]
pub struct GemmaConfig {
pub attention_bias: bool,
Expand Down Expand Up @@ -45,8 +45,8 @@ impl GemmaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.eos_token_id as u32))),
max_seq_len: self.max_position_embeddings.unwrap_or(4096),
sliding_window: None,
hidden_act: hidden_act,
Expand Down
8 changes: 5 additions & 3 deletions src/openai/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use candle_core as candle;
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use candle_transformers::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
pub const MAX_SEQ_LEN: usize = 4096;
use crate::openai::models::TokenID;
use std::iter::zip;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
Expand All @@ -18,8 +19,9 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub bos_token_id: TokenID,
pub eos_token_id: TokenID,
pub max_position_embeddings: Option<usize>,
}

fn default_rope() -> f32 {
Expand All @@ -40,7 +42,7 @@ impl LlamaConfig {
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
max_seq_len: MAX_SEQ_LEN,
max_seq_len: self.max_position_embeddings.unwrap_or(MAX_SEQ_LEN),
sliding_window: None,
hidden_act: None,
tie_word_embeddings: false,
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -37,8 +38,8 @@ impl MistralConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.eos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
9 changes: 7 additions & 2 deletions src/openai/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ use std::collections::HashMap;
#[derive(Deserialize, Debug, Clone)]
pub struct RopeScaling(#[serde(with = "either::serde_untagged")] pub Either<Vec<f64>, String>);

#[derive(Deserialize, Debug, Clone)]
pub struct TokenID(
#[serde(with = "either::serde_untagged")] pub Either<Option<u32>, Option<Vec<u32>>>,
);

#[derive(Debug, Clone)]
pub struct Config {
pub hidden_size: usize,
Expand All @@ -25,8 +30,8 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub bos_token_id: TokenID,
pub eos_token_id: TokenID,
pub max_seq_len: usize,
pub sliding_window: Option<usize>,
pub hidden_act: Option<candle_nn::Activation>,
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::{
layer_norm, linear_no_bias as linear, Embedding, LayerNorm, Linear,
};
use either::Either;
use serde::Deserialize;
use std::iter::zip;

Expand Down Expand Up @@ -41,8 +42,8 @@ impl Phi2Config {
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
rms_norm_eps: self.layer_norm_eps,
rope_theta: self.rope_theta,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
bos_token_id: super::TokenID(Either::Left(self.bos_token_id)),
eos_token_id: super::TokenID(Either::Left(self.eos_token_id)),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
4 changes: 2 additions & 2 deletions src/openai/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl PhiConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
bos_token_id: super::TokenID(Either::Left(self.bos_token_id)),
eos_token_id: super::TokenID(Either::Left(self.eos_token_id)),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_core as candle;
use candle_nn::VarBuilder;
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -40,8 +41,8 @@ impl QwenConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: Some(self.sliding_window),
hidden_act: Some(self.hidden_act),
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/stable_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder};
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -41,8 +42,8 @@ impl StableLMConfig {
rms_norm_eps: self.norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
5 changes: 3 additions & 2 deletions src/openai/models/yi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::paged_attention::PagedAttention;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use candle_transformers::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
use either::Either;
use std::iter::zip;
use std::sync::Arc;

Expand Down Expand Up @@ -37,8 +38,8 @@ impl YiConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
eos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))),
max_seq_len: self.max_position_embeddings,
sliding_window: self.sliding_window,
hidden_act: Some(self.hidden_act),
Expand Down
37 changes: 31 additions & 6 deletions src/openai/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{get_token, ModelLoader, ModelPaths, ModulePipeline, TokenOrFinishReason};
use crate::openai::models::TokenID;
use crate::openai::sampling_params::{Logprobs, TopLogprob};
use crate::scheduler::sequence::SequenceGroup;
use crate::{
Expand Down Expand Up @@ -31,6 +32,7 @@
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use either::Either;
use either::Either::{Left, Right};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use std::{path::PathBuf, sync::Arc};
Expand Down Expand Up @@ -170,7 +172,7 @@
let specific_args = self.config.clone();

let config = match self.name.as_str() {
"llama" => {
"llama" | "llama3" => {
let config: LlamaConfig = try_api!(serde_json::from_slice(&try_api!(
std::fs::read(paths.get_config_filename())
),));
Expand Down Expand Up @@ -238,6 +240,10 @@
LLMModel::LLAMA(try_api!(Llama::load(vb, &config, dtype, &device))),
SeparatorStyle::Llama,
),
"llama3" => (
LLMModel::LLAMA(try_api!(Llama::load(vb, &config, dtype, &device))),
SeparatorStyle::Llama3,
),
"phi2" => (
LLMModel::Phi2(try_api!(Phi2::new(vb, &config, dtype, &device))),
SeparatorStyle::Phi,
Expand Down Expand Up @@ -296,13 +302,32 @@

println!("{:?}", pipeline_config);

let eos_token = match tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap(),
};
let mut stop_token_ids = Vec::<u32>::new();
stop_token_ids.push(eos_token);

match &config.eos_token_id {
//eos_token defined in the config
TokenID(Either::Left(eos_token)) => {
if let Some(tk) = eos_token {
stop_token_ids.push(*tk);
}
}
TokenID(Either::Right(eos_token_list)) => {
if let Some(tks) = eos_token_list {
stop_token_ids.extend(tks)
}
}
}

if stop_token_ids.len() == 0 {
//if no eos_token defined in the config, use default
let eos_token = match tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
_ => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap_or(0),
};
stop_token_ids.push(eos_token);
}

//custome stop tokens

Check warning on line 330 in src/openai/pipelines/pipeline.rs

View workflow job for this annotation

GitHub Actions / Typos

"custome" should be "custom" or "customs" or "costume" or "customer".
if let Some(custom_stop) = &config.custom_stop_tokens {
for stop in custom_stop {
match tokenizer.get_token(&stop) {
Expand Down
Loading