Skip to content

Commit

Permalink
feat: support overriding agent config with env vars (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Nov 6, 2024
1 parent 0fac7fe commit 3ffa876
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 34 deletions.
29 changes: 28 additions & 1 deletion src/config/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Agent {
let functions_file_path = functions_dir.join("functions.json");
let rag_path = Config::agent_rag_file(name, DEFAULT_AGENT_NAME)?;
let config_path = Config::agent_config_file(name)?;
let agent_config = if config_path.exists() {
let mut agent_config = if config_path.exists() {
AgentConfig::load(&config_path)?
} else {
AgentConfig::new(&config.read())
Expand All @@ -54,6 +54,8 @@ impl Agent {
};
definition.replace_tools_placeholder(&functions);

agent_config.load_envs(&definition.name);

let model = {
let config = config.read();
match agent_config.model_id.as_ref() {
Expand Down Expand Up @@ -330,6 +332,31 @@ impl AgentConfig {
.with_context(|| format!("Failed to load agent config at '{}'", path.display()))?;
Ok(config)
}

fn load_envs(&mut self, name: &str) {
let with_prefix = |v: &str| normalize_env_name(&format!("{name}_{v}"));

if let Some(v) = read_env_value::<String>(&with_prefix("model")) {
self.model_id = v;
}
if let Some(v) = read_env_value::<f64>(&with_prefix("temperature")) {
self.temperature = v;
}
if let Some(v) = read_env_value::<f64>(&with_prefix("top_p")) {
self.top_p = v;
}
if let Some(v) = read_env_value::<String>(&with_prefix("use_tools")) {
self.use_tools = v;
}
if let Some(v) = read_env_value::<String>(&with_prefix("agent_prelude")) {
self.agent_prelude = v;
}
if let Ok(v) = env::var(with_prefix("variables")) {
if let Ok(v) = serde_json::from_str(&v) {
self.variables = v;
}
}
}
}

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
Expand Down
67 changes: 34 additions & 33 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1942,94 +1942,95 @@ impl Config {
if let Ok(v) = env::var(get_env_name("model")) {
self.model_id = v;
}
if let Some(v) = read_env_value::<f64>("temperature") {
if let Some(v) = read_env_value::<f64>(&get_env_name("temperature")) {
self.temperature = v;
}
if let Some(v) = read_env_value::<f64>("top_p") {
if let Some(v) = read_env_value::<f64>(&get_env_name("top_p")) {
self.top_p = v;
}

if let Some(Some(v)) = read_env_bool("dry_run") {
if let Some(Some(v)) = read_env_bool(&get_env_name("dry_run")) {
self.dry_run = v;
}
if let Some(Some(v)) = read_env_bool("stream") {
if let Some(Some(v)) = read_env_bool(&get_env_name("stream")) {
self.stream = v;
}
if let Some(Some(v)) = read_env_bool("save") {
if let Some(Some(v)) = read_env_bool(&get_env_name("save")) {
self.save = v;
}
if let Ok(v) = env::var(get_env_name("keybindings")) {
if v == "vi" {
self.keybindings = v;
}
}
if let Some(v) = read_env_value::<String>("editor") {
if let Some(v) = read_env_value::<String>(&get_env_name("editor")) {
self.editor = v;
}
if let Some(v) = read_env_value::<String>("wrap") {
if let Some(v) = read_env_value::<String>(&get_env_name("wrap")) {
self.wrap = v;
}
if let Some(Some(v)) = read_env_bool("wrap_code") {
if let Some(Some(v)) = read_env_bool(&get_env_name("wrap_code")) {
self.wrap_code = v;
}

if let Some(Some(v)) = read_env_bool("function_calling") {
if let Some(Some(v)) = read_env_bool(&get_env_name("function_calling")) {
self.function_calling = v;
}
if let Ok(v) = env::var(get_env_name("mapping_tools")) {
if let Ok(v) = serde_json::from_str(&v) {
self.mapping_tools = v;
}
}
if let Some(v) = read_env_value::<String>("use_tools") {
if let Some(v) = read_env_value::<String>(&get_env_name("use_tools")) {
self.use_tools = v;
}

if let Some(v) = read_env_value::<String>("prelude") {
if let Some(v) = read_env_value::<String>(&get_env_name("prelude")) {
self.prelude = v;
}
if let Some(v) = read_env_value::<String>("repl_prelude") {
if let Some(v) = read_env_value::<String>(&get_env_name("repl_prelude")) {
self.repl_prelude = v;
}
if let Some(v) = read_env_value::<String>("agent_prelude") {
if let Some(v) = read_env_value::<String>(&get_env_name("agent_prelude")) {
self.agent_prelude = v;
}

if let Some(v) = read_env_bool("save_session") {
if let Some(v) = read_env_bool(&get_env_name("save_session")) {
self.save_session = v;
}
if let Some(Some(v)) = read_env_value::<usize>("compress_threshold") {
if let Some(Some(v)) = read_env_value::<usize>(&get_env_name("compress_threshold")) {
self.compress_threshold = v;
}
if let Some(v) = read_env_value::<String>("summarize_prompt") {
if let Some(v) = read_env_value::<String>(&get_env_name("summarize_prompt")) {
self.summarize_prompt = v;
}
if let Some(v) = read_env_value::<String>("summary_prompt") {
if let Some(v) = read_env_value::<String>(&get_env_name("summary_prompt")) {
self.summary_prompt = v;
}

if let Some(v) = read_env_value::<String>("rag_embedding_model") {
if let Some(v) = read_env_value::<String>(&get_env_name("rag_embedding_model")) {
self.rag_embedding_model = v;
}
if let Some(v) = read_env_value::<String>("rag_reranker_model") {
if let Some(v) = read_env_value::<String>(&get_env_name("rag_reranker_model")) {
self.rag_reranker_model = v;
}
if let Some(Some(v)) = read_env_value::<usize>("rag_top_k") {
if let Some(Some(v)) = read_env_value::<usize>(&get_env_name("rag_top_k")) {
self.rag_top_k = v;
}
if let Some(v) = read_env_value::<usize>("rag_chunk_size") {
if let Some(v) = read_env_value::<usize>(&get_env_name("rag_chunk_size")) {
self.rag_chunk_size = v;
}
if let Some(v) = read_env_value::<usize>("rag_chunk_overlap") {
if let Some(v) = read_env_value::<usize>(&get_env_name("rag_chunk_overlap")) {
self.rag_chunk_overlap = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_vector_search") {
if let Some(Some(v)) = read_env_value::<f32>(&get_env_name("rag_min_score_vector_search")) {
self.rag_min_score_vector_search = v;
}
if let Some(Some(v)) = read_env_value::<f32>("rag_min_score_keyword_search") {
if let Some(Some(v)) = read_env_value::<f32>(&get_env_name("rag_min_score_keyword_search"))
{
self.rag_min_score_keyword_search = v;
}
if let Some(v) = read_env_value::<String>("rag_template") {
if let Some(v) = read_env_value::<String>(&get_env_name("rag_template")) {
self.rag_template = v;
}

Expand All @@ -2039,13 +2040,13 @@ impl Config {
}
}

if let Some(Some(v)) = read_env_bool("highlight") {
if let Some(Some(v)) = read_env_bool(&get_env_name("highlight")) {
self.highlight = v;
}
if *NO_COLOR {
self.highlight = false;
}
if let Some(Some(v)) = read_env_bool("light_theme") {
if let Some(Some(v)) = read_env_bool(&get_env_name("light_theme")) {
self.light_theme = v;
} else if !self.light_theme {
if let Ok(v) = env::var("COLORFGBG") {
Expand All @@ -2054,17 +2055,17 @@ impl Config {
}
}
}
if let Some(v) = read_env_value::<String>("left_prompt") {
if let Some(v) = read_env_value::<String>(&get_env_name("left_prompt")) {
self.left_prompt = v;
}
if let Some(v) = read_env_value::<String>("right_prompt") {
if let Some(v) = read_env_value::<String>(&get_env_name("right_prompt")) {
self.right_prompt = v;
}

if let Some(v) = read_env_value::<String>("serve_addr") {
if let Some(v) = read_env_value::<String>(&get_env_name("serve_addr")) {
self.serve_addr = v;
}
if let Some(v) = read_env_value::<String>("user_agent") {
if let Some(v) = read_env_value::<String>(&get_env_name("user_agent")) {
self.user_agent = v;
}
}
Expand Down Expand Up @@ -2230,7 +2231,7 @@ fn read_env_value<T>(key: &str) -> Option<Option<T>>
where
T: std::str::FromStr,
{
let value = env::var(get_env_name(key)).ok()?;
let value = env::var(key).ok()?;
let value = parse_value(&value).ok()?;
Some(value)
}
Expand All @@ -2252,7 +2253,7 @@ where
}

fn read_env_bool(key: &str) -> Option<Option<bool>> {
let value = env::var(get_env_name(key)).ok()?;
let value = env::var(key).ok()?;
Some(parse_bool(&value))
}

Expand Down

0 comments on commit 3ffa876

Please sign in to comment.