Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error prompt for requested message exceeds model capacity #48

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ where
}
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct PipelineConfig {
pub max_model_len: usize,
pub default_max_tokens: usize,
}

#[derive(Clone)]
Expand Down
22 changes: 11 additions & 11 deletions src/openai/openai_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,20 @@ async fn check_length(
.map_err(APIError::from)?
};

let max_tokens = if let Some(max_toks) = request.max_tokens {
max_toks
} else {
data.pipeline_config.max_model_len - token_ids.len()
};
let max_gen_tokens = request
.max_tokens
.unwrap_or(data.pipeline_config.default_max_tokens);

if token_ids.len() + max_tokens > data.pipeline_config.max_model_len {
if token_ids.len() + max_gen_tokens > data.pipeline_config.max_model_len {
Err(APIError::new(format!(
"This model's maximum context length is {} tokens. \
However, you requested {} tokens ({} in the messages, \
{} in the completion). Please reduce the length of the \
messages or completion.",
{} in the completion). \nPlease clear the chat history or reduce the length of the \
messages.",
data.pipeline_config.max_model_len,
max_tokens + token_ids.len(),
max_gen_tokens + token_ids.len(),
token_ids.len(),
max_tokens
max_gen_tokens
)))
} else {
Ok(token_ids)
Expand Down Expand Up @@ -157,7 +155,9 @@ async fn chat_completions(
request.stop.clone(),
request.stop_token_ids.clone().unwrap_or_default(),
request.ignore_eos.unwrap_or(false),
request.max_tokens.unwrap_or(1024),
request
.max_tokens
.unwrap_or(data.pipeline_config.default_max_tokens),
None,
None,
request.skip_special_tokens.unwrap_or(true),
Expand Down
23 changes: 21 additions & 2 deletions src/openai/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use std::{iter::zip, path::PathBuf, sync::Arc};
use tokenizers::Tokenizer;
const EOS_TOKEN: &str = "</s>";
const SAMPLING_SEED: u64 = 299792458;
const MIN_GEN_TOKENS: usize = 128;
const MAX_GEN_TOKENS: usize = 4096;

#[derive(Debug, Clone)]
pub struct SpecificConfig {
Expand Down Expand Up @@ -160,6 +162,8 @@ impl<'a> ModelLoader<'a> for DefaultLoader {
_ => panic!(""),
};

println!("Model {:?}", config);

println!("Loading {} model.", self.name);

let vb = match unsafe {
Expand Down Expand Up @@ -192,9 +196,24 @@ impl<'a> ModelLoader<'a> for DefaultLoader {

println!("Done loading.");

//max is https://huggingface.co/docs/transformers/model_doc/llama2#transformers.LlamaConfig.max_position_embeddings
//max and min number of tokens generated per request
let mut default_max_tokens = config.max_seq_len / 10;
if default_max_tokens < MIN_GEN_TOKENS {
default_max_tokens = MIN_GEN_TOKENS;
} else if default_max_tokens > MAX_GEN_TOKENS {
default_max_tokens = MAX_GEN_TOKENS;
}

let pipeline_config = PipelineConfig {
max_model_len: 4096,
max_model_len: config.max_seq_len,
default_max_tokens,
};

println!("{:?}", pipeline_config);

let eos_token = match tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => tokenizer.tokenizer().token_to_id(EOS_TOKEN).unwrap(),
};

let eos_token = match tokenizer.get_token("<|endoftext|>") {
Expand Down
11 changes: 9 additions & 2 deletions src/openai/responses.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::openai::sampling_params::Logprobs;
use actix_web::error;
use actix_web::{error, HttpResponse};
use derive_more::{Display, Error};

use serde::{Deserialize, Serialize};
Expand All @@ -10,7 +10,14 @@ pub struct APIError {
data: String,
}

impl error::ResponseError for APIError {}
impl error::ResponseError for APIError {
fn error_response(&self) -> HttpResponse {
//pack error to json so that client can handle it
HttpResponse::BadRequest()
.content_type("application/json")
.json(self.data.to_string())
}
}

impl APIError {
pub fn new(data: String) -> Self {
Expand Down
Loading