From 931d7e0b201d2c1c51801c3141509239a4c7a2b2 Mon Sep 17 00:00:00 2001 From: sigoden Date: Tue, 5 Nov 2024 22:00:53 +0800 Subject: [PATCH] feat: `.agent` accepts session name (#970) --- src/config/mod.rs | 39 ++++-------- src/config/session.rs | 2 +- src/repl/mod.rs | 145 +++++++++++++++++++----------------------- src/utils/path.rs | 21 ++++++ 4 files changed, 101 insertions(+), 106 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 7673133b..adf36f30 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1111,29 +1111,8 @@ impl Config { self.last_message = None; Ok(()) } - pub fn list_sessions(&self) -> Vec { - let sessions_dir = match self.sessions_dir() { - Ok(dir) => dir, - Err(_) => return vec![], - }; - match read_dir(sessions_dir) { - Ok(rd) => { - let mut names = vec![]; - for entry in rd.flatten() { - let name = entry.file_name(); - if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") { - if name.starts_with(TEMP_SESSION_NAME) { - continue; - } - names.push(name.to_string()); - } - } - names.sort_unstable(); - names - } - Err(_) => vec![], - } + list_file_names(self.sessions_dir().ok(), ".yaml") } pub fn should_compress_session(&mut self) -> bool { @@ -1366,8 +1345,8 @@ impl Config { pub async fn use_agent( config: &GlobalConfig, - name: &str, - session: Option<&str>, + agent_name: &str, + session_name: Option<&str>, abort_signal: AbortSignal, ) -> Result<()> { if !config.read().function_calling { @@ -1376,8 +1355,8 @@ impl Config { if config.read().agent.is_some() { bail!("Already in a agent, please run '.exit agent' first to exit the current agent."); } - let agent = Agent::init(config, name, abort_signal).await?; - let session = session + let agent = Agent::init(config, agent_name, abort_signal).await?; + let session = session_name .map(|v| v.to_string()) .or_else(|| agent.agent_prelude().map(|v| v.to_string())); config.write().rag = agent.rag(); @@ -1650,6 +1629,14 @@ impl Config { }; values = candidates.into_iter().map(|v| (v, None)).collect(); filter = args[1]; + } else if cmd == ".agent" && args.len() >= 2 { + let dir = Self::agent_data_dir(args[0]) + .ok() + .map(|v| v.join(SESSIONS_DIR_NAME)); + values = list_file_names(dir, ".yaml") + .into_iter() + .map(|v| (v, None)) + .collect(); } else if cmd == ".starter" && args.len() >= 2 { if let Some(agent) = &self.agent { values = agent diff --git a/src/config/session.rs b/src/config/session.rs index 8e82c9ae..8cdd793e 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -29,7 +29,7 @@ pub struct Session { #[serde(skip_serializing_if = "Option::is_none")] role_name: Option, - #[serde(skip_serializing_if = "IndexMap::is_empty")] + #[serde(default, skip_serializing_if = "IndexMap::is_empty")] agent_variables: IndexMap, #[serde(default, skip_serializing_if = "HashMap::is_empty")] diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 4482b849..2b4fc8f9 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -305,12 +305,17 @@ impl Repl { ".rag" => { Config::use_rag(&self.config, args, self.abort_signal.clone()).await?; } - ".agent" => match args { - Some(name) => { - Config::use_agent(&self.config, name, None, self.abort_signal.clone()) - .await?; + ".agent" => match split_args(args) { + Some((agent_name, session_name)) => { + Config::use_agent( + &self.config, + agent_name, + session_name, + self.abort_signal.clone(), + ) + .await?; } - None => println!(r#"Usage: .agent "#), + None => println!(r#"Usage: .agent [session-name]"#), }, ".starter" => match args { Some(value) => { @@ -330,58 +335,43 @@ impl Repl { println!("Usage: .variable ") } }, - ".save" => { - match args.map(|v| match v.split_once(' ') { - Some((subcmd, args)) => (subcmd, Some(args.trim())), - None => (v, None), - }) { - Some(("role", name)) => { - self.config.write().save_role(name)?; - } - Some(("session", name)) => { - self.config.write().save_session(name)?; - } - _ => { - println!(r#"Usage: .save [name]"#) - } + ".save" => match split_args(args) { + Some(("role", name)) => { + self.config.write().save_role(name)?; } - } - ".edit" => { - match args.map(|v| match v.split_once(' ') { - Some((subcmd, args)) => (subcmd, Some(args.trim())), - None => (v, None), - }) { - Some(("role", _)) => { - self.config.write().edit_role()?; - } - Some(("session", _)) => { - self.config.write().edit_session()?; - } - Some(("rag-docs", _)) => { - Config::edit_rag_docs(&self.config, self.abort_signal.clone()).await?; - } - _ => { - println!(r#"Usage: .edit "#) - } + Some(("session", name)) => { + self.config.write().save_session(name)?; } - } - ".compress" => { - match args.map(|v| match v.split_once(' ') { - Some((subcmd, args)) => (subcmd, Some(args.trim())), - None => (v, None), - }) { - Some(("session", _)) => { - let spinner = create_spinner("Compressing").await; - let ret = Config::compress_session(&self.config).await; - spinner.stop(); - ret?; - println!("✨ Successfully compressed the session."); - } - _ => { - println!(r#"Usage: .compress session"#) - } + _ => { + println!(r#"Usage: .save [name]"#) } - } + }, + ".edit" => match args { + Some("role") => { + self.config.write().edit_role()?; + } + Some("session") => { + self.config.write().edit_session()?; + } + Some("rag-docs") => { + Config::edit_rag_docs(&self.config, self.abort_signal.clone()).await?; + } + _ => { + println!(r#"Usage: .edit "#) + } + }, + ".compress" => match args { + Some("session") => { + let spinner = create_spinner("Compressing").await; + let ret = Config::compress_session(&self.config).await; + spinner.stop(); + ret?; + println!("✨ Successfully compressed the session."); + } + _ => { + println!(r#"Usage: .compress session"#) + } + }, ".empty" => match args { Some("session") => { self.config.write().empty_session()?; @@ -390,33 +380,23 @@ impl Repl { println!(r#"Usage: .empty session"#) } }, - ".rebuild" => { - match args.map(|v| match v.split_once(' ') { - Some((subcmd, args)) => (subcmd, Some(args.trim())), - None => (v, None), - }) { - Some(("rag", _)) => { - Config::rebuild_rag(&self.config, self.abort_signal.clone()).await?; - } - _ => { - println!(r#"Usage: .rebuild rag"#) - } + ".rebuild" => match args { + Some("rag") => { + Config::rebuild_rag(&self.config, self.abort_signal.clone()).await?; } - } - ".sources" => { - match args.map(|v| match v.split_once(' ') { - Some((subcmd, args)) => (subcmd, Some(args.trim())), - None => (v, None), - }) { - Some(("rag", _)) => { - let output = Config::rag_sources(&self.config)?; - println!("{}", output); - } - _ => { - println!(r#"Usage: .sources rag"#) - } + _ => { + println!(r#"Usage: .rebuild rag"#) } - } + }, + ".sources" => match args { + Some("rag") => { + let output = Config::rag_sources(&self.config)?; + println!("{}", output); + } + _ => { + println!(r#"Usage: .sources rag"#) + } + }, ".file" => match args { Some(args) => { let (files, text) = split_files_text(args); @@ -725,6 +705,13 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> { } } +fn split_args(args: Option<&str>) -> Option<(&str, Option<&str>)> { + args.map(|v| match v.split_once(' ') { + Some((subcmd, args)) => (subcmd, Some(args.trim())), + None => (v, None), + }) +} + fn split_files_text(args: &str) -> (&str, &str) { match SPLIT_FILES_TEXT_ARGS_RE.find(args).ok().flatten() { Some(mat) => { diff --git a/src/utils/path.rs b/src/utils/path.rs index a8534fb0..b0a5614f 100644 --- a/src/utils/path.rs +++ b/src/utils/path.rs @@ -42,6 +42,27 @@ pub async fn expand_glob_paths>(paths: &[T]) -> Result Ok(new_paths) } +pub fn list_file_names>(dir: Option, ext: &str) -> Vec { + let dir = match dir { + Some(v) => v, + None => return vec![], + }; + match std::fs::read_dir(dir) { + Ok(rd) => { + let mut names = vec![]; + for entry in rd.flatten() { + let name = entry.file_name(); + if let Some(name) = name.to_string_lossy().strip_suffix(ext) { + names.push(name.to_string()); + } + } + names.sort_unstable(); + names + } + Err(_) => vec![], + } +} + pub fn get_patch_extension(path: &str) -> Option { Path::new(&path) .extension()