Skip to content

Commit

Permalink
Merge pull request #55 from EricLBuehler/develop
Browse files Browse the repository at this point in the history
Fix mistral output repetition with F32 rope and penalty & temperature parameters
  • Loading branch information
guoqingbao authored Jul 11, 2024
2 parents bfa16e3 + e4c7237 commit d0f31eb
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 20 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ For chat history settings, set `record_conversation` to `true` to let candle-vll

For chat streaming, the `stream` flag in chat request need to be set to `True`.

You may revise `repetition_penalty` and `temperature` flag in chat request (http post).
You may supply `penalty` and `temperature` to the model to **prevent potential repetitions**, for example:

```
cargo run --release -- --port 2000 --weight-path /home/mistral_7b/ mistral --repeat-last-n 32 --penalty 1.1 temperature 0.8
```

## Report issue
Installing `candle-vllm` is as simple as the following steps. If you have any problems, please create an
Expand Down
28 changes: 24 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ pub enum ModelSelected {
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,
},
Expand All @@ -25,6 +28,9 @@ pub enum ModelSelected {
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,
},
Expand Down Expand Up @@ -73,6 +79,9 @@ pub enum ModelSelected {
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,
},
Expand All @@ -83,6 +92,9 @@ pub enum ModelSelected {
#[arg(long)]
repeat_last_n: Option<usize>,

#[arg(long)]
temperature: Option<f32>,

#[arg(long)]
penalty: Option<f32>,
},
Expand All @@ -93,10 +105,12 @@ impl ToString for ModelSelected {
match self {
ModelSelected::Llama {
repeat_last_n: _,
temperature: _,
penalty: _,
} => "llama".to_string(),
ModelSelected::Phi2 {
repeat_last_n: _,
temperature: _,
penalty: _,
} => "phi2".to_string(),
ModelSelected::Phi3 {
Expand All @@ -115,10 +129,12 @@ impl ToString for ModelSelected {
} => "qwen2".to_string(),
ModelSelected::Gemma {
repeat_last_n: _,
temperature: _,
penalty: _,
} => "gemma".to_string(),
ModelSelected::Mistral {
repeat_last_n: _,
temperature: _,
penalty: _,
} => "mistral".to_string(),
}
Expand All @@ -132,10 +148,11 @@ pub fn get_model_loader<'a>(
match selected_model {
ModelSelected::Llama {
repeat_last_n,
temperature,
penalty,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, None, None, None, penalty),
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
"llama".to_string(),
)),
if model_id.is_some() {
Expand All @@ -146,10 +163,11 @@ pub fn get_model_loader<'a>(
),
ModelSelected::Phi2 {
repeat_last_n,
temperature,
penalty,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, None, None, None, penalty),
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
"phi2".to_string(),
)),
if model_id.is_some() {
Expand Down Expand Up @@ -194,10 +212,11 @@ pub fn get_model_loader<'a>(
),
ModelSelected::Gemma {
repeat_last_n,
temperature,
penalty,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, None, None, None, penalty),
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
"gemma".to_string(),
)),
if model_id.is_some() {
Expand All @@ -208,10 +227,11 @@ pub fn get_model_loader<'a>(
),
ModelSelected::Mistral {
repeat_last_n,
temperature,
penalty,
} => (
Box::new(DefaultLoader::new(
SpecificConfig::new(repeat_last_n, None, None, None, penalty),
SpecificConfig::new(repeat_last_n, temperature, None, None, penalty),
"mistral".to_string(),
)),
if model_id.is_some() {
Expand Down
17 changes: 11 additions & 6 deletions src/openai/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ impl RotaryEmbedding {
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;

Expand Down Expand Up @@ -204,12 +204,17 @@ impl Attention {
let v = value_states
.reshape((b_sz, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
(q.contiguous()?, k.contiguous()?, v.contiguous()?)
(q, k, v.contiguous()?)
};

let (q, k) = self
.rotary_emb
.apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(
&q.to_dtype(DType::F32)?,
&k.to_dtype(DType::F32)?,
seqlen_offset,
)?;

let q = q.to_dtype(v.dtype())?;
let k = k.to_dtype(v.dtype())?;

let y = self.attn.forward(
&q,
Expand Down
2 changes: 1 addition & 1 deletion src/openai/openai_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async fn chat_completions(
}
let sampling_params = sampling_params.unwrap();

println!("{:?}", sampling_params);
// println!("{:?}", sampling_params);

let created = get_created_time_secs();

Expand Down
16 changes: 8 additions & 8 deletions src/openai/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl<'a> ModelLoader<'a> for DefaultLoader {
dtype: DType,
device: Device,
) -> Result<(Box<dyn ModulePipeline<'a>>, PipelineConfig), APIError> {
let args = self.config.clone();
let specific_args = self.config.clone();

let config = match self.name.as_str() {
"llama" => {
Expand Down Expand Up @@ -260,9 +260,9 @@ impl<'a> ModelLoader<'a> for DefaultLoader {
let pipeline_config = PipelineConfig {
max_model_len: config.max_seq_len,
default_max_tokens,
penalty: self.config.penalty.unwrap_or(1.1),
repeat_last_n: self.config.repeat_last_n.unwrap_or(16),
temperature: self.config.temperature.unwrap_or(0.),
penalty: specific_args.penalty.unwrap_or(1.1),
repeat_last_n: specific_args.repeat_last_n.unwrap_or(32),
temperature: specific_args.temperature.unwrap_or(0.),
};

println!("{:?}", pipeline_config);
Expand All @@ -272,14 +272,14 @@ impl<'a> ModelLoader<'a> for DefaultLoader {
None => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap(),
};

println!("{:?}", self.config);
println!("{:?}", specific_args);

let logits_processor = {
let temperature = args.temperature.unwrap_or(0.) as f64;
let temperature = specific_args.temperature.unwrap_or(0.) as f64;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
match (specific_args.top_k, specific_args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
Expand All @@ -292,7 +292,7 @@ impl<'a> ModelLoader<'a> for DefaultLoader {
Ok((
Box::new(DefaultPipeline {
model,
args,
args: specific_args,
tokenizer,
logits_processor: logits_processor,
conversation: DefaultConversation::new(
Expand Down

0 comments on commit d0f31eb

Please sign in to comment.