Skip to content

Commit

Permalink
feat: .agent accepts session name (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Nov 5, 2024
1 parent 42deaa0 commit 931d7e0
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 106 deletions.
39 changes: 13 additions & 26 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1111,29 +1111,8 @@ impl Config {
self.last_message = None;
Ok(())
}

pub fn list_sessions(&self) -> Vec<String> {
let sessions_dir = match self.sessions_dir() {
Ok(dir) => dir,
Err(_) => return vec![],
};
match read_dir(sessions_dir) {
Ok(rd) => {
let mut names = vec![];
for entry in rd.flatten() {
let name = entry.file_name();
if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") {
if name.starts_with(TEMP_SESSION_NAME) {
continue;
}
names.push(name.to_string());
}
}
names.sort_unstable();
names
}
Err(_) => vec![],
}
list_file_names(self.sessions_dir().ok(), ".yaml")
}

pub fn should_compress_session(&mut self) -> bool {
Expand Down Expand Up @@ -1366,8 +1345,8 @@ impl Config {

pub async fn use_agent(
config: &GlobalConfig,
name: &str,
session: Option<&str>,
agent_name: &str,
session_name: Option<&str>,
abort_signal: AbortSignal,
) -> Result<()> {
if !config.read().function_calling {
Expand All @@ -1376,8 +1355,8 @@ impl Config {
if config.read().agent.is_some() {
bail!("Already in a agent, please run '.exit agent' first to exit the current agent.");
}
let agent = Agent::init(config, name, abort_signal).await?;
let session = session
let agent = Agent::init(config, agent_name, abort_signal).await?;
let session = session_name
.map(|v| v.to_string())
.or_else(|| agent.agent_prelude().map(|v| v.to_string()));
config.write().rag = agent.rag();
Expand Down Expand Up @@ -1650,6 +1629,14 @@ impl Config {
};
values = candidates.into_iter().map(|v| (v, None)).collect();
filter = args[1];
} else if cmd == ".agent" && args.len() >= 2 {
let dir = Self::agent_data_dir(args[0])
.ok()
.map(|v| v.join(SESSIONS_DIR_NAME));
values = list_file_names(dir, ".yaml")
.into_iter()
.map(|v| (v, None))
.collect();
} else if cmd == ".starter" && args.len() >= 2 {
if let Some(agent) = &self.agent {
values = agent
Expand Down
2 changes: 1 addition & 1 deletion src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct Session {

#[serde(skip_serializing_if = "Option::is_none")]
role_name: Option<String>,
#[serde(skip_serializing_if = "IndexMap::is_empty")]
#[serde(default, skip_serializing_if = "IndexMap::is_empty")]
agent_variables: IndexMap<String, String>,

#[serde(default, skip_serializing_if = "HashMap::is_empty")]
Expand Down
145 changes: 66 additions & 79 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,17 @@ impl Repl {
".rag" => {
Config::use_rag(&self.config, args, self.abort_signal.clone()).await?;
}
".agent" => match args {
Some(name) => {
Config::use_agent(&self.config, name, None, self.abort_signal.clone())
.await?;
".agent" => match split_args(args) {
Some((agent_name, session_name)) => {
Config::use_agent(
&self.config,
agent_name,
session_name,
self.abort_signal.clone(),
)
.await?;
}
None => println!(r#"Usage: .agent <name>"#),
None => println!(r#"Usage: .agent <agent-name> [session-name]"#),
},
".starter" => match args {
Some(value) => {
Expand All @@ -330,58 +335,43 @@ impl Repl {
println!("Usage: .variable <key> <value>")
}
},
".save" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
None => (v, None),
}) {
Some(("role", name)) => {
self.config.write().save_role(name)?;
}
Some(("session", name)) => {
self.config.write().save_session(name)?;
}
_ => {
println!(r#"Usage: .save <role|session> [name]"#)
}
".save" => match split_args(args) {
Some(("role", name)) => {
self.config.write().save_role(name)?;
}
}
".edit" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
None => (v, None),
}) {
Some(("role", _)) => {
self.config.write().edit_role()?;
}
Some(("session", _)) => {
self.config.write().edit_session()?;
}
Some(("rag-docs", _)) => {
Config::edit_rag_docs(&self.config, self.abort_signal.clone()).await?;
}
_ => {
println!(r#"Usage: .edit <role|session|rag-docs>"#)
}
Some(("session", name)) => {
self.config.write().save_session(name)?;
}
}
".compress" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
None => (v, None),
}) {
Some(("session", _)) => {
let spinner = create_spinner("Compressing").await;
let ret = Config::compress_session(&self.config).await;
spinner.stop();
ret?;
println!("✨ Successfully compressed the session.");
}
_ => {
println!(r#"Usage: .compress session"#)
}
_ => {
println!(r#"Usage: .save <role|session> [name]"#)
}
}
},
".edit" => match args {
Some("role") => {
self.config.write().edit_role()?;
}
Some("session") => {
self.config.write().edit_session()?;
}
Some("rag-docs") => {
Config::edit_rag_docs(&self.config, self.abort_signal.clone()).await?;
}
_ => {
println!(r#"Usage: .edit <role|session|rag-docs>"#)
}
},
".compress" => match args {
Some("session") => {
let spinner = create_spinner("Compressing").await;
let ret = Config::compress_session(&self.config).await;
spinner.stop();
ret?;
println!("✨ Successfully compressed the session.");
}
_ => {
println!(r#"Usage: .compress session"#)
}
},
".empty" => match args {
Some("session") => {
self.config.write().empty_session()?;
Expand All @@ -390,33 +380,23 @@ impl Repl {
println!(r#"Usage: .empty session"#)
}
},
".rebuild" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
None => (v, None),
}) {
Some(("rag", _)) => {
Config::rebuild_rag(&self.config, self.abort_signal.clone()).await?;
}
_ => {
println!(r#"Usage: .rebuild rag"#)
}
".rebuild" => match args {
Some("rag") => {
Config::rebuild_rag(&self.config, self.abort_signal.clone()).await?;
}
}
".sources" => {
match args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
None => (v, None),
}) {
Some(("rag", _)) => {
let output = Config::rag_sources(&self.config)?;
println!("{}", output);
}
_ => {
println!(r#"Usage: .sources rag"#)
}
_ => {
println!(r#"Usage: .rebuild rag"#)
}
}
},
".sources" => match args {
Some("rag") => {
let output = Config::rag_sources(&self.config)?;
println!("{}", output);
}
_ => {
println!(r#"Usage: .sources rag"#)
}
},
".file" => match args {
Some(args) => {
let (files, text) = split_files_text(args);
Expand Down Expand Up @@ -725,6 +705,13 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
}
}

fn split_args(args: Option<&str>) -> Option<(&str, Option<&str>)> {
args.map(|v| match v.split_once(' ') {
Some((subcmd, args)) => (subcmd, Some(args.trim())),
None => (v, None),
})
}

fn split_files_text(args: &str) -> (&str, &str) {
match SPLIT_FILES_TEXT_ARGS_RE.find(args).ok().flatten() {
Some(mat) => {
Expand Down
21 changes: 21 additions & 0 deletions src/utils/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,27 @@ pub async fn expand_glob_paths<T: AsRef<str>>(paths: &[T]) -> Result<Vec<String>
Ok(new_paths)
}

pub fn list_file_names<T: AsRef<Path>>(dir: Option<T>, ext: &str) -> Vec<String> {
let dir = match dir {
Some(v) => v,
None => return vec![],
};
match std::fs::read_dir(dir) {
Ok(rd) => {
let mut names = vec![];
for entry in rd.flatten() {
let name = entry.file_name();
if let Some(name) = name.to_string_lossy().strip_suffix(ext) {
names.push(name.to_string());
}
}
names.sort_unstable();
names
}
Err(_) => vec![],
}
}

pub fn get_patch_extension(path: &str) -> Option<String> {
Path::new(&path)
.extension()
Expand Down

0 comments on commit 931d7e0

Please sign in to comment.