From 50fbf6e3a49169305408803afce36bc5abaa5bed Mon Sep 17 00:00:00 2001 From: Markus Klein Date: Thu, 5 Sep 2024 12:03:17 +0200 Subject: [PATCH] Support fetching tokenizers --- Cargo.toml | 3 ++- Changelog.md | 4 ++++ src/http.rs | 20 ++++++++++++++++++++ src/lib.rs | 5 +++++ tests/integration.rs | 15 +++++++++++++++ 5 files changed, 46 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e29ecd5..536549c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aleph-alpha-client" -version = "0.11.0" +version = "0.12.0" edition = "2021" description = "Interact with large language models provided by the Aleph Alpha API in Rust code" license = "MIT" @@ -18,6 +18,7 @@ reqwest = { version = "0.12.3", features = ["json"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.115" thiserror = "1.0.58" +tokenizers = { version = "0.20.0", default-features = false, features = ["onig", "esaxx_fast"] } [dev-dependencies] dotenv = "0.15.0" diff --git a/Changelog.md b/Changelog.md index 2a011ee..11a0f71 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,10 @@ ## Unreleased +## 0.12.0 + +* Add `Client::tokenizer_by_model` to fetch the Tokenizer for a given model name + ## 0.11.0 * Add `with_maximum_tokens` method to `Prompt` diff --git a/src/http.rs b/src/http.rs index 84e37b2..f2c27e6 100644 --- a/src/http.rs +++ b/src/http.rs @@ -3,6 +3,7 @@ use std::{borrow::Cow, time::Duration}; use reqwest::{header, ClientBuilder, RequestBuilder, StatusCode}; use serde::Deserialize; use thiserror::Error as ThisError; +use tokenizers::Tokenizer; use crate::How; @@ -163,6 +164,21 @@ impl HttpClient { auth_value.set_sensitive(true); auth_value } + + pub async fn tokenizer_by_model(&self, model: &str, api_token: Option ) -> Result { + let api_token = api_token + .as_ref() + .or(self.api_token.as_ref()) + .expect("API token needs to be set on client construction or per request"); + let response = self.http.get(format!("{}/models/{model}/tokenizer", self.base)) + .header(header::AUTHORIZATION, Self::header_from_token(api_token)).send().await?; + let response = translate_http_error(response).await?; + let bytes = response.bytes().await?; + let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| { + Error::InvalidTokenizer { deserialization_error: e.to_string() } + })?; + Ok(tokenizer) + } } async fn translate_http_error(response: reqwest::Response) -> Result { @@ -235,6 +251,10 @@ pub enum Error { /// An error on the Http Protocol level. #[error("HTTP request failed with status code {}. Body:\n{}", status, body)] Http { status: u16, body: String }, + #[error("Tokenizer could not be correctly deserialized. Caused by:\n{}", deserialization_error)] + InvalidTokenizer { + deserialization_error: String, + }, /// Most likely either TLS errors creating the Client, or IO errors. #[error(transparent)] Other(#[from] reqwest::Error), diff --git a/src/lib.rs b/src/lib.rs index 9bee967..2ab4dbc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,7 @@ use std::time::Duration; use http::HttpClient; use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; +use tokenizers::Tokenizer; pub use self::{ completion::{CompletionOutput, Sampling, Stopping, TaskCompletion}, @@ -303,6 +304,10 @@ impl Client { .output_of(&task.with_model(model), how) .await } + + pub async fn tokenizer_by_model(&self, model: &str, api_token: Option) -> Result { + self.http_client.tokenizer_by_model(model, api_token).await + } } /// Controls of how to execute a task diff --git a/tests/integration.rs b/tests/integration.rs index 2b65f27..11ab979 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -524,3 +524,18 @@ async fn detokenization_with_luminous_base() { // Then assert!(response.result.contains("Hello, World!")); } + +#[tokio::test] +async fn fetch_tokenizer_for_pharia_1_llm_7b() { + // Given + let client = Client::with_authentication(api_token()).unwrap(); + + // When + let tokenizer = client + .tokenizer_by_model("Pharia-1-LLM-7B-control", None) + .await + .unwrap(); + + // Then + assert_eq!(128_000, tokenizer.get_vocab_size(true)); +} \ No newline at end of file