Skip to content

Commit

Permalink
feat: support customizing top_p parameter (#434)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Apr 24, 2024
1 parent 040c48b commit a17f349
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 41 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ Feel free to adjust the configuration according to your needs.
> Get `config.yaml` path with command `aichat --info` or repl command `.info`.
```yaml
model: openai:gpt-3.5-turbo # The Large Language Model (LLM) to use
temperature: 1.0 # Controls the randomness and creativity of the LLM's responses
model: openai:gpt-3.5-turbo # Specify the language model to use
temperature: null # Set default temperature parameter
top_p: null # Set default top-p parameter
save: true # Indicates whether to persist the message
save_session: null # Controls the persistence of the session, if null, asking the user
highlight: true # Controls syntax highlighting
Expand Down
5 changes: 3 additions & 2 deletions config.example.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model: openai:gpt-3.5-turbo # The Large Language Model (LLM) to use
temperature: 1.0 # Controls the randomness and creativity of the LLM's responses
model: openai:gpt-3.5-turbo # Specify the language model to use
temperature: null # Set default temperature parameter
top_p: null # Set default top-p parameter
save: true # Indicates whether to persist the message
save_session: null # Controls the persistence of the session, if null, asking the user
highlight: true # Controls syntax highlighting
Expand Down
4 changes: 4 additions & 0 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
temperature,
top_p,
stream,
} = data;

Expand Down Expand Up @@ -205,6 +206,9 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
if let Some(v) = temperature {
body["temperature"] = v.into();
}
if let Some(v) = top_p {
body["top_p"] = v.into();
}
if stream {
body["stream"] = true.into();
}
Expand Down
4 changes: 4 additions & 0 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
mut messages,
temperature,
top_p,
stream,
} = data;

Expand Down Expand Up @@ -173,6 +174,9 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
if let Some(temperature) = temperature {
body["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["p"] = top_p.into();
}
if stream {
body["stream"] = true.into();
}
Expand Down
1 change: 1 addition & 0 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ pub struct ExtraConfig {
pub struct SendData {
pub messages: Vec<Message>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub stream: bool,
}

Expand Down
4 changes: 4 additions & 0 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ fn build_body(data: SendData, model: &Model) -> Value {
let SendData {
mut messages,
temperature,
top_p,
stream,
} = data;

Expand All @@ -242,6 +243,9 @@ fn build_body(data: SendData, model: &Model) -> Value {
if let Some(temperature) = temperature {
body["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["top_p"] = top_p.into();
}

if let Some(max_output_tokens) = model.max_output_tokens {
body["max_output_tokens"] = max_output_tokens.into();
Expand Down
4 changes: 4 additions & 0 deletions src/client/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
let SendData {
messages,
temperature,
top_p,
stream,
} = data;

Expand Down Expand Up @@ -185,6 +186,9 @@ fn build_body(data: SendData, model: &Model) -> Result<Value> {
if let Some(temperature) = temperature {
body["options"]["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["options"]["top_p"] = top_p.into();
}

Ok(body)
}
Expand Down
12 changes: 8 additions & 4 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
let SendData {
messages,
temperature,
top_p,
stream,
} = data;

Expand All @@ -139,13 +140,16 @@ pub fn openai_build_body(data: SendData, model: &Model) -> Value {
});

if let Some(max_tokens) = model.max_output_tokens {
body["max_tokens"] = json!(max_tokens);
body["max_tokens"] = max_tokens.into();
} else if model.name == "gpt-4-vision-preview" {
// The default max_tokens of gpt-4-vision-preview is only 16, we need to make it larger
body["max_tokens"] = json!(4096);
body["max_tokens"] = 4096.into();
}
if let Some(v) = temperature {
body["temperature"] = v.into();
if let Some(temperature) = temperature {
body["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
body["top_p"] = top_p.into();
}
if stream {
body["stream"] = true.into();
Expand Down
53 changes: 24 additions & 29 deletions src/client/qianwen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ async fn send_message_streaming(
is_vl: bool,
) -> Result<()> {
let mut es = builder.eventsource()?;
let mut offset = 0;

while let Some(event) = es.next().await {
match event {
Expand All @@ -139,12 +138,10 @@ async fn send_message_streaming(
let data: Value = serde_json::from_str(&message.data)?;
catch_error(&data)?;
if is_vl {
let text =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str();
if let Some(text) = text {
let text = &text[offset..];
if let Some(text) =
data["output"]["choices"][0]["message"]["content"][0]["text"].as_str()
{
handler.text(text)?;
offset += text.len();
}
} else if let Some(text) = data["output"]["text"].as_str() {
handler.text(text)?;
Expand All @@ -169,11 +166,12 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
let SendData {
messages,
temperature,
top_p,
stream,
} = data;

let mut has_upload = false;
let (input, parameters) = if is_vl {
let input = if is_vl {
let messages: Vec<Value> = messages
.into_iter()
.map(|message| {
Expand All @@ -199,40 +197,37 @@ fn build_body(data: SendData, model: &Model, is_vl: bool) -> Result<(Value, bool
})
.collect();

let input = json!({
json!({
"messages": messages,
});

let mut parameters = json!({});
if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
(input, parameters)
})
} else {
let input = json!({
json!({
"messages": messages,
});
})
};

let mut parameters = json!({});
if stream {
parameters["incremental_output"] = true.into();
}
let mut parameters = json!({});
if stream {
parameters["incremental_output"] = true.into();
}

if let Some(max_tokens) = model.max_output_tokens {
parameters["max_tokens"] = max_tokens.into();
}
if let Some(max_tokens) = model.max_output_tokens {
parameters["max_tokens"] = max_tokens.into();
}

if let Some(v) = temperature {
parameters["temperature"] = v.into();
}
(input, parameters)
};
if let Some(temperature) = temperature {
parameters["temperature"] = temperature.into();
}
if let Some(top_p) = top_p {
parameters["top_p"] = top_p.into();
}

let body = json!({
"model": &model.name,
"input": input,
"parameters": parameters
});

Ok((body, has_upload))
}

Expand Down
7 changes: 6 additions & 1 deletion src/client/vertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ pub(crate) fn build_body(
let SendData {
mut messages,
temperature,
..
top_p,
stream: _,
} = data;

patch_system_message(&mut messages);
Expand Down Expand Up @@ -223,6 +224,10 @@ pub(crate) fn build_body(
body["generationConfig"]["temperature"] = temperature.into();
}

if let Some(top_p) = top_p {
body["generationConfig"]["topP"] = top_p.into();
}

Ok(body)
}

Expand Down
32 changes: 32 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub struct Config {
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub dry_run: bool,
pub save: bool,
pub save_session: Option<bool>,
Expand Down Expand Up @@ -89,6 +90,7 @@ impl Default for Config {
Self {
model_id: None,
temperature: None,
top_p: None,
save: true,
save_session: None,
highlight: true,
Expand Down Expand Up @@ -297,6 +299,7 @@ impl Config {
if let Some(session) = self.session.as_mut() {
session.guard_empty()?;
session.set_temperature(role.temperature);
session.set_top_p(role.top_p);
}
self.role = Some(role);
Ok(())
Expand Down Expand Up @@ -335,6 +338,16 @@ impl Config {
}
}

pub fn set_top_p(&mut self, value: Option<f64>) {
if let Some(session) = self.session.as_mut() {
session.set_top_p(value);
} else if let Some(role) = self.role.as_mut() {
role.set_top_p(value);
} else {
self.top_p = value;
}
}

pub fn set_save_session(&mut self, value: Option<bool>) {
if let Some(session) = self.session.as_mut() {
session.set_save_session(value);
Expand Down Expand Up @@ -411,6 +424,7 @@ impl Config {
let items = vec![
("model", self.model.id()),
("temperature", format_option(&self.temperature)),
("top_p", format_option(&self.top_p)),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("save_session", format_option(&self.save_session)),
Expand Down Expand Up @@ -478,6 +492,7 @@ impl Config {
".session" => self.list_sessions(),
".set" => vec![
"temperature ",
"top_p ",
"compress_threshold",
"save ",
"save_session ",
Expand Down Expand Up @@ -529,6 +544,10 @@ impl Config {
let value = parse_value(value)?;
self.set_temperature(value);
}
"top_p" => {
let value = parse_value(value)?;
self.set_top_p(value);
}
"compress_threshold" => {
let value = parse_value(value)?;
self.set_compress_threshold(value);
Expand Down Expand Up @@ -756,10 +775,18 @@ impl Config {
} else {
self.temperature
};
let top_p = if let Some(session) = input.session(&self.session) {
session.top_p()
} else if let Some(role) = input.role() {
role.top_p
} else {
self.top_p
};
self.model.max_input_tokens_limit(&messages)?;
Ok(SendData {
messages,
temperature,
top_p,
stream,
})
}
Expand Down Expand Up @@ -791,6 +818,11 @@ impl Config {
output.insert("temperature", temperature.to_string());
}
}
if let Some(top_p) = self.top_p {
if top_p != 0.0 {
output.insert("top_p", top_p.to_string());
}
}
if self.dry_run {
output.insert("dry_run", "true".to_string());
}
Expand Down
Loading

0 comments on commit a17f349

Please sign in to comment.