diff --git a/README.md b/README.md index bbecb3a..f1c67f6 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,11 @@ For chat history settings, set `record_conversation` to `true` to let candle-vll For chat streaming, the `stream` flag in chat request need to be set to `True`. -You may revise `repetition_penalty` and `temperature` flag in chat request (http post). +You may supply `penalty` and `temperature` to the model to **prevent potential repetitions**, for example: + +``` +cargo run --release -- --port 2000 --weight-path /home/mistral_7b/ mistral --repeat-last-n 32 --penalty 1.1 temperature 0.8 +``` ## Report issue Installing `candle-vllm` is as simple as the following steps. If you have any problems, please create an diff --git a/src/lib.rs b/src/lib.rs index 7387e3f..1fcb508 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,9 @@ pub enum ModelSelected { #[arg(long)] repeat_last_n: Option, + #[arg(long)] + temperature: Option, + #[arg(long)] penalty: Option, }, @@ -25,6 +28,9 @@ pub enum ModelSelected { #[arg(long)] repeat_last_n: Option, + #[arg(long)] + temperature: Option, + #[arg(long)] penalty: Option, }, @@ -73,6 +79,9 @@ pub enum ModelSelected { #[arg(long)] repeat_last_n: Option, + #[arg(long)] + temperature: Option, + #[arg(long)] penalty: Option, }, @@ -83,6 +92,9 @@ pub enum ModelSelected { #[arg(long)] repeat_last_n: Option, + #[arg(long)] + temperature: Option, + #[arg(long)] penalty: Option, }, @@ -93,10 +105,12 @@ impl ToString for ModelSelected { match self { ModelSelected::Llama { repeat_last_n: _, + temperature: _, penalty: _, } => "llama".to_string(), ModelSelected::Phi2 { repeat_last_n: _, + temperature: _, penalty: _, } => "phi2".to_string(), ModelSelected::Phi3 { @@ -115,10 +129,12 @@ impl ToString for ModelSelected { } => "qwen2".to_string(), ModelSelected::Gemma { repeat_last_n: _, + temperature: _, penalty: _, } => "gemma".to_string(), ModelSelected::Mistral { repeat_last_n: _, + temperature: _, penalty: _, } => "mistral".to_string(), } @@ -132,10 +148,11 @@ pub fn get_model_loader<'a>( match selected_model { ModelSelected::Llama { repeat_last_n, + temperature, penalty, } => ( Box::new(DefaultLoader::new( - SpecificConfig::new(repeat_last_n, None, None, None, penalty), + SpecificConfig::new(repeat_last_n, temperature, None, None, penalty), "llama".to_string(), )), if model_id.is_some() { @@ -146,10 +163,11 @@ pub fn get_model_loader<'a>( ), ModelSelected::Phi2 { repeat_last_n, + temperature, penalty, } => ( Box::new(DefaultLoader::new( - SpecificConfig::new(repeat_last_n, None, None, None, penalty), + SpecificConfig::new(repeat_last_n, temperature, None, None, penalty), "phi2".to_string(), )), if model_id.is_some() { @@ -194,10 +212,11 @@ pub fn get_model_loader<'a>( ), ModelSelected::Gemma { repeat_last_n, + temperature, penalty, } => ( Box::new(DefaultLoader::new( - SpecificConfig::new(repeat_last_n, None, None, None, penalty), + SpecificConfig::new(repeat_last_n, temperature, None, None, penalty), "gemma".to_string(), )), if model_id.is_some() { @@ -208,10 +227,11 @@ pub fn get_model_loader<'a>( ), ModelSelected::Mistral { repeat_last_n, + temperature, penalty, } => ( Box::new(DefaultLoader::new( - SpecificConfig::new(repeat_last_n, None, None, None, penalty), + SpecificConfig::new(repeat_last_n, temperature, None, None, penalty), "mistral".to_string(), )), if model_id.is_some() { diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index 561795a..15be2f6 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -69,9 +69,9 @@ impl RotaryEmbedding { .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; @@ -204,12 +204,17 @@ impl Attention { let v = value_states .reshape((b_sz, seq_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; - (q.contiguous()?, k.contiguous()?, v.contiguous()?) + (q, k, v.contiguous()?) }; - let (q, k) = self - .rotary_emb - .apply_rotary_emb_qkv(&q, &k, seqlen_offset)?; + let (q, k) = self.rotary_emb.apply_rotary_emb_qkv( + &q.to_dtype(DType::F32)?, + &k.to_dtype(DType::F32)?, + seqlen_offset, + )?; + + let q = q.to_dtype(v.dtype())?; + let k = k.to_dtype(v.dtype())?; let y = self.attn.forward( &q, diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index 0b3d499..f44b811 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -171,7 +171,7 @@ async fn chat_completions( } let sampling_params = sampling_params.unwrap(); - println!("{:?}", sampling_params); + // println!("{:?}", sampling_params); let created = get_created_time_secs(); diff --git a/src/openai/pipelines/pipeline.rs b/src/openai/pipelines/pipeline.rs index 34c185d..f9701fc 100644 --- a/src/openai/pipelines/pipeline.rs +++ b/src/openai/pipelines/pipeline.rs @@ -160,7 +160,7 @@ impl<'a> ModelLoader<'a> for DefaultLoader { dtype: DType, device: Device, ) -> Result<(Box>, PipelineConfig), APIError> { - let args = self.config.clone(); + let specific_args = self.config.clone(); let config = match self.name.as_str() { "llama" => { @@ -260,9 +260,9 @@ impl<'a> ModelLoader<'a> for DefaultLoader { let pipeline_config = PipelineConfig { max_model_len: config.max_seq_len, default_max_tokens, - penalty: self.config.penalty.unwrap_or(1.1), - repeat_last_n: self.config.repeat_last_n.unwrap_or(16), - temperature: self.config.temperature.unwrap_or(0.), + penalty: specific_args.penalty.unwrap_or(1.1), + repeat_last_n: specific_args.repeat_last_n.unwrap_or(32), + temperature: specific_args.temperature.unwrap_or(0.), }; println!("{:?}", pipeline_config); @@ -272,14 +272,14 @@ impl<'a> ModelLoader<'a> for DefaultLoader { None => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap(), }; - println!("{:?}", self.config); + println!("{:?}", specific_args); let logits_processor = { - let temperature = args.temperature.unwrap_or(0.) as f64; + let temperature = specific_args.temperature.unwrap_or(0.) as f64; let sampling = if temperature <= 0. { Sampling::ArgMax } else { - match (args.top_k, args.top_p) { + match (specific_args.top_k, specific_args.top_p) { (None, None) => Sampling::All { temperature }, (Some(k), None) => Sampling::TopK { k, temperature }, (None, Some(p)) => Sampling::TopP { p, temperature }, @@ -292,7 +292,7 @@ impl<'a> ModelLoader<'a> for DefaultLoader { Ok(( Box::new(DefaultPipeline { model, - args, + args: specific_args, tokenizer, logits_processor: logits_processor, conversation: DefaultConversation::new(