diff --git a/Cargo.toml b/Cargo.toml index c7a21b5..e29ecd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aleph-alpha-client" -version = "0.10.1" +version = "0.11.0" edition = "2021" description = "Interact with large language models provided by the Aleph Alpha API in Rust code" license = "MIT" diff --git a/Changelog.md b/Changelog.md index 741fa93..2a011ee 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,12 @@ ## Unreleased +## 0.11.0 + +* Add `with_maximum_tokens` method to `Prompt` +* Remove maximum tokens argument from `Prompt::from_text` +* Make maximum tokens optional + ## 0.10.1 * Fix: Version number in Cargo.toml diff --git a/README.md b/README.md index ccd4531..54f35a5 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ fn main() { let model = "luminous-base"; // The task we want to perform. Here we want to continue the sentence: "An apple a day ..." - let task = TaskCompletion::from_text("An apple a day", 10); + let task = TaskCompletion::from_text("An apple a day"); // Retrieve the answer from the API let response = client.completion(&task, model, &How::default()).await.unwrap(); diff --git a/src/completion.rs b/src/completion.rs index 3b68ac2..d59c5b7 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -14,15 +14,19 @@ pub struct TaskCompletion<'a> { } impl<'a> TaskCompletion<'a> { - /// Convenience constructor leaving most setting to default, just completing a given text and - /// taking the maximum anticipated length of the completion. - pub fn from_text(text: &'a str, maximum_tokens: u32) -> Self { + /// Convenience constructor leaving most setting to default, just completing a given text + pub fn from_text(text: &'a str) -> Self { TaskCompletion { prompt: Prompt::from_text(text), - stopping: Stopping::from_maximum_tokens(maximum_tokens), + stopping: Stopping::NO_TOKEN_LIMIT, sampling: Sampling::MOST_LIKELY, } } + + pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self { + self.stopping.maximum_tokens = Some(maximum_tokens); + self + } } /// Sampling controls how the tokens ("words") are selected for the completion. @@ -69,10 +73,13 @@ impl Default for Sampling<'_> { /// Controls the conditions under which the language models stops generating text. pub struct Stopping<'a> { /// The maximum number of tokens to be generated. Completion will terminate after the maximum - /// number of tokens is reached.Increase this value to allow for longer outputs. A text is split + /// number of tokens is reached. Increase this value to allow for longer outputs. A text is split /// into tokens. Usually there are more tokens than words. The total number of tokens of prompt /// and maximum_tokens depends on the model. - pub maximum_tokens: u32, + /// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens. + /// The model will generate tokens until it either emits a stop token or it reaches its technical + /// limit, which usually is its context window. + pub maximum_tokens: Option, /// List of strings which will stop generation if they are generated. Stop sequences are /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of /// lines starting with either "Question: " or "Answer: " (alternating). After producing an @@ -83,15 +90,28 @@ pub struct Stopping<'a> { } impl<'a> Stopping<'a> { + /// Only stop once the model generates end of text, or it reaches its technical limit, usually the + /// context window. + pub const NO_TOKEN_LIMIT: Self = Stopping { + maximum_tokens: None, + stop_sequences: &[], + }; + /// Only stop once the model generates end of text, or maximum tokens are reached. pub fn from_maximum_tokens(maximum_tokens: u32) -> Self { Self { - maximum_tokens, + maximum_tokens: Some(maximum_tokens), stop_sequences: &[], } } } +impl Default for Stopping<'_> { + fn default() -> Self { + Self::NO_TOKEN_LIMIT + } +} + /// Body send to the Aleph Alpha API on the POST `/completion` Route #[derive(Serialize, Debug)] struct BodyCompletion<'a> { @@ -100,7 +120,8 @@ struct BodyCompletion<'a> { /// Prompt to complete. The modalities supported depend on `model`. pub prompt: Prompt<'a>, /// Limits the number of tokens, which are generated for the completion. - pub maximum_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum_tokens: Option, /// List of strings which will stop generation if they are generated. Stop sequences are /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of /// lines starting with either "Question: " or "Answer: " (alternating). After producing an diff --git a/src/http.rs b/src/http.rs index 0bba6f9..84e37b2 100644 --- a/src/http.rs +++ b/src/http.rs @@ -114,7 +114,7 @@ impl HttpClient { /// /// // The task we want to perform. Here we want to continue the sentence: "An apple a day /// // ..." - /// let task = TaskCompletion::from_text("An apple a day", 10); + /// let task = TaskCompletion::from_text("An apple a day"); /// /// // Retrieve answer from API /// let response = client.output_of(&task.with_model(model), &How::default()).await?; diff --git a/src/lib.rs b/src/lib.rs index 251105e..9bee967 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ //! let model = "luminous-base"; //! //! // The task we want to perform. Here we want to continue the sentence: "An apple a day ..." -//! let task = TaskCompletion::from_text("An apple a day", 10); +//! let task = TaskCompletion::from_text("An apple a day"); //! //! // Retrieve the answer from the API //! let response = client.completion(&task, model, &How::default()).await.unwrap(); @@ -65,7 +65,7 @@ impl Client { /// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API. /// For "normal" client applications you may likely rather use [`Self::with_authentication`] or /// [`Self::with_base_url`]. - /// + /// /// You may want to only use request based authentication and skip default authentication. This /// is useful if writing an application which invokes the client on behalf of many different /// users. Having neither request, nor default authentication is considered a bug and will cause @@ -81,7 +81,7 @@ impl Client { } /// Use your on-premise inference with your API token for all requests. - /// + /// /// In production you typically would want set this to . Yet /// you may want to use a different instances for testing. pub fn with_base_url(host: String, api_token: impl Into) -> Result { @@ -103,7 +103,7 @@ impl Client { /// /// // The task we want to perform. Here we want to continue the sentence: "An apple a day /// // ..." - /// let task = TaskCompletion::from_text("An apple a day", 10); + /// let task = TaskCompletion::from_text("An apple a day"); /// /// // Retrieve answer from API /// let response = client.execute(model, &task, &How::default()).await?; @@ -167,7 +167,7 @@ impl Client { /// /// // The task we want to perform. Here we want to continue the sentence: "An apple a day /// // ..." - /// let task = TaskCompletion::from_text("An apple a day", 10); + /// let task = TaskCompletion::from_text("An apple a day"); /// /// // Retrieve answer from API /// let response = client.completion(&task, model, &How::default()).await?; diff --git a/tests/integration.rs b/tests/integration.rs index 23e5713..95aec72 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -21,7 +21,7 @@ fn api_token() -> &'static str { #[tokio::test] async fn completion_with_luminous_base() { // When - let task = TaskCompletion::from_text("Hello", 1); + let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; let client = Client::with_authentication(api_token()).unwrap(); @@ -39,7 +39,7 @@ async fn completion_with_luminous_base() { #[tokio::test] async fn request_authentication_has_priority() { let bad_aa_api_token = "DUMMY"; - let task = TaskCompletion::from_text("Hello", 1); + let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); let model = "luminous-base"; let client = Client::with_authentication(bad_aa_api_token).unwrap(); @@ -64,7 +64,7 @@ async fn request_authentication_has_priority() { async fn authentication_only_per_request() { // Given let model = "luminous-base"; - let task = TaskCompletion::from_text("Hello", 1); + let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); // When let client = Client::new("https://api.aleph-alpha.com".to_owned(), None).unwrap(); @@ -88,7 +88,7 @@ async fn authentication_only_per_request() { async fn must_panic_if_authentication_is_missing() { // Given let model = "luminous-base"; - let task = TaskCompletion::from_text("Hello", 1); + let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1); // When let client = Client::new("https://api.aleph-alpha.com".to_owned(), None).unwrap(); @@ -172,7 +172,7 @@ async fn complete_structured_prompt() { let task = TaskCompletion { prompt: Prompt::from_text(prompt), stopping: Stopping { - maximum_tokens: 64, + maximum_tokens: Some(64), stop_sequences: &stop_sequences[..], }, sampling: Sampling::MOST_LIKELY, @@ -190,6 +190,29 @@ async fn complete_structured_prompt() { assert!(!response.completion.contains("User:")); } +#[tokio::test] +async fn context_window_stopping() { + // Given + let prompt = "Bot: Hello user!\nUser: Hello Bot, how are you doing?\nBot:"; + let stopping = Stopping::NO_TOKEN_LIMIT; + + // When + let task = TaskCompletion { + prompt: Prompt::from_text(prompt), + stopping, + sampling: Sampling::MOST_LIKELY, + }; + let model = "luminous-base"; + let client = Client::with_authentication(api_token()).unwrap(); + let response = client + .output_of(&task.with_model(model), &How::default()) + .await + .unwrap(); + + // Then + assert!(!response.completion.is_empty()); +} + #[tokio::test] async fn explain_request() { // Given diff --git a/tests/unit.rs b/tests/unit.rs index e6b2726..e29c6a7 100644 --- a/tests/unit.rs +++ b/tests/unit.rs @@ -32,7 +32,7 @@ async fn completion_with_luminous_base() { .await; // When - let task = TaskCompletion::from_text("Hello,", 1); + let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); let response = client @@ -72,7 +72,7 @@ async fn detect_rate_limiting() { .await; // When - let task = TaskCompletion::from_text("Hello,", 1); + let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); let error = client @@ -116,7 +116,7 @@ async fn detect_queue_full() { .await; // When - let task = TaskCompletion::from_text("Hello,", 1); + let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); let error = client @@ -153,7 +153,7 @@ async fn detect_service_unavailable() { .await; // When - let task = TaskCompletion::from_text("Hello,", 1); + let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); let error = client @@ -175,7 +175,7 @@ async fn be_nice() { let mock_server = MockServer::start().await; // When - let task = TaskCompletion::from_text("Hello,", 1); + let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1); let model = "luminous-base"; let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); // Drop result, answer is meaningless anyway @@ -211,7 +211,9 @@ async fn client_timeout() { // When let result = client .output_of( - &TaskCompletion::from_text("Hello,", 1).with_model("any"), + &TaskCompletion::from_text("Hello,") + .with_maximum_tokens(1) + .with_model("any"), &How { client_timeout: response_time / 2, ..Default::default()