diff --git a/Cargo.lock b/Cargo.lock index 0f4b4b4..6ddac2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -137,7 +137,7 @@ dependencies = [ ] [[package]] -name = "ata2" +name = "ata" version = "3.0.0" dependencies = [ "ansi-colors", @@ -158,6 +158,7 @@ dependencies = [ "serde", "serde_json", "tokio", + "tokio-stream", "toml 0.6.0", ] @@ -1866,6 +1867,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 845dd76..9e9958f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ members = [ ] resolver = "2" -[profile.release.package.ata2] +[profile.release.package.ata] strip = false [profile.release] diff --git a/GNUmakefile b/GNUmakefile new file mode 100644 index 0000000..a656df7 --- /dev/null +++ b/GNUmakefile @@ -0,0 +1,12 @@ +all: README.md ata + +README.md: README.md.sh + ./$< > $@ + +.PHONY: ata +ata: + cargo build --release + +.PHONY: install +install: + cargo install --path ./ata² diff --git a/README.md b/README.md index 36bf44e..397fb5e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +

ata²: Ask the Terminal Anything

ChatGPT in the terminal

@@ -74,6 +75,45 @@ You may also: $ cargo install --path . ``` +### Keybindings +```text +Keyboard shortcuts: +ata²-specific: +Ctrl-D, EOF (In multiline mode) Send the current message. +F2 Save the current conversation (not including the message + you're typing) to a file. + +rustyline: +Ctrl-A, Home Move cursor to the beginning of line +Ctrl-B, Left Move cursor one character left +Ctrl-E, End Move cursor to end of line +Ctrl-F, Right Move cursor one character right +Ctrl-H, Backspace Delete character before cursor +Ctrl-I, Tab Next completion +Ctrl-K Delete from cursor to end of line +Ctrl-L Clear screen +Ctrl-N, Down Next match from history +Ctrl-P, Up Previous match from history +Ctrl-X Ctrl-U Undo +Ctrl-Y Paste from Yank buffer (Meta-Y to paste next yank instead) +Meta-< Move to first entry in history +Meta-> Move to last entry in history +Meta-B, Alt-Left Move cursor to previous word +Meta-C Capitalize the current word +Meta-D Delete forwards one word +Meta-F, Alt-Right Move cursor to next word +Meta-L Lower-case the next word +Meta-T Transpose words +Meta-U Upper-case the next word +Meta-Y See Ctrl-Y +Meta-Backspace Kill from the start of the current word, or, if between + words, to the start of the previous word +Meta-0, 1, ..., - Specify the digit to the argument. – starts a negative + argument. + +Thanks to . +``` + # License Copyright 2023 Fredrick R. Brennan <copypaste@kittens.ph>, Rik Huijzer <rikhuijzer@pm.me>, & ATA Project Authors diff --git a/README.md.sh b/README.md.sh new file mode 100755 index 0000000..11b6d07 --- /dev/null +++ b/README.md.sh @@ -0,0 +1,103 @@ +#!/bin/bash +cat << 'EOF' + +

ata²: Ask the Terminal Anything

+ +

ChatGPT in the terminal

+ +[![asciicast](https://asciinema.org/a/sOgAo4BkUXBJTSgyjIZw2mnFr.svg)](https://asciinema.org/a/sOgAo4BkUXBJTSgyjIZw2mnFr) + +## This is a fork! + +The original project, `ata`, by Rik Huijzer is [elsewhere](https://github.com/rikhuijzer/ata). + +This fork implements many new config options and features. + +

+TIP:
+ Run a terminal with this tool in your background and show/hide it with a keypress.
+ This can be done via: Iterm2 (Mac), Guake (Ubuntu), scratchpad (i3/sway), yakuake (KDE), or the quake mode for the Windows Terminal. +

+ +## Productivity benefits + +- The terminal starts more quickly and requires **less resources** than a browser. +- The **keyboard shortcuts** allow for quick interaction with the query. For example, press `CTRL + c` to cancel the stream, `CTRL + ↑` to get the previous query again, and `CTRL + w` to remove the last word. +- A terminal can be set to **run in the background and show/hide with one keypress**. To do this, use iTerm2 (Mac), Guake (Ubuntu), scratchpad (i3/sway), or the quake mode for the Windows Terminal. +- The prompts are **reproducible** because each prompt is sent as a stand-alone prompt without history. Tweaking the prompt can be done by pressing `CTRL + ↑` and making changes. + +## Usage + +Download the binary for your system from [Releases](https://github.com/ctrlcctrlv/ata2/releases). +If you're running Arch Linux, then you can use the AUR package: [ata2](https://aur.archlinux.org/packages/ata2) + +To specify the API key and some basic model settings, start the application. +It should give an error and the option to create a configuration file called `ata2.toml` for you. +Press `y` and `ENTER` to create a `ata2.toml` file. + +Next, request an API key via and update the key in the example configuration file. + +For more information, see: + +```sh +$ ata2 --help +``` + +## FAQ + +**How much will I have to pay for the API?** + +Using OpenAI's API for chat is very cheap. +Let's say that an average response is about 500 tokens, so costs $0.001. +That means that if you do 100 requests per day, which is a lot, then that will cost you about $0.10 per day ($3 per month). +OpenAI grants you $18.00 for free, so you can use the API for about 180 days (6 months) before having to pay. + +**How does this compare to LLM-based search engines such as You.com or Bing Chat?** + +At the time of writing, the OpenAI API responds much quicker than the large language model-based search engines and contains no adds. +It is particularly useful to quickly look up some things like Unicode symbols, historical facts, or word meanings. + +**Can I build the binary myself?** + +Yes, you can clone the repository and build the project via [`Cargo`](https://github.com/rust-lang/cargo). +Make sure that you have `Cargo` installed and then run: + +```sh +$ git clone https://github.com/ctrlcctrlv/ata2.git + +$ cd ata2/ + +$ cargo build --release +``` +After this, your binary should be available at `target/release/ata2` (Unix-based) or `target/release/ata2.exe` (Windows). + +You may also: + +```sh +$ cargo install --path . +``` + +### Keybindings +```text +EOF +cat ./ata²/src/help/keybindings.txt +cat << 'EOF' +``` + +# License + + Copyright 2023 Fredrick R. Brennan <copypaste@kittens.ph>, Rik Huijzer <rikhuijzer@pm.me>, & ATA Project Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +EOF diff --git "a/ata\302\262/Cargo.toml" "b/ata\302\262/Cargo.toml" index 170dc9f..e1af5c3 100644 --- "a/ata\302\262/Cargo.toml" +++ "b/ata\302\262/Cargo.toml" @@ -1,8 +1,12 @@ [package] -name = "ata2" -version = "3.0.0" +name = "ata" +version = "3.1.0" edition = "2021" authors = ["Fredrick R. Brennan ", "Rik Huijzer ", "ATA Project Authors"] +homepage = "https://github.com/ctrlcctrlv/ata2" +repository = "https://github.com/ctrlcctrlv/ata2" +readme = "README.md" +description = "Ask the Terminal Anything² — ChatGPT¾ in your terminal" license = "Apache-2.0" [[bin]] @@ -28,6 +32,7 @@ once_cell = "1.18.0" atty = "0.2.14" async-openai = { version = "0.16.2", features = ["native-tls-vendored"] } futures-util = { version = "0.3.29", features = ["io"] } +tokio-stream = { version = "0.1.14", features = ["sync", "full"] } [dev-dependencies] pretty_assertions = "1" diff --git "a/ata\302\262/src/args.rs" "b/ata\302\262/src/args.rs" index a2f2242..0782b38 100644 --- "a/ata\302\262/src/args.rs" +++ "b/ata\302\262/src/args.rs" @@ -1,3 +1,5 @@ +//! Command-line argument parsing (using [`clap`]). +//! //! # ata² //! //! © 2023 Fredrick R. Brennan @@ -15,20 +17,14 @@ //! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //! See the License for the specific language governing permissions and //! limitations under the License. -//! -//! Ask the Terminal Anything (ATA): OpenAI GPT in the terminal use crate::config::ConfigLocation; use clap::Parser; - -use once_cell::sync::Lazy; - -#[allow(non_upper_case_globals)] -static AUTHORS: Lazy<&'static str> = Lazy::new(|| crate_authors!("\n\t")); +use clap::{crate_authors, crate_version}; #[derive(Parser, Debug)] -#[command(author = &*AUTHORS, version = crate_version!(), +#[command(author = crate_authors!(), version = crate_version!(), about, long_about = None, help_template = "{before-help}{name} {version} — {about}\ \n\n\ @@ -49,4 +45,8 @@ pub struct Ata2 { /// Print the keyboard shortcuts. #[arg(long)] pub print_shortcuts: bool, + + /// Conversation file to load. + #[arg(short = 'l', long = "load")] + pub load: Option, } diff --git "a/ata\302\262/src/config.rs" "b/ata\302\262/src/config.rs" index b0a4331..b0b6a23 100644 --- "a/ata\302\262/src/config.rs" +++ "b/ata\302\262/src/config.rs" @@ -1,3 +1,5 @@ +//! Configuration file parsing and validation. +//! //! # ata² //! //! © 2023 Fredrick R. Brennan @@ -18,6 +20,7 @@ use std::collections::HashMap as StdHashMap; use std::convert::Infallible; +use std::env; use std::ffi::OsString; use std::fmt::{self, Display}; @@ -59,7 +62,7 @@ pub struct UiConfig { pub history_file: PathBuf, } -/// For definitions, see https://platform.openai.com/docs/api-reference/completions/create +/// For definitions, see . #[repr(C)] #[derive(Clone, Deserialize, Debug, Serialize, Reflect)] #[serde(default)] @@ -76,15 +79,15 @@ pub struct Config { pub presence_penalty: f64, pub frequency_penalty: f64, pub logit_bias: HashMap, + pub user_id: Option, pub ui: UiConfig, } impl Config { pub fn validate(&self) -> Result<(), String> { - if let Some(api_key) = &self.api_key { - if api_key.is_empty() { - return Err(String::from("API key is empty")); - } + match self.api_key.as_ref().map(|s| s.as_str()) { + Some("") | None => return Err(String::from("API key is missing")), + _ => {} } if self.model.is_empty() { @@ -127,6 +130,11 @@ impl Config { )); } + match self.user_id.as_ref().map(|s| s.as_str()) { + Some("") => return Err(String::from("User ID cannot be an empty string")), + _ => {} + } + for (key, value) in &self.logit_bias { if value < &-2.0 || value > &2.0 { return Err(format!( @@ -140,35 +148,106 @@ impl Config { } } +/// Note: the result is heavily based on the environment variables. +/// +/// * `ATA2_MODEL` sets the model ID. Default: `gpt-3.5-turbo`. +/// * `ATA2_MAX_TOKENS` sets the maximum amount of tokens that the server can answer with. Longer answers will be truncated. Default: `2048`. +/// * `ATA2_TEMPERATURE`. Default: `0.8`. +/// * `ATA2_SUFFIX` sets the suffix. Default: `None`. +/// * `ATA2_TOP_P`. Default: `1.0`. +/// * `ATA2_N`. Default: `1`. +/// * `ATA2_STOP` sets the stop phrases. Default: `[]`. +/// * `ATA2_PRESENCE_PENALTY`. Default: `0.0`. +/// * `ATA2_FREQUENCY_PENALTY`. Default: `0.0`. +/// * `ATA2_LOGIT_BIAS` sets the logit bias. Default: `{}`. impl Default for Config { fn default() -> Self { Self { - model: "text-davinci-003".into(), - max_tokens: 16, - temperature: 0.5, - suffix: None, - top_p: 1.0, - n: 1, + model: env::var("ATA2_MODEL") + .ok() + .unwrap_or_else(|| "gpt-3.5-turbo".to_string()), + max_tokens: env::var("ATA2_MAX_TOKENS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(2048), + temperature: env::var("ATA2_TEMPERATURE") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(0.8), + suffix: env::var("ATA2_SUFFIX").ok(), + top_p: env::var("ATA2_TOP_P") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(1.0), + n: env::var("ATA2_N") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(1), stream: true, - stop: vec![], - presence_penalty: 0.0, - frequency_penalty: 0.0, - logit_bias: HashMap::new(), - api_key: None, + stop: env::var("ATA2_STOP") + .ok() + .map(|s| serde_json::from_str(&s).unwrap()) + .unwrap_or_else(|| vec![]), + presence_penalty: env::var("ATA2_PRESENCE_PENALTY") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(0.0), + frequency_penalty: env::var("ATA2_FREQUENCY_PENALTY") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(0.0), + logit_bias: env::var("ATA2_LOGIT_BIAS") + .ok() + .map(|s| serde_json::from_str(&s).unwrap()) + .unwrap_or_else(|| HashMap::default()), + api_key: env::var("OPENAI_API_KEY").ok(), + user_id: env::var("ATA2_USER_ID").ok(), ui: UiConfig::default(), } } } +/// Note: the result is heavily based on the environment variables. +/// +/// * `ATA2_DOUBLE_CTRLC` sets whether to require user to press ^C twice. Default: `true`. +/// * `ATA2_HIDE_CONFIG` sets whether to hide config on run. Default: `false`. +/// * `ATA2_REDACT_API_KEY` sets whether to redact API key. Default: `true`. +/// * `ATA2_MULTILINE_INSERTIONS` sets whether to allow multiline insertions. Default: `true`. +/// * `ATA2_SAVE_HISTORY` sets whether to save history. Default: `true`. +/// * `ATA2_HISTORY_FILE` sets the history file. Default: `~/.config/ata2/history`. impl Default for UiConfig { fn default() -> Self { Self { - double_ctrlc: true, - hide_config: false, - redact_api_key: true, - multiline_insertions: false, - save_history: true, - history_file: PathBuf::from(get_config_dir::<2>().join("history")), + double_ctrlc: env::var("ATA2_DOUBLE_CTRLC") + .ok() + .map(|s| s.len() > 0) + .unwrap_or(true), + hide_config: env::var("ATA2_HIDE_CONFIG") + .ok() + .map(|s| s.len() > 0) + .unwrap_or(false), + redact_api_key: env::var("ATA2_REDACT_API_KEY") + .ok() + .map(|s| s.len() > 0) + .unwrap_or(true), + multiline_insertions: env::var("ATA2_MULTILINE_INSERTIONS") + .ok() + .map(|s| s.len() > 0) + .unwrap_or(true), + save_history: env::var("ATA2_SAVE_HISTORY") + .ok() + .map(|s| s.len() > 0) + .unwrap_or(true), + history_file: env::var("ATA2_HISTORY_FILE") + .ok() + .map(|s| PathBuf::from(s)) + .unwrap_or_else(|| { + get_config_dir::<2>() + .join("history") + .to_string_lossy() + .to_string() + .into() + }), } } } @@ -208,7 +287,7 @@ impl<'a> Into for &'a Config { if !self.stream { warn!("Stream is disabled. This is not supported anymore and will be ignored."); } - CreateChatCompletionRequestArgs::default() + let mut args = CreateChatCompletionRequestArgs::default() .n(self.n as u8) .model(&self.model) .max_tokens(self.max_tokens as u16) @@ -225,7 +304,13 @@ impl<'a> Into for &'a Config { .top_p(self.top_p as f32) .stop(self.stop.clone()) .stream(true) - .to_owned() + .to_owned(); + + if let Some(user_id) = &self.user_id { + args = args.user(user_id).to_owned(); + } + + args } } diff --git "a/ata\302\262/src/help.rs" "b/ata\302\262/src/help.rs" index 7b828c6..c8d3039 100644 --- "a/ata\302\262/src/help.rs" +++ "b/ata\302\262/src/help.rs" @@ -1,3 +1,5 @@ +//! Help messages for the command-line interface. +//! //! # ata² //! //! © 2023 Fredrick R. Brennan @@ -22,36 +24,11 @@ use crate::config; use config::DEFAULT_CONFIG_FILENAME; use std::fs::{self, File}; use std::io::Write as _; +use std::process::exit; pub fn commands() { - println!(" -Ctrl-A, Home Move cursor to the beginning of line -Ctrl-B, Left Move cursor one character left -Ctrl-E, End Move cursor to end of line -Ctrl-F, Right Move cursor one character right -Ctrl-H, Backspace Delete character before cursor -Ctrl-I, Tab Next completion -Ctrl-K Delete from cursor to end of line -Ctrl-L Clear screen -Ctrl-N, Down Next match from history -Ctrl-P, Up Previous match from history -Ctrl-X Ctrl-U Undo -Ctrl-Y Paste from Yank buffer (Meta-Y to paste next yank instead) -Meta-< Move to first entry in history -Meta-> Move to last entry in history -Meta-B, Alt-Left Move cursor to previous word -Meta-C Capitalize the current word -Meta-D Delete forwards one word -Meta-F, Alt-Right Move cursor to next word -Meta-L Lower-case the next word -Meta-T Transpose words -Meta-U Upper-case the next word -Meta-Y See Ctrl-Y -Meta-Backspace Kill from the start of the current word, or, if between words, to the start of the previous word -Meta-0, 1, ..., - Specify the digit to the argument. – starts a negative argument. - -Thanks to . - "); + println!(include_str!("help/keybindings.txt")); + exit(0); } const EXAMPLE_TOML: &str = r#"api_key = "" @@ -107,5 +84,5 @@ The `temperature` sets the `sampling temperature`. From the OpenAI API docs: "Wh .expect("Unable to write to file"); } } - std::process::exit(1); + exit(1); } diff --git "a/ata\302\262/src/help/keybindings.txt" "b/ata\302\262/src/help/keybindings.txt" new file mode 100644 index 0000000..ac4334f --- /dev/null +++ "b/ata\302\262/src/help/keybindings.txt" @@ -0,0 +1,35 @@ +Keyboard shortcuts: +ata²-specific: +Ctrl-D, EOF (In multiline mode) Send the current message. +F2 Save the current conversation (not including the message + you're typing) to a file. + +rustyline: +Ctrl-A, Home Move cursor to the beginning of line +Ctrl-B, Left Move cursor one character left +Ctrl-E, End Move cursor to end of line +Ctrl-F, Right Move cursor one character right +Ctrl-H, Backspace Delete character before cursor +Ctrl-I, Tab Next completion +Ctrl-K Delete from cursor to end of line +Ctrl-L Clear screen +Ctrl-N, Down Next match from history +Ctrl-P, Up Previous match from history +Ctrl-X Ctrl-U Undo +Ctrl-Y Paste from Yank buffer (Meta-Y to paste next yank instead) +Meta-< Move to first entry in history +Meta-> Move to last entry in history +Meta-B, Alt-Left Move cursor to previous word +Meta-C Capitalize the current word +Meta-D Delete forwards one word +Meta-F, Alt-Right Move cursor to next word +Meta-L Lower-case the next word +Meta-T Transpose words +Meta-U Upper-case the next word +Meta-Y See Ctrl-Y +Meta-Backspace Kill from the start of the current word, or, if between + words, to the start of the previous word +Meta-0, 1, ..., - Specify the digit to the argument. – starts a negative + argument. + +Thanks to . diff --git "a/ata\302\262/src/main.rs" "b/ata\302\262/src/main.rs" index 665303c..ef5d7dd 100644 --- "a/ata\302\262/src/main.rs" +++ "b/ata\302\262/src/main.rs" @@ -1,4 +1,4 @@ -//! # ata² +//! # ata² — Ask the Terminal Anything² //! //! © 2023 Fredrick R. Brennan //! © 2023 Rik Huijzer @@ -16,80 +16,47 @@ //! See the License for the specific language governing permissions and //! limitations under the License. #[macro_use] -extern crate clap; -#[macro_use] extern crate lazy_static; #[macro_use] extern crate log; mod args; +pub use crate::args::Ata2; mod config; +pub use crate::config::Config; mod help; mod prompt; +use crate::prompt::load_conversation; +mod readline; +mod state; +pub use crate::state::*; use ansi_colors::ColouredStr; -use clap::Parser as _; -use futures_util::lock::Mutex; +use futures_util::future::FutureExt as _; use futures_util::task::Context; -use futures_util::FutureExt; -use rustyline::{error::ReadlineError, Cmd, Editor, KeyCode, KeyEvent, Modifiers}; -use tokio::sync::mpsc::{Receiver, Sender}; -use tokio::task::JoinHandle; +use futures_util::task::Poll; +use std::error::Error; use std::fs::File; -use std::io::Read; -use std::fs; +use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; -use std::sync::atomic::{AtomicBool, AtomicUsize}; use std::sync::Arc; -use std::task::Poll; use std::time::Duration; -use crate::args::Ata2; -use crate::config::Config; -use crate::prompt::TokioResult; - +pub type TokioResult> = Result; #[tokio::main] -pub async fn main() -> prompt::TokioResult<()> { - init_logger(); - let flags: Ata2 = Ata2::parse(); - if flags.print_shortcuts { - help::commands(); - return Ok(()); +pub async fn main() -> TokioResult<()> { + if EXIT.load(Ordering::Acquire) { + std::process::exit(0); + } else { + init_logger(); } - let filename = flags.config.location(); - if !filename.exists() { - let v1_filename = flags.config.location_v1(); - if v1_filename.exists() { - fs::create_dir_all(&config::default_path::<2>(None).parent().unwrap()) - .expect("Could not make configuration directory"); - fs::copy(&v1_filename, &filename).expect(&format!( - "Failed to copy {} to {}", - v1_filename.to_string_lossy(), - filename.to_string_lossy() - )); - warn!( - "{}", - &format!( - "Copied old configuration file to ata¹'s location {}", - filename.to_string_lossy() - ), - ); - } else { - help::missing_toml(); - } + if FLAGS.load.is_some() { + load_conversation(FLAGS.load.as_ref().unwrap()).await?; } - let mut contents = String::new(); - File::open(filename) - .unwrap() - .read_to_string(&mut contents) - .expect("Could not read configuration file"); - - let config = Arc::new(Config::from(&contents)); - let config_clone = config.clone(); - let config_clone2 = config.clone(); - let had_first_interrupt: AtomicBool = AtomicBool::new(false); + let mut rl = readline::Readline::new(); + let config = CONFIGURATION.clone(); config.validate().unwrap_or_else(|e| { error!("Config error!: {e}. Dying."); panic!() @@ -102,22 +69,11 @@ pub async fn main() -> prompt::TokioResult<()> { eprint!("{}", header); } - if !flags.hide_config && !config.ui.hide_config && atty::is(atty::Stream::Stderr) { + if !FLAGS.hide_config && !config.ui.hide_config && atty::is(atty::Stream::Stderr) { eprintln!("{config}"); } - let mut rl = Editor::<()>::new()?; - if config.ui.multiline_insertions { - if atty::is(atty::Stream::Stdin) { - // Cmd::Newline inserts a newline, Cmd::AcceptLine accepts the line - rl.bind_sequence(KeyEvent(KeyCode::Enter, Modifiers::NONE), Cmd::Newline); - rl.bind_sequence( - KeyEvent(KeyCode::Char('d'), Modifiers::CTRL), - Cmd::AcceptLine, - ); - } - } if atty::is(atty::Stream::Stdin) && config.ui.save_history { - if rl.load_history(&config.ui.history_file).is_err() { + if rl.load_history().await.is_err() { warn!("No history file found. Creating a new one."); File::create(&config.ui.history_file).unwrap_or_else(|e| { error!("Could not create history file: {e}"); @@ -126,17 +82,13 @@ pub async fn main() -> prompt::TokioResult<()> { }); } } + rl.enable_multiline().await; + rl.enable_request_save().await; // use tokio asynchronous message queue - let (tx, mut rx): (Sender>, Receiver>) = + let (tx, mut rx): (tokio::sync::mpsc::Sender>, _) = tokio::sync::mpsc::channel(1); - let is_running = Arc::new(AtomicBool::new(false)); - let is_running_clone = is_running.clone(); - let abort = Arc::new(AtomicBool::new(false)); - let abort_clone = abort.clone(); let handle = tokio::spawn(async move { - let abort = abort_clone.clone(); - let is_running = is_running.clone(); let n_pending_debug_log_notices = Arc::new(AtomicUsize::new(0)); loop { let msg = Box::pin(rx.recv()).poll_unpin(&mut Context::from_waker( @@ -144,14 +96,7 @@ pub async fn main() -> prompt::TokioResult<()> { )); match msg { Poll::Ready(Some(Some(line))) => { - let result = prompt::request( - abort.clone(), - is_running.clone(), - &config_clone, - line.to_string(), - 0, - ) - .await; + let result = prompt::request(line.to_string(), 0).await; match result { Ok(_) => {} Err(e) => { @@ -160,23 +105,23 @@ pub async fn main() -> prompt::TokioResult<()> { } n_pending_debug_log_notices.store(0, Ordering::SeqCst); } - Poll::Ready(Some(None) | None) => { + Poll::Ready(Some(None)) => { n_pending_debug_log_notices.store(0, Ordering::SeqCst); + info!("Got None in API request loop, exiting"); break; } - Poll::Pending => { + Poll::Ready(None) | Poll::Pending => { // All the next 20 or so lines are just for debug logging… { - n_pending_debug_log_notices.fetch_add(1, Ordering::SeqCst); - let n = n_pending_debug_log_notices.load(Ordering::SeqCst); + let n = n_pending_debug_log_notices.fetch_add(1, Ordering::SeqCst); static MAX_PENDING_DEBUG_LOG_NOTICES: usize = 10; macro_rules! PENDING_LOOP_MSG { () => { - "Got pending in API request loop, waiting 100ms ({n}/{max})" + "Got pending in API request loop, waiting 10ms ({n}/{max})" }; ($msg:expr) => { concat!( - "Got pending in API request loop, waiting 100ms ({n}/{max}): ", + "Got pending in API request loop, waiting 10ms ({n}/{max}): ", $msg ) }; @@ -200,94 +145,30 @@ pub async fn main() -> prompt::TokioResult<()> { } } // …and now we're done with debug logging. - tokio::time::sleep(Duration::from_millis(100)).await; + tokio::time::sleep(Duration::from_millis(10)).await; continue; } } } }); - let rl = Arc::new(Mutex::new(rl)); - let rl_clone = rl.clone(); - let readline_handle: JoinHandle> = tokio::spawn(async move { - // If stdin is not a tty, we want to read once to the end of it and then exit. - let mut already_read = false; - let mut stdin = std::io::stdin(); - prompt::print_prompt(); - while !abort.load(Ordering::Relaxed) { - // lock Readlien - let mut rl = rl.lock().await; - // Using an empty prompt text because otherwise the user would - // "see" that the prompt is ready again during response printing. - // Also, the current readline is cleared in some cases by rustyline, - // so being on a newline is the only way to avoid that. - let readline = if atty::is(atty::Stream::Stdin) { - rl.readline("") - } else if !already_read { - let mut buf = String::with_capacity(1024); - stdin.read_to_string(&mut buf)?; - already_read = true; - Ok(buf) - } else { - Err(ReadlineError::Eof)? - }; - match readline { - Ok(line) => { - if is_running_clone.load(Ordering::SeqCst) { - abort.store(true, Ordering::SeqCst); - } - if line.is_empty() { - continue; - } - rl.add_history_entry(line.as_str()); - tx.send(Some(line)).await?; - had_first_interrupt.store(false, Ordering::Relaxed); - } - Err(ReadlineError::Interrupted) => { - if is_running_clone.load(Ordering::SeqCst) { - abort.store(true, Ordering::SeqCst); - } else { - if config.ui.double_ctrlc && !had_first_interrupt.load(Ordering::Relaxed) { - had_first_interrupt.store(true, Ordering::Relaxed); - eprintln!("\nPress Ctrl-C again to exit."); - tokio::time::sleep(Duration::from_millis(1000)).await; - eprintln!(); - prompt::print_prompt(); - continue; - } else { - tx.send(None).await?; - break; - } - } - } - Err(ReadlineError::Eof) => { - tx.send(None).await?; - break; - } - Err(err) => { - eprintln!("{err:?}"); - tx.send(None).await?; - break; - } - } - } - return Ok(()); - }); + let readline_handle = rl.handle(tx).await; tokio::select! { - _ = readline_handle => {} - _ = handle => {} + _ = readline_handle => { + info!("Readline died"); + } + _ = handle => { + info!("API request loop died"); + } } - if atty::is(atty::Stream::Stdin) && config_clone2.ui.save_history { - let mut rl_clone = rl_clone.lock().await; - rl_clone - .save_history(&config_clone2.ui.history_file) - .unwrap_or_else(|e| error!("Could not save history: {e}")); + if atty::is(atty::Stream::Stdin) && config.ui.save_history { + rl.save_history().await?; info!( "Saved history to {history_file}. Number of entries: {entries}", - history_file = config_clone2.ui.history_file.to_string_lossy(), - entries = rl_clone.history().len() + history_file = config.ui.history_file.to_string_lossy(), + entries = rl.history_len().await ); } @@ -295,7 +176,7 @@ pub async fn main() -> prompt::TokioResult<()> { } fn init_logger() { - let env = env_logger::Env::default().default_filter_or("warn"); + let env = env_logger::Env::default().default_filter_or("info"); env_logger::Builder::from_env(env) .format_timestamp(None) .init(); diff --git "a/ata\302\262/src/prompt.rs" "b/ata\302\262/src/prompt.rs" index 4827126..c854943 100644 --- "a/ata\302\262/src/prompt.rs" +++ "b/ata\302\262/src/prompt.rs" @@ -1,3 +1,5 @@ +//! REPL +//! //! # ata² //! //! © 2023 Fredrick R. Brennan @@ -20,36 +22,65 @@ use ansi_colors::ColouredStr; use async_openai::{ config::OpenAIConfig, types::{ - ChatCompletionRequestUserMessageArgs, ChatCompletionResponseStreamMessage, + ChatCompletionRequestMessage, ChatCompletionResponseStreamMessage, CreateChatCompletionRequestArgs, FinishReason, }, Client, }; use atty; use log::debug; +use tokio::sync::Mutex; +use tokio_stream::StreamExt as _; -use futures_util::StreamExt as _; - -use std::error::Error; -use std::io::Write; -use std::result::Result; +use std::io::{self, Stderr, Stdout}; +use std::io::{Read as _, Write as _}; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; -pub type TokioResult> = Result; +use crate::readline::{ + string_to_chat_completion_assistant_message, string_to_chat_completion_request_user_message, +}; +use crate::TokioResult; +use crate::ABORT; +use crate::CONFIGURATION; +use crate::IS_RUNNING; + +lazy_static! { + static ref STDOUT: Stdout = io::stdout(); + static ref STDERR: Stderr = io::stderr(); + pub static ref CONVERSATION: Mutex> = Mutex::new(vec![]); +} + +pub async fn load_conversation>(path: P) -> TokioResult<()> { + let mut file = std::fs::File::open(path)?; + let mut contents = String::new(); + file.read_to_string(&mut contents)?; + let lines = contents.split("\n").collect::>(); + let mut conversation = CONVERSATION.lock().await; + let loaded_conversation = serde_json::from_str::>( + &lines + .into_iter() + .filter(|o| !o.is_empty()) + .collect::>() + .join("\n"), + )?; + conversation.clear(); + conversation.extend(loaded_conversation); + Ok(()) +} fn print_and_flush(text: &str) { print!("{text}"); - std::io::stdout().flush().unwrap(); + (&*STDOUT).flush().unwrap(); } fn eprint_and_flush(text: &str) { eprint!("{text}"); - std::io::stderr().flush().unwrap(); + (&*STDERR).flush().unwrap(); } -pub fn eprint_bold(msg: &str) { +fn eprint_bold(msg: &str) { if atty::is(atty::Stream::Stderr) { let mut bold = ColouredStr::new(msg); bold.bold(); @@ -62,25 +93,24 @@ pub fn eprint_bold(msg: &str) { pub fn print_prompt() { if atty::is(atty::Stream::Stderr) { - eprint_bold("Prompt:\n"); + eprint_bold("\nPrompt:\n"); } } fn print_response_prompt() { if atty::is(atty::Stream::Stderr) { - eprint_bold("Response:\n"); + eprint_bold("\nResponse:\n"); } } -fn finish_prompt(is_running: Arc) { - is_running.store(false, Ordering::SeqCst); - eprint_and_flush("\n\n"); +fn finish_prompt() { + IS_RUNNING.store(false, Ordering::SeqCst); print_prompt(); } -pub fn print_error(is_running: Arc, msg: &str) { +fn print_error(msg: &str) { error!("{msg}"); - finish_prompt(is_running) + finish_prompt() } fn store_and_do_nothing(print_buffer: &mut Vec, text: &str) -> String { @@ -113,31 +143,40 @@ fn post_process(print_buffer: &mut Vec, text: &str) -> String { } pub async fn request( - abort: Arc, - is_running: Arc, - config: &super::Config, prompt: String, _count: i64, ) -> TokioResult> { let mut print_buffer: Vec = Vec::new(); + let config = &*CONFIGURATION.to_owned(); let oconfig: OpenAIConfig = config.into(); let openai = Client::with_config(oconfig); let completions = openai.chat(); - let mut args: CreateChatCompletionRequestArgs = config.into(); - args.messages(vec![ChatCompletionRequestUserMessageArgs::default() - .content(prompt) - .build()? - .into()]); - let mut stream = completions.create_stream(args.build()?).await?; - is_running.store(true, Ordering::SeqCst); + let messages = { + CONVERSATION + .lock() + .await + .push(string_to_chat_completion_request_user_message( + prompt.clone(), + )); + CONVERSATION + .lock() + .await + .clone() + .into_iter() + .collect::>() + }; + let mut request: CreateChatCompletionRequestArgs = config.into(); + let mut stream = completions + .create_stream(request.messages(messages).build()?) + .await?; + IS_RUNNING.store(true, Ordering::SeqCst); let got_first_success: Arc = Arc::new(AtomicBool::new(false)); let mut ret = vec![]; - 'abort: while !abort.load(Ordering::Relaxed) { - 'outer: while let Some(completion) = stream.next().await { - debug!("Completion: {:?}", completion); - match completion { + 'abort: while !ABORT.load(Ordering::Relaxed) { + while let Some(c) = stream.next().await { + match c { Ok(completion) => { let completion = Arc::new(completion); ret.push(completion.clone()); @@ -146,49 +185,50 @@ pub async fn request( print_response_prompt(); } for choice in &completion.choices { - if abort.load(Ordering::Relaxed) { + if ABORT.load(Ordering::Relaxed) { break 'abort; } + match choice.delta.content { + Some(ref text) => { + let newline_fixed = post_process(&mut print_buffer, &text); + print_and_flush(&newline_fixed); + } + None => {} + } match choice.finish_reason { Some(FinishReason::Stop) => { debug!("Got stop from API, returning to REPL"); + IS_RUNNING.store(false, Ordering::SeqCst); break 'abort; } Some(reason) => { let msg = format!("OpenAI API error: {reason:?}"); - print_error(is_running.clone(), &msg); + print_error(&msg); continue 'abort; } None => {} } - match choice.delta.content { - Some(ref text) => { - let newline_fixed = post_process(&mut print_buffer, &text); - print_and_flush(&newline_fixed); - } - None => { - continue 'outer; - } - } } } Err(e) => { let msg = format!("OpenAI API error: {e}"); - print_error(is_running.clone(), &msg); - continue 'abort; + print_error(&msg); + break 'abort; } } } + debug!("Got end of stream, returning to REPL"); + IS_RUNNING.store(false, Ordering::SeqCst); break 'abort; } + eprint_and_flush("\n"); if !got_first_success.load(Ordering::SeqCst) { let msg = format!("Empty prompt, aborting."); - print_error(is_running.clone(), &msg); + print_error(&msg); return Ok(vec![]); } - print_and_flush("\n"); let result = ret .drain(..) .map(|o| Arc::new(o.choices.clone().into_iter().collect::>())) @@ -203,7 +243,18 @@ pub async fn request( .flatten() .collect::>(); - is_running.store(false, Ordering::SeqCst); - finish_prompt(is_running); + let complete_message = result.iter().map(|o| o.delta.clone()).collect::>(); + + let assistant_msg = string_to_chat_completion_assistant_message( + complete_message + .into_iter() + .map(|o| o.content.unwrap_or_else(String::new)) + .collect::>() + .join(""), + ); + (*CONVERSATION).lock().await.push(assistant_msg); + + IS_RUNNING.store(false, Ordering::SeqCst); + finish_prompt(); Ok(result) } diff --git "a/ata\302\262/src/readline.rs" "b/ata\302\262/src/readline.rs" new file mode 100644 index 0000000..2ba5577 --- /dev/null +++ "b/ata\302\262/src/readline.rs" @@ -0,0 +1,213 @@ +//! a wrapper around rustyline +//! +//! (rustyline is a readline-like library for Rust) +//! +//! # ata² +//! +//! © 2023 Fredrick R. Brennan +//! © 2023 Rik Huijzer +//! © 2023– ATA Project Authors +//! +//! Licensed under the Apache License, Version 2.0 (the "License"); +//! you may _not_ use this file except in compliance with the License. +//! You may obtain a copy of the License at +//! +//! http://www.apache.org/licenses/LICENSE-2.0 +//! +//! Unless required by applicable law or agreed to in writing, software +//! distributed under the License is distributed on an "AS IS" BASIS, +//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +//! See the License for the specific language governing permissions and +//! limitations under the License. + +use async_openai::types::{ + ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage, + ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, Role, +}; +use futures_util::lock::Mutex; +use rustyline::error::ReadlineError; +use rustyline::{ + Cmd, ConditionalEventHandler, Editor, EventContext, EventHandler, KeyCode, KeyEvent, Modifiers, + RepeatCount, +}; +use std::future::IntoFuture; +use std::io::Read as _; +use std::io::Write as _; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::mpsc::Sender; +use tokio::task::JoinHandle; + +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use crate::prompt::{self, CONVERSATION}; +use crate::TokioResult; +use crate::ABORT; +use crate::CONFIGURATION as config; +use crate::HAD_FIRST_INTERRUPT; + +pub fn string_to_chat_completion_request_user_message( + string: String, +) -> ChatCompletionRequestMessage { + ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { + role: Role::User, + content: Some(ChatCompletionRequestUserMessageContent::Text(string)), + ..Default::default() + }) +} + +pub fn string_to_chat_completion_assistant_message(string: String) -> ChatCompletionRequestMessage { + ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { + role: Role::Assistant, + content: Some(string), + ..Default::default() + }) +} + +pub struct Readline { + pub rl: Arc>>, +} + +impl Readline { + pub fn new() -> Self { + let rl = Editor::<()>::new().unwrap(); + Self { + rl: Arc::new(Mutex::new(rl)), + } + } +} + +use futures_util::FutureExt as _; + +struct RequestSaveHandler; +impl ConditionalEventHandler for RequestSaveHandler { + fn handle( + &self, + _event: &rustyline::Event, + _n: RepeatCount, + _positive: bool, + _: &EventContext, + ) -> Option { + let convo = CONVERSATION.lock().into_future(); + let convo = convo.now_or_never().unwrap(); + let convo = convo.clone(); + let convo_json = serde_json::to_string(&convo).unwrap(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + // as unix secs + let filename = format!("conversation-{}.json", now); + let _ = std::fs::remove_file(&filename); + let convo_file = std::fs::File::create(&filename).unwrap(); + let mut convo_file = std::io::BufWriter::new(convo_file); + convo_file.write_all(convo_json.as_bytes()).unwrap(); + info!("Saved conversation to {filename}"); + Some(Cmd::Noop) + } +} + +impl Readline { + pub async fn handle(&mut self, tx: Sender>) -> JoinHandle> { + let rl = self.rl.clone(); + let readline_handle: JoinHandle> = tokio::spawn(async move { + // If stdin is not a tty, we want to read once to the end of it and then exit. + let mut already_read = false; + let mut stdin = std::io::stdin(); + prompt::print_prompt(); + while !ABORT.load(Ordering::Relaxed) { + // lock Readlien + let mut rl = rl.lock().await; + // Using an empty prompt text because otherwise the user would + // "see" that the prompt is ready again during response printing. + // Also, the current readline is cleared in some cases by rustyline, + // so being on a newline is the only way to avoid that. + let readline = if atty::is(atty::Stream::Stdin) { + rl.readline("") + } else if !already_read { + let mut buf = String::with_capacity(1024); + stdin.read_to_string(&mut buf)?; + already_read = true; + Ok(buf) + } else { + Err(ReadlineError::Eof)? + }; + match readline { + Ok(line) => { + if line.is_empty() { + continue; + } + rl.add_history_entry(line.as_str()); + tx.send(Some(line)).await?; + HAD_FIRST_INTERRUPT.store(false, Ordering::Relaxed); + } + Err(ReadlineError::Interrupted) => { + if config.ui.double_ctrlc && !HAD_FIRST_INTERRUPT.load(Ordering::Relaxed) { + HAD_FIRST_INTERRUPT.store(true, Ordering::Relaxed); + eprint!("\nPress Ctrl-C again to exit."); + prompt::print_prompt(); + continue; + } else { + tx.send(None).await?; + ABORT.store(true, Ordering::Relaxed); + break; + } + } + Err(ReadlineError::Eof) => { + HAD_FIRST_INTERRUPT.store(false, Ordering::Relaxed); + tx.send(None).await?; + break; + } + Err(err) => { + eprintln!("{err:?}"); + tx.send(None).await?; + break; + } + } + } + return Ok(()); + }); + readline_handle + } + + pub async fn enable_multiline(&mut self) { + let mut rl = self.rl.lock().await; + if config.ui.multiline_insertions { + if atty::is(atty::Stream::Stdin) { + // Cmd::Newline inserts a newline, Cmd::AcceptLine accepts the line + rl.bind_sequence(KeyEvent(KeyCode::Enter, Modifiers::NONE), Cmd::Newline); + rl.bind_sequence( + KeyEvent(KeyCode::Char('d'), Modifiers::CTRL), + Cmd::AcceptLine, + ); + } + } + } + + pub async fn enable_request_save(&mut self) { + let mut rl = self.rl.lock().await; + if atty::is(atty::Stream::Stdin) { + rl.bind_sequence( + KeyEvent(KeyCode::F(2), Modifiers::NONE), + EventHandler::Conditional(Box::new(RequestSaveHandler)), + ); + } + } + + pub async fn save_history(&mut self) -> TokioResult<()> { + let mut rl = self.rl.lock().await; + rl.save_history(&config.ui.history_file)?; + Ok(()) + } + + pub async fn load_history(&mut self) -> TokioResult<()> { + let mut rl = self.rl.lock().await; + rl.load_history(&config.ui.history_file)?; + Ok(()) + } + + pub async fn history_len(&mut self) -> usize { + let rl = self.rl.lock().await; + rl.history().len() + } +} diff --git "a/ata\302\262/src/state.rs" "b/ata\302\262/src/state.rs" new file mode 100644 index 0000000..f549c51 --- /dev/null +++ "b/ata\302\262/src/state.rs" @@ -0,0 +1,76 @@ +//! the global state of ata² +//! +//! # ata² +//! +//! © 2023 Fredrick R. Brennan +//! © 2023 Rik Huijzer +//! © 2023– ATA Project Authors +//! +//! Licensed under the Apache License, Version 2.0 (the "License"); +//! you may _not_ use this file except in compliance with the License. +//! You may obtain a copy of the License at +//! +//! http://www.apache.org/licenses/LICENSE-2.0 +//! +//! Unless required by applicable law or agreed to in writing, software +//! distributed under the License is distributed on an "AS IS" BASIS, +//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +//! See the License for the specific language governing permissions and +//! limitations under the License. + +use clap::Parser as _; + +use crate::args::Ata2; +use crate::config::{self, Config}; +use crate::help; + +use std::fs; +use std::fs::File; +use std::io::Read as _; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +lazy_static! { + pub static ref FLAGS: Ata2 = Ata2::parse(); + pub static ref EXIT: Arc = Arc::new(AtomicBool::new(false)); + pub static ref CONFIGURATION: Arc = { + if FLAGS.print_shortcuts { + help::commands(); + EXIT.store(true, Ordering::Relaxed); + } + let filename = FLAGS.config.location(); + if !filename.exists() { + let v1_filename = FLAGS.config.location_v1(); + if v1_filename.exists() { + fs::create_dir_all(&config::default_path::<2>(None).parent().unwrap()) + .expect("Could not make configuration directory"); + fs::copy(&v1_filename, &filename).expect(&format!( + "Failed to copy {} to {}", + v1_filename.to_string_lossy(), + filename.to_string_lossy() + )); + warn!( + "{}", + &format!( + "Copied old configuration file to ata¹'s location {}", + filename.to_string_lossy() + ), + ); + } else { + help::missing_toml(); + } + } + let mut contents = String::new(); + File::open(filename) + .unwrap() + .read_to_string(&mut contents) + .expect("Could not read configuration file"); + + let config_ = Arc::new(Config::from(&contents)); + config_ + }; + pub static ref ABORT: Arc = Arc::new(AtomicBool::new(false)); + pub static ref IS_RUNNING: Arc = Arc::new(AtomicBool::new(false)); + pub static ref HAD_FIRST_INTERRUPT: Arc = Arc::new(AtomicBool::new(false)); +}