From 1446c7627e83f8a918731bb697d10c77d6597212 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Tue, 5 Nov 2024 12:15:05 +0000 Subject: [PATCH 01/14] Introduce eos token locator --- Cargo.toml | 5 +- src/lib.rs | 2 + src/locator.rs | 215 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 src/locator.rs diff --git a/Cargo.toml b/Cargo.toml index 0e83d02..eea078c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,10 @@ thiserror = "1.0" pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true } regex = "1.10.6" serde-pyobject = "0.4.0" -serde_json = { version = "1.0.125", features = ["preserve_order"] } +serde_json = { version = "1.0", features = ["preserve_order"] } +serde = {version = "1", features = ["derive"]} +hf-hub = "=0.3.2" +tokenizers = { version = "=0.20.0", features = ["http"] } [features] python-bindings = ["pyo3"] diff --git a/src/lib.rs b/src/lib.rs index 71787e2..576d996 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ pub mod primitives; pub mod regex; pub mod vocabulary; +mod locator; + #[cfg(feature = "python-bindings")] mod python_bindings; diff --git a/src/locator.rs b/src/locator.rs new file mode 100644 index 0000000..272e02e --- /dev/null +++ b/src/locator.rs @@ -0,0 +1,215 @@ +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; +use serde::{Deserialize, Serialize}; +use tokenizers::{FromPretrainedParameters, Tokenizer}; + +use crate::primitives::*; + +/// List of common eos token locations appearing on hugging face hub, ordered by priority. +const COMMON_LOCATIONS: &[EosTokenLocation] = &[ + // Most projects have `generation_config.json` that looks like: + // { + // ... + // "eos_token_id": 50256, + // ... + // } + // So it's the first place we look for the eos token id. + // + // For example: + // - https://huggingface.co/openai-community/gpt2/blob/main/generation_config.json + EosTokenLocation { + file: "generation_config.json", + location: EosTokenField::Id, + }, + // The ones that don't have `generation_config.json` usually have `tokenizer_config.json`: + // { + // ... + // "eos_token": "<|endoftext|>", + // ... + // } + // Once we have the eos token content, we can get its id from the tokenizer. + // + // For example: + // - https://huggingface.co/microsoft/phi-2/blob/main/tokenizer_config.json + EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Value, + }, + // Sometimes `tokenizer_config.json` can have the following format as well: + // { + // "eos_token": { + // ... + // "content": "", + // ... + // }, + // } + // Once we have the eos token content, we can get its id from the tokenizer. + // + // For example: + // - https://huggingface.co/hf-internal-testing/llama-tokenizer/blob/main/tokenizer_config.json + EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Object, + }, +]; + +#[derive(Debug, Serialize, Deserialize)] +struct Id { + eos_token_id: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Value { + eos_token: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Object { + eos_token: Content, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Content { + content: String, +} + +/// Kind of the json field which will be checked for eos token id. +enum EosTokenField { + Id, + Value, + Object, +} + +/// Location of the end of sentence token id in a config file. +struct EosTokenLocation { + file: &'static str, + location: EosTokenField, +} + +pub(crate) struct EosTokenLocator; + +impl EosTokenLocator { + pub(crate) fn locate( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + COMMON_LOCATIONS + .iter() + .find_map(|location| location.lookup(model, tokenizer, parameters)) + } +} + +impl EosTokenLocation { + /// Finds eos token within defined location in related config file. + fn lookup( + &self, + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + let file_path = Self::download_config(model, self.file, parameters).ok()?; + let file = std::fs::File::open(file_path).ok()?; + + match self.location { + EosTokenField::Id => { + let config: Id = serde_json::from_reader(file).ok()?; + u32::try_from(config.eos_token_id).ok() + } + EosTokenField::Value => { + let config: Value = serde_json::from_reader(file).ok()?; + tokenizer.token_to_id(&config.eos_token) + } + EosTokenField::Object => { + let config: Object = serde_json::from_reader(file).ok()?; + tokenizer.token_to_id(&config.eos_token.content) + } + } + } + + /// Downloads a config file from Hugging Face Hub. + fn download_config( + project: &str, + file: &str, + parameters: &Option, + ) -> tokenizers::Result { + // Adapted from + // https://github.com/huggingface/tokenizers/blob/9b77c054ef4297c7057fa8db875368c7c02f1bfc/tokenizers/src/utils/from_pretrained.rs#L26 + + let params = parameters.clone().unwrap_or_default(); + + Self::validate(project)?; + Self::validate(¶ms.revision)?; + + let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision); + let api = ApiBuilder::new() + .with_token(params.auth_token) + .build()? + .repo(repo); + + Ok(api.get(file)?) + } + + fn validate(input: &str) -> tokenizers::Result<()> { + let valid_chars = ['-', '_', '.', '/']; + + if !input + .chars() + .all(|c: char| c.is_alphanumeric() || valid_chars.contains(&c)) + { + return Err(format!( + "Input {input} contains invalid characters, expected only alphanumeric or {}", + valid_chars + .iter() + .map(|x| format!("'{}'", x)) + .collect::>() + .join(", ") + ) + .into()); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn common_locations() { + for (model, expected_token_id, expected_token) in &[ + ("openai-community/gpt2", 50256, "<|endoftext|>"), + ("microsoft/phi-2", 50256, "<|endoftext|>"), + ("hf-internal-testing/llama-tokenizer", 2, ""), + ] { + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let located = + EosTokenLocator::locate(model, &tokenizer, &None).expect("Token id is not located"); + + assert_eq!(located, *expected_token_id); + assert_eq!( + tokenizer.id_to_token(located).expect("Token is not found"), + expected_token.to_string() + ); + } + } + + #[test] + fn bad_location() { + let bad_location = EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Id, + }; + let model = "microsoft/phi-2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + let token_id = bad_location.lookup(model, &tokenizer, &None); + assert!(token_id.is_none()); + + let bad_file = EosTokenLocation { + file: "generation_config.json", + location: EosTokenField::Value, + }; + let token_id = bad_file.lookup(model, &tokenizer, &None); + assert!(token_id.is_none()); + } +} From c3b4430fe1dd2dc8489cd927b2c053f068be164f Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 7 Nov 2024 14:23:11 +0000 Subject: [PATCH 02/14] Introduce token processor --- Cargo.toml | 1 + src/lib.rs | 26 ++++ src/locator.rs | 3 +- src/processor.rs | 315 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 src/processor.rs diff --git a/Cargo.toml b/Cargo.toml index eea078c..aaf2e17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ license = "Apache-2.0" repository = "https://github.com/dottxt-ai/outlines-core" [dependencies] +once_cell = "1.20" anyhow = "1.0.86" thiserror = "1.0" pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true } diff --git a/src/lib.rs b/src/lib.rs index 576d996..2cba82a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub mod regex; pub mod vocabulary; mod locator; +mod processor; #[cfg(feature = "python-bindings")] mod python_bindings; @@ -18,6 +19,31 @@ pub enum Error { IndexError, } +#[derive(Error, Debug)] +pub enum VocabularyError { + #[error("Unable to create tokenizer for {model}, source {source}")] + UnableToCreateTokenizer { + model: String, + source: tokenizers::Error, + }, + #[error("Unable to locate EOS token for {model}")] + UnableToLocateEosTokenId { model: String }, + #[error("Unable to process token")] + TokenProcessorError(#[from] TokenProcessorError), +} + +#[derive(Error, Debug)] +pub enum TokenProcessorError { + #[error("Tokenizer is not supported")] + UnsupportedTokenizer, + #[error("Decoder unpacking failed")] + DecoderUnpackingFailed, + #[error("Token processing failed for byte level processor")] + ByteProcessorFailed, + #[error("Token processing failed for byte fallback level processor")] + ByteFallbackProcessorFailed, +} + #[cfg(feature = "python-bindings")] impl From for pyo3::PyErr { fn from(e: Error) -> Self { diff --git a/src/locator.rs b/src/locator.rs index 272e02e..61cc581 100644 --- a/src/locator.rs +++ b/src/locator.rs @@ -72,7 +72,7 @@ struct Content { content: String, } -/// Kind of the json field which will be checked for eos token id. +/// Which part in config's json to check for eos token id. enum EosTokenField { Id, Value, @@ -88,6 +88,7 @@ struct EosTokenLocation { pub(crate) struct EosTokenLocator; impl EosTokenLocator { + /// Locates eos token id by searching in defined common locations. pub(crate) fn locate( model: &str, tokenizer: &Tokenizer, diff --git a/src/processor.rs b/src/processor.rs new file mode 100644 index 0000000..5048b11 --- /dev/null +++ b/src/processor.rs @@ -0,0 +1,315 @@ +use std::collections::HashMap; + +use once_cell::sync::Lazy; +use serde::Deserialize; +use tokenizers::normalizers::Replace; +use tokenizers::{DecoderWrapper, Tokenizer}; + +use crate::TokenProcessorError; + +pub type Result = std::result::Result; + +/// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete +/// utf-8 characters. For example, b` \xf0` can be one token. These tokenizers map each +/// byte to a valid UTF-8 character. And we need to map back those characters into bytes. +/// +/// "ĠO" = [U+0120, U+004F] should be interpreted as [0x20, 0x4F] = " O" +/// or +/// "Ġal" = [U+0120, U+0061, U+006C] should be interpreted as [0x20, 0x61, 0x6C] = " al" +/// +/// We'll use the following the mapping for this transition: +/// --- +/// 'Ā' == '\u{0100}' -> 0x00 == 0 +/// 'ā' == '\u{0101}' -> 0x01 == 1 +/// 'Ă' == '\u{0102}' -> 0x02 == 2 +/// ... +/// 'Ğ' == '\u{011E}' -> 0x1E == 30 +/// 'ğ' == '\u{011F}' -> 0x1F == 31 +/// 'Ġ' == '\u{0120}' -> 0x20 == 32 +/// --- +/// '!' == '\u{0021}' -> 0x21 == 33 +/// '"' == '\u{0022}' -> 0x22 == 34 +/// '#' == '\u{0023}' -> 0x23 == 35 +/// ... +/// '|' == '\u{007C}' -> 0x7C == 124 +/// '}' == '\u{007D}' -> 0x7D == 125 +/// '~' == '\u{007E}' -> 0x7E == 126 +/// --- +/// 'ġ' == '\u{0121}' -> 0x7F == 127 +/// 'Ģ' == '\u{0122}' -> 0x80 == 128 +/// 'ģ' == '\u{0123}' -> 0x81 == 129 +/// ... +/// 'ŀ' == '\u{0140}' -> 0x9E == 158 +/// 'Ł' == '\u{0141}' -> 0x9F == 159 +/// 'ł' == '\u{0142}' -> 0xA0 == 160 +/// --- +/// '¡' == '\u{00A1}' -> 0xA1 == 161 +/// '¢' == '\u{00A2}' -> 0xA2 == 162 +/// '£' == '\u{00A3}' -> 0xA3 == 163 +/// ... +/// 'ª' == '\u{00AA}' -> 0xAA == 170 +/// '«' == '\u{00AB}' -> 0xAB == 171 +/// '¬' == '\u{00AC}' -> 0xAC == 172 +/// --- +/// 'Ń' == '\u{0143}' -> 0xAD == 173 +/// --- +/// '®' == '\u{00AE}' -> 0xAE == 174 +/// '¯' == '\u{00AF}' -> 0xAF == 175 +/// '°' == '\u{00B0}' -> 0xB0 == 176 +/// ... +/// 'ý' == '\u{00FD}' -> 0xFD == 253 +/// 'þ' == '\u{00FE}' -> 0xFE == 254 +/// 'ÿ' == '\u{00FF}' -> 0xFF == 255 +/// --- +static CHAR_MAP: Lazy> = Lazy::new(|| { + let mut char_map = HashMap::with_capacity(256); + let mut key = 0x100u32; + for byte in 0..=255u8 { + let char = byte as char; + if matches!( + char, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}', + ) { + char_map.insert(char, byte); + } else if let Some(ch) = char::from_u32(key) { + char_map.insert(ch, byte); + key += 1; + } + } + char_map +}); + +/// Token processor to adjust tokens according to the tokenizer's level. +#[derive(Debug)] +pub(crate) struct TokenProcessor { + level: TokenProcessorLevel, +} + +/// Recognized tokenizer's levels. +#[derive(Debug, Clone, PartialEq)] +pub enum TokenProcessorLevel { + /// Matches byte level tokenizer (e.g., gpt2). + Byte, + /// Matches byte fallback tokenizer (e.g., llama), which have <0x__> tokens for + /// all __ >= 0x80 to represent incomplete UTF-8 sequences. + ByteFallback(Mods), +} + +impl std::fmt::Display for TokenProcessorLevel { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Byte => write!(f, "Byte Level"), + Self::ByteFallback(mods) => write!(f, "Byte Fallback Level with mods: {:?}", mods), + } + } +} + +/// Modifications to be applied by `ByteFallback` `TokenProcessorLevel`. +#[derive(Debug, Clone, PartialEq)] +pub struct Mods { + spacechar: char, +} + +/// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. +static DEFAULT_MODS: Mods = Mods { spacechar: ' ' }; + +impl Mods { + /// Apply default modifications. + fn apply_default(&self, token: String) -> String { + let to = DEFAULT_MODS.spacechar.to_string(); + token.replace(self.spacechar, &to) + } +} + +#[derive(Debug, Deserialize)] +struct ReplaceDecoder { + content: String, + pattern: ReplacePattern, +} + +impl ReplaceDecoder { + fn space_replacement(&self) -> Option { + if self.content != " " { + return None; + } + match &self.pattern { + ReplacePattern::String(pattern) => { + let mut chars = pattern.chars(); + let char = chars.next(); + if let Some(replacement) = char { + if chars.next().is_none() { + return Some(replacement); + } + } + None + } + } + } +} + +#[derive(Debug, Deserialize)] +pub enum ReplacePattern { + String(String), +} + +impl TokenProcessor { + /// Create new `TokenProcessor` with the level defined based on tokenizer's decoders. + pub(crate) fn new(tokenizer: &Tokenizer) -> Result { + match tokenizer.get_decoder() { + None => Err(TokenProcessorError::UnsupportedTokenizer), + Some(decoder) => match decoder { + DecoderWrapper::ByteLevel(_) => Ok(Self { + level: TokenProcessorLevel::Byte, + }), + DecoderWrapper::Sequence(decoding_sequence) => { + let mut is_byte_fallback = false; + let mut spacechar = ' '; + + for decoder in decoding_sequence.get_decoders() { + match decoder { + DecoderWrapper::ByteFallback(_) => { + is_byte_fallback = true; + } + DecoderWrapper::Replace(replace) => { + // `Replace` decoder would replace a pattern in the output with something else, + // which we need to know. + let decoder = Self::unpack_decoder(replace)?; + if let Some(replacement) = decoder.space_replacement() { + spacechar = replacement; + } + } + _ => {} + } + } + + if is_byte_fallback { + Ok(Self { + level: TokenProcessorLevel::ByteFallback(Mods { spacechar }), + }) + } else { + Err(TokenProcessorError::UnsupportedTokenizer) + } + } + _ => Err(TokenProcessorError::UnsupportedTokenizer), + }, + } + } + + /// Process each token based on the level ofTokenProcesso. + pub(crate) fn process(&self, token: String) -> Result> { + match &self.level { + TokenProcessorLevel::Byte => { + let mut bytes = vec![]; + for char in token.chars() { + match CHAR_MAP.get(&char) { + None => return Err(TokenProcessorError::ByteProcessorFailed), + Some(b) => bytes.push(*b), + } + } + Ok(bytes) + } + TokenProcessorLevel::ByteFallback(mods) => { + // If the token is of form `<0x__>`: + if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') { + // Get to a single byte specified in the __ part and parse it in base 16 to a byte. + match u8::from_str_radix(&token[3..5], 16) { + Ok(byte) => Ok([byte].to_vec()), + Err(_) => Err(TokenProcessorError::ByteFallbackProcessorFailed), + } + } else { + Ok(mods.apply_default(token).as_bytes().to_vec()) + } + } + } + } + + /// Since all fields of `Replace` are private with no getters, we'll have to unpack it into our own. + fn unpack_decoder(decoder: &Replace) -> Result { + match serde_json::to_value(decoder) { + Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed), + Ok(value) => match serde_json::from_value(value) { + Ok(d) => Ok(d), + Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn byte_level_processor() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + assert_eq!(processor.level, TokenProcessorLevel::Byte); + + for (ch, byte) in [ + ('Ā', 0x00), + ('ā', 0x01), + ('Ă', 0x02), + ('Ğ', 0x1E), + ('ğ', 0x1F), + ('Ġ', 0x20), + ('!', 0x21), + ('"', 0x22), + ('#', 0x23), + ('|', 0x7C), + ('}', 0x7D), + ('~', 0x7E), + ('ġ', 0x7F), + ('Ģ', 0x80), + ('ģ', 0x81), + ('ŀ', 0x9E), + ('Ł', 0x9F), + ('ł', 0xA0), + ('¡', 0xA1), + ('¢', 0xA2), + ('£', 0xA3), + ('ª', 0xAA), + ('«', 0xAB), + ('¬', 0xAC), + ('Ń', 0xAD), + ('®', 0xAE), + ('¯', 0xAF), + ('°', 0xB0), + ('ý', 0xFD), + ('þ', 0xFE), + ('ÿ', 0xFF), + ] { + let processed = processor.process(ch.to_string()).expect("Not processed"); + assert_eq!(processed, [byte]); + } + } + + #[test] + fn byte_fallback_level_processor() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + let spacechar = '▁'; + let mods = Mods { spacechar }; + + assert_eq!(processor.level, TokenProcessorLevel::ByteFallback(mods)); + + for (input, expected) in [ + ("abc", vec![0x61, 0x62, 0x63]), + ("<0x61>", vec![0x61]), + ("<0x61>a", vec![0x3C, 0x30, 0x78, 0x36, 0x31, 0x3E, 0x61]), + (&spacechar.to_string(), vec![0x20]), + ( + &format!("{}{}abc", spacechar, spacechar), + vec![0x20, 0x20, 0x61, 0x62, 0x63], + ), + ( + &format!("{}{}{}", spacechar, spacechar, spacechar), + vec![0x20, 0x20, 0x20], + ), + ] { + let processed = processor.process(input.to_string()).expect("Not processed"); + assert_eq!(processed, expected); + } + } +} From e85ce9082ae38a1a3c30afcca5ffc9ed38af0cd1 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 8 Nov 2024 17:35:53 +0000 Subject: [PATCH 03/14] Extend vocabulary with eos token id & pretrained models --- src/prelude.rs | 6 -- src/regex.rs | 1 + src/vocabulary.rs | 201 ++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 187 insertions(+), 21 deletions(-) diff --git a/src/prelude.rs b/src/prelude.rs index e196e47..d42516b 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -2,9 +2,3 @@ pub use super::{ primitives::{State, Token, TokenId, TransitionKey}, vocabulary::Vocabulary, }; - -pub(crate) use std::{ - collections::{HashMap, HashSet}, - fmt::{self, Display}, - ops::Deref, -}; diff --git a/src/regex.rs b/src/regex.rs index a41bf86..b565819 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use std::collections::{HashMap, HashSet}; pub fn walk_fsm( fsm_transitions: &HashMap<(State, TransitionKey), State>, diff --git a/src/vocabulary.rs b/src/vocabulary.rs index f03df8f..438e94b 100644 --- a/src/vocabulary.rs +++ b/src/vocabulary.rs @@ -1,4 +1,12 @@ +use std::collections::HashMap; + +use tokenizers::normalizers::Sequence; +use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; + +use crate::locator::EosTokenLocator; use crate::prelude::*; +use crate::processor::TokenProcessor; +use crate::VocabularyError; /// Vocabulary of an LLM. /// @@ -7,19 +15,116 @@ use crate::prelude::*; /// ```rust /// # use outlines_core::prelude::*; /// # -/// let vocabulary = Vocabulary::new() +/// let vocabulary = Vocabulary::new(None) /// .insert("blah", 0) /// .insert("1a", 1) /// .insert("2", 2) /// .insert("0", 3); /// ``` #[derive(Clone, Debug, Default)] -pub struct Vocabulary(pub(crate) HashMap>); +pub struct Vocabulary { + // TODO: Option is temp for back compatibility + eos_token_id: Option, + map: HashMap>, +} impl Vocabulary { /// Creates an empty vocabulary. - pub fn new() -> Vocabulary { - Vocabulary::default() + pub fn new(eos_token_id: Option) -> Self { + Self { + eos_token_id, + map: HashMap::new(), + } + } + + /// Creates the vocabulary of pre-trained model from Hugging Face Hub. + pub fn from_pretrained( + model: &str, + parameters: Option, + ) -> Result { + let mut tokenizer = + Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| { + VocabularyError::UnableToCreateTokenizer { + model: model.to_string(), + source: error, + } + })?; + Self::filter_normalizers(&mut tokenizer); + + let eos_token_id = EosTokenLocator::locate(model, &tokenizer, ¶meters); + let Some(eos_token_id) = eos_token_id else { + return Err(VocabularyError::UnableToLocateEosTokenId { + model: model.to_string(), + }); + }; + + Vocabulary::try_from((&mut tokenizer, eos_token_id)) + } + + /// Per provided token returns vector of `TokenId`s if available in vocabulary. + pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { + self.map.get(token) + } + + /// Gets the identifier of the special end of sentence token. + pub fn eos_token_id(&self) -> Option { + self.eos_token_id + } + + fn filter_normalizers(tokenizer: &mut Tokenizer) { + // Main concern is prepend normalizers, for example https://github.com/google/sentencepiece + // In `sentencepiece` tokenizer, `▁` is used to denote spaces in the source text, + // e.g. `Hello World.` could be tokenized as: [Hello] [▁Wor] [ld] [.] + // + // We don't want to deal with the special characters, so we remove `Prepend` normalizers. + if let Some(normalizer) = tokenizer.get_normalizer() { + match normalizer { + NormalizerWrapper::Sequence(normalization_sequence) => { + let new_sequence = Sequence::new( + normalization_sequence + .get_normalizers() + .iter() + .filter_map(|normalizer| match normalizer { + NormalizerWrapper::Prepend(_) => None, + _ => Some(normalizer.clone()), + }) + .collect(), + ); + tokenizer.with_normalizer(new_sequence.into()); + } + NormalizerWrapper::Prepend(_) => { + tokenizer.with_normalizer(None::); + } + _ => {} + } + } + } +} + +impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary { + type Error = VocabularyError; + + fn try_from(value: (&mut Tokenizer, u32)) -> Result { + let (tokenizer, eos_token_id) = value; + + let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { + if !added_token.special { + vocabulary = vocabulary.insert(added_token.content.clone(), *id); + } + } + + let processor = TokenProcessor::new(tokenizer)?; + for (token, token_id) in tokenizer.get_vocab(false) { + let token_bytes = processor.process(token)?; + // TODO: lossy is temp: + // - in python in was handled by byte_symbol function + // - interface needs to be redefined to treat Token type as bytes: Vec + let processed_token = String::from_utf8_lossy(&token_bytes); + vocabulary = vocabulary.insert(processed_token, token_id); + } + + Ok(vocabulary) } } @@ -43,8 +148,9 @@ impl Vocabulary { impl Vocabulary { /// Inserts a token to the vocabulary with the specified identifier, in place. pub fn insert_in_place(&mut self, token: impl Into, id: TokenId) { + // TODO: return error if eos token id is inserted let token = token.into(); - self.0.entry(token).or_default().push(id); + self.map.entry(token).or_default().push(id); } /// Extends the vocabulary with tokens and their identifiers, in place. @@ -54,21 +160,21 @@ impl Vocabulary { ) { for (token, ids) in tokens_and_ids.into_iter() { let token = token.into(); - self.0.entry(token).or_default().extend(ids); + self.map.entry(token).or_default().extend(ids); } } } -impl Deref for Vocabulary { +impl std::ops::Deref for Vocabulary { type Target = HashMap>; fn deref(&self) -> &HashMap> { - &self.0 + &self.map } } -impl Display for Vocabulary { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl std::fmt::Display for Vocabulary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for (index, (token, token_ids)) in self.iter().enumerate() { if index != (self.len() - 1) { writeln!(f, "{:?} -> {:?}", token, token_ids)?; @@ -82,7 +188,10 @@ impl Display for Vocabulary { impl From>> for Vocabulary { fn from(map: HashMap>) -> Vocabulary { - Vocabulary(map) + Vocabulary { + eos_token_id: None, + map, + } } } @@ -92,17 +201,17 @@ where I: IntoIterator, { fn from_iter>(tokens_and_ids: A) -> Self { - Vocabulary::new().extend(tokens_and_ids) + Vocabulary::new(None).extend(tokens_and_ids) } } #[cfg(test)] mod tests { - use crate::prelude::*; + use super::*; #[test] fn insert() { - let vocabulary = Vocabulary::new() + let vocabulary = Vocabulary::new(None) .insert("blah", 0) .insert("1a", 1) .insert("2", 2) @@ -117,7 +226,7 @@ mod tests { #[test] fn extend() { - let vocabulary = Vocabulary::new().extend([ + let vocabulary = Vocabulary::new(None).extend([ ("blah", vec![0]), ("1a", vec![1]), ("2", vec![2]), @@ -130,4 +239,66 @@ mod tests { assert_eq!(vocabulary["2"], &[2]); assert_eq!(vocabulary["0"], &[3]); } + + #[test] + fn pretrained_from_gpt2() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); + + let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + assert_eq!(v_eos, 50256); + assert_eq!( + tokenizer.id_to_token(v_eos).expect("Token not found"), + "<|endoftext|>" + ); + + let token = "Ġal"; + assert!(vocabulary.token_to_ids(token).is_none()); + assert!(tokenizer.token_to_id(token).is_some()); + + for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] { + let v_ids = vocabulary.token_to_ids(v_token); + assert!(v_ids.is_some()); + for v_id in v_ids.unwrap() { + let t_token = tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer"); + assert_eq!(&t_token, t_token_expected); + } + } + } + + #[test] + fn pretrained_from_llama() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); + + let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + assert_eq!(v_eos, 2); + assert_eq!( + tokenizer.id_to_token(v_eos).expect("Token not found"), + "" + ); + + for (v_token, t_token_expected) in [ + ("abc", "abc"), + (" al", "▁al"), + (" O", "▁O"), + (" ", "▁▁▁"), + // TODO: won't pass since first we need to change token's type to bytes + // ("<0xFF>", "ÿ"), + // ("<0x20>", "▁"), + ] { + let v_ids = vocabulary.token_to_ids(v_token); + assert!(v_ids.is_some()); + for v_id in v_ids.unwrap() { + let t_token = tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer"); + assert_eq!(&t_token, t_token_expected); + } + } + } } From 5eb350dec74281128c0fcb6d76261c5e72523c8f Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Fri, 8 Nov 2024 17:47:26 +0000 Subject: [PATCH 04/14] Move vocabulary to a dedicated module --- src/lib.rs | 3 --- src/{ => vocabulary}/locator.rs | 0 src/{vocabulary.rs => vocabulary/mod.rs} | 8 ++++++-- src/{ => vocabulary}/processor.rs | 0 4 files changed, 6 insertions(+), 5 deletions(-) rename src/{ => vocabulary}/locator.rs (100%) rename src/{vocabulary.rs => vocabulary/mod.rs} (99%) rename src/{ => vocabulary}/processor.rs (100%) diff --git a/src/lib.rs b/src/lib.rs index 2cba82a..e397800 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,6 @@ pub mod primitives; pub mod regex; pub mod vocabulary; -mod locator; -mod processor; - #[cfg(feature = "python-bindings")] mod python_bindings; diff --git a/src/locator.rs b/src/vocabulary/locator.rs similarity index 100% rename from src/locator.rs rename to src/vocabulary/locator.rs diff --git a/src/vocabulary.rs b/src/vocabulary/mod.rs similarity index 99% rename from src/vocabulary.rs rename to src/vocabulary/mod.rs index 438e94b..2851752 100644 --- a/src/vocabulary.rs +++ b/src/vocabulary/mod.rs @@ -3,11 +3,15 @@ use std::collections::HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::locator::EosTokenLocator; use crate::prelude::*; -use crate::processor::TokenProcessor; use crate::VocabularyError; +use locator::EosTokenLocator; +use processor::TokenProcessor; + +mod locator; +mod processor; + /// Vocabulary of an LLM. /// /// ## Examples diff --git a/src/processor.rs b/src/vocabulary/processor.rs similarity index 100% rename from src/processor.rs rename to src/vocabulary/processor.rs From 9b42797904161e6e9810dbd940f847a7bff23304 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Mon, 11 Nov 2024 19:31:38 +0000 Subject: [PATCH 05/14] Add more tests --- src/lib.rs | 16 +++++++-- src/vocabulary/mod.rs | 65 +++++++++++++++++++++++++++++++++++-- src/vocabulary/processor.rs | 42 +++++++++++++++++++++++- 3 files changed, 116 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e397800..695c529 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,18 +10,28 @@ mod python_bindings; use thiserror::Error; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum Error { #[error("The vocabulary does not allow us to build a sequence that matches the input")] IndexError, } #[derive(Error, Debug)] +#[error("Tokenizer error")] +pub struct TokenizerError(tokenizers::Error); + +impl PartialEq for TokenizerError { + fn eq(&self, other: &Self) -> bool { + self.0.to_string() == other.0.to_string() + } +} + +#[derive(Error, Debug, PartialEq)] pub enum VocabularyError { #[error("Unable to create tokenizer for {model}, source {source}")] UnableToCreateTokenizer { model: String, - source: tokenizers::Error, + source: TokenizerError, }, #[error("Unable to locate EOS token for {model}")] UnableToLocateEosTokenId { model: String }, @@ -29,7 +39,7 @@ pub enum VocabularyError { TokenProcessorError(#[from] TokenProcessorError), } -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq)] pub enum TokenProcessorError { #[error("Tokenizer is not supported")] UnsupportedTokenizer, diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 2851752..fef311f 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -3,8 +3,7 @@ use std::collections::HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::prelude::*; -use crate::VocabularyError; +use crate::{prelude::*, TokenizerError, VocabularyError}; use locator::EosTokenLocator; use processor::TokenProcessor; @@ -50,7 +49,7 @@ impl Vocabulary { Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| { VocabularyError::UnableToCreateTokenizer { model: model.to_string(), - source: error, + source: TokenizerError(error), } })?; Self::filter_normalizers(&mut tokenizer); @@ -305,4 +304,64 @@ mod tests { } } } + + #[test] + fn token_processor_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let vocabulary = Vocabulary::from_pretrained(model, None); + + assert!(vocabulary.is_err()); + if let Err(e) = vocabulary { + assert_eq!( + e, + VocabularyError::TokenProcessorError( + crate::TokenProcessorError::UnsupportedTokenizer + ) + ) + } + } + + #[test] + fn tokenizer_error() { + let model = "hf-internal-testing/some-non-existent-model"; + let vocabulary = Vocabulary::from_pretrained(model, None); + + assert!(vocabulary.is_err()); + if let Err(VocabularyError::UnableToCreateTokenizer { model, source }) = vocabulary { + assert_eq!(model, model.to_string()); + assert_eq!(source.to_string(), "Tokenizer error".to_string()); + } + } + + #[test] + fn prepend_normalizers_filtered_out() { + use tokenizers::normalizers::{Prepend, Sequence}; + + let prepend = Prepend::new("_".to_string()); + let prepend_normalizer = NormalizerWrapper::Prepend(prepend); + let sequence = Sequence::new(vec![prepend_normalizer.clone()]); + let sequence_normalizer = NormalizerWrapper::Sequence(sequence); + + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + for normalizer in [prepend_normalizer, sequence_normalizer] { + let mut normalized_t = tokenizer.clone(); + normalized_t.with_normalizer(Some(normalizer)); + Vocabulary::filter_normalizers(&mut normalized_t); + if let Some(n) = normalized_t.get_normalizer() { + match n { + NormalizerWrapper::Sequence(seq) => { + for n in seq.get_normalizers() { + if let NormalizerWrapper::Prepend(_) = n { + unreachable!() + } + } + } + NormalizerWrapper::Prepend(_) => unreachable!(), + _ => {} + } + } + } + } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 5048b11..ce149f8 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -194,7 +194,7 @@ impl TokenProcessor { } } - /// Process each token based on the level ofTokenProcesso. + /// Process each token based on the level of `TokenProcessor`. pub(crate) fn process(&self, token: String) -> Result> { match &self.level { TokenProcessorLevel::Byte => { @@ -312,4 +312,44 @@ mod tests { assert_eq!(processed, expected); } } + + #[test] + fn unsupported_tokenizer_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + let result = TokenProcessor::new(&tokenizer); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, TokenProcessorError::UnsupportedTokenizer) + } + } + + #[test] + fn byte_processor_error() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + for token in ["𝒜𝒷𝒸𝒟𝓔", "🦄🌈🌍🔥🎉", "京东购物"] { + let result = processor.process(token.to_string()); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, TokenProcessorError::ByteProcessorFailed) + } + } + } + + #[test] + fn byte_fallback_processor_error() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + let result = processor.process("<0x6y>".to_string()); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, TokenProcessorError::ByteFallbackProcessorFailed) + } + } } From d09ea690f6f81d04dceda70c4ce99b69014664c2 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Mon, 11 Nov 2024 20:54:45 +0000 Subject: [PATCH 06/14] Improve documentation and visibilities --- src/vocabulary/locator.rs | 12 ++++++++---- src/vocabulary/mod.rs | 11 ++++++----- src/vocabulary/processor.rs | 23 +++++++++++++---------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 61cc581..b3c6f06 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -52,16 +52,20 @@ const COMMON_LOCATIONS: &[EosTokenLocation] = &[ }, ]; +/// `Id` kind of `EosTokenField`, when `eos_token_id` provided as an id. #[derive(Debug, Serialize, Deserialize)] struct Id { eos_token_id: u64, } +/// `Value` kind of `EosTokenField`, when `eos_token` provided as a text, so that its id +/// will be fetched from the tokenizer. #[derive(Debug, Serialize, Deserialize)] struct Value { eos_token: String, } +/// `Object` kind of `EosTokenField`, when `eos_token` provided as a `Content`. #[derive(Debug, Serialize, Deserialize)] struct Object { eos_token: Content, @@ -72,14 +76,14 @@ struct Content { content: String, } -/// Which part in config's json to check for eos token id. +/// Specifies in which part in config's json to check for eos token id. enum EosTokenField { Id, Value, Object, } -/// Location of the end of sentence token id in a config file. +/// Defines location of the end of sentence token id in the config file. struct EosTokenLocation { file: &'static str, location: EosTokenField, @@ -101,7 +105,7 @@ impl EosTokenLocator { } impl EosTokenLocation { - /// Finds eos token within defined location in related config file. + /// Finds eos token within defined location in a related config file. fn lookup( &self, model: &str, @@ -127,7 +131,7 @@ impl EosTokenLocation { } } - /// Downloads a config file from Hugging Face Hub. + /// Downloads related config file from Hugging Face Hub. fn download_config( project: &str, file: &str, diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index fef311f..b62c22e 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -52,7 +52,7 @@ impl Vocabulary { source: TokenizerError(error), } })?; - Self::filter_normalizers(&mut tokenizer); + Self::filter_prepend_normalizers(&mut tokenizer); let eos_token_id = EosTokenLocator::locate(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { @@ -64,17 +64,18 @@ impl Vocabulary { Vocabulary::try_from((&mut tokenizer, eos_token_id)) } - /// Per provided token returns vector of `TokenId`s if available in vocabulary. + /// Per provided token returns vector of `TokenId`s if available in the vocabulary. pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { self.map.get(token) } - /// Gets the identifier of the special end of sentence token. + /// Gets the identifier of the special end of the sentence token. pub fn eos_token_id(&self) -> Option { self.eos_token_id } - fn filter_normalizers(tokenizer: &mut Tokenizer) { + /// Filters out `Prepend` kind of tokenizer's normalizers. + fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) { // Main concern is prepend normalizers, for example https://github.com/google/sentencepiece // In `sentencepiece` tokenizer, `▁` is used to denote spaces in the source text, // e.g. `Hello World.` could be tokenized as: [Hello] [▁Wor] [ld] [.] @@ -348,7 +349,7 @@ mod tests { for normalizer in [prepend_normalizer, sequence_normalizer] { let mut normalized_t = tokenizer.clone(); normalized_t.with_normalizer(Some(normalizer)); - Vocabulary::filter_normalizers(&mut normalized_t); + Vocabulary::filter_prepend_normalizers(&mut normalized_t); if let Some(n) = normalized_t.get_normalizer() { match n { NormalizerWrapper::Sequence(seq) => { diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index ce149f8..9488a78 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -7,11 +7,12 @@ use tokenizers::{DecoderWrapper, Tokenizer}; use crate::TokenProcessorError; -pub type Result = std::result::Result; +type Result = std::result::Result; /// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete -/// utf-8 characters. For example, b` \xf0` can be one token. These tokenizers map each -/// byte to a valid UTF-8 character. And we need to map back those characters into bytes. +/// UTF-8 characters, for example, byte ` \xf0` can be one token. These tokenizers map each +/// byte to a valid UTF-8 character, `TokenProcessor` of `ByteFallback` level will be used +/// to map back these type of characters into bytes, based on `CHAR_MAP`. /// /// "ĠO" = [U+0120, U+004F] should be interpreted as [0x20, 0x4F] = " O" /// or @@ -84,9 +85,9 @@ pub(crate) struct TokenProcessor { level: TokenProcessorLevel, } -/// Recognized tokenizer's levels. +/// Recognizes different tokenizer's levels. #[derive(Debug, Clone, PartialEq)] -pub enum TokenProcessorLevel { +pub(crate) enum TokenProcessorLevel { /// Matches byte level tokenizer (e.g., gpt2). Byte, /// Matches byte fallback tokenizer (e.g., llama), which have <0x__> tokens for @@ -103,9 +104,9 @@ impl std::fmt::Display for TokenProcessorLevel { } } -/// Modifications to be applied by `ByteFallback` `TokenProcessorLevel`. +/// Modifications to be applied by `TokenProcessor`of `ByteFallback` level. #[derive(Debug, Clone, PartialEq)] -pub struct Mods { +pub(crate) struct Mods { spacechar: char, } @@ -120,6 +121,7 @@ impl Mods { } } +/// Local structure to be deserialized into from HF's `ReplaceDecoder` in order to get a replace pattern. #[derive(Debug, Deserialize)] struct ReplaceDecoder { content: String, @@ -147,7 +149,7 @@ impl ReplaceDecoder { } #[derive(Debug, Deserialize)] -pub enum ReplacePattern { +enum ReplacePattern { String(String), } @@ -194,7 +196,7 @@ impl TokenProcessor { } } - /// Process each token based on the level of `TokenProcessor`. + /// Operates on each token based on the level of `TokenProcessor`. pub(crate) fn process(&self, token: String) -> Result> { match &self.level { TokenProcessorLevel::Byte => { @@ -222,7 +224,8 @@ impl TokenProcessor { } } - /// Since all fields of `Replace` are private with no getters, we'll have to unpack it into our own. + /// Since all fields of HF's `Replace` are private with no getters, it needs to be unpacked + /// into local `ReplaceDecoder` structure. fn unpack_decoder(decoder: &Replace) -> Result { match serde_json::to_value(decoder) { Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed), From c7af9810d7058440e55cbe50dc10088ded650e1e Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Tue, 12 Nov 2024 14:09:52 +0000 Subject: [PATCH 07/14] Separate and simplify errors --- src/error.rs | 27 +++++++++++++++++++ src/index.rs | 5 ++-- src/lib.rs | 54 +++---------------------------------- src/vocabulary/mod.rs | 26 +++++++----------- src/vocabulary/processor.rs | 41 +++++++++++++--------------- 5 files changed, 62 insertions(+), 91 deletions(-) create mode 100644 src/error.rs diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..ff977a7 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,27 @@ +use thiserror::Error; + +#[derive(Error, Debug, PartialEq)] +pub enum Error { + #[error("The vocabulary does not allow us to build a sequence that matches the input")] + IndexError, + #[error("Unable to create tokenizer for {model}")] + UnableToCreateTokenizer { model: String }, + #[error("Unable to locate EOS token for {model}")] + UnableToLocateEosTokenId { model: String }, + #[error("Tokenizer is not supported by token processor")] + UnsupportedByTokenProcessor, + #[error("Decoder unpacking failed for token processor")] + DecoderUnpackingFailed, + #[error("Token processing failed for byte level processor")] + ByteProcessorFailed, + #[error("Token processing failed for byte fallback level processor")] + ByteFallbackProcessorFailed, +} + +#[cfg(feature = "python-bindings")] +impl From for pyo3::PyErr { + fn from(e: Error) -> Self { + use pyo3::{exceptions::PyValueError, PyErr}; + PyErr::new::(e.to_string()) + } +} diff --git a/src/index.rs b/src/index.rs index 587cd76..cc1187e 100644 --- a/src/index.rs +++ b/src/index.rs @@ -2,10 +2,9 @@ use crate::prelude::{State, TransitionKey}; use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; +use crate::{Error, Result}; use std::collections::{HashMap, HashSet}; -pub type Result = std::result::Result; - #[derive(Debug)] pub struct FSMInfo { pub(crate) initial: State, @@ -101,7 +100,7 @@ impl Index { eos_token_id, }) } else { - Err(crate::Error::IndexError) + Err(Error::IndexError) } } diff --git a/src/lib.rs b/src/lib.rs index 695c529..4c45de4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod error; pub mod index; pub mod json_schema; pub mod prelude; @@ -5,56 +6,9 @@ pub mod primitives; pub mod regex; pub mod vocabulary; -#[cfg(feature = "python-bindings")] -mod python_bindings; - -use thiserror::Error; - -#[derive(Error, Debug, PartialEq)] -pub enum Error { - #[error("The vocabulary does not allow us to build a sequence that matches the input")] - IndexError, -} +use error::Error; -#[derive(Error, Debug)] -#[error("Tokenizer error")] -pub struct TokenizerError(tokenizers::Error); - -impl PartialEq for TokenizerError { - fn eq(&self, other: &Self) -> bool { - self.0.to_string() == other.0.to_string() - } -} - -#[derive(Error, Debug, PartialEq)] -pub enum VocabularyError { - #[error("Unable to create tokenizer for {model}, source {source}")] - UnableToCreateTokenizer { - model: String, - source: TokenizerError, - }, - #[error("Unable to locate EOS token for {model}")] - UnableToLocateEosTokenId { model: String }, - #[error("Unable to process token")] - TokenProcessorError(#[from] TokenProcessorError), -} - -#[derive(Error, Debug, PartialEq)] -pub enum TokenProcessorError { - #[error("Tokenizer is not supported")] - UnsupportedTokenizer, - #[error("Decoder unpacking failed")] - DecoderUnpackingFailed, - #[error("Token processing failed for byte level processor")] - ByteProcessorFailed, - #[error("Token processing failed for byte fallback level processor")] - ByteFallbackProcessorFailed, -} +pub type Result = std::result::Result; #[cfg(feature = "python-bindings")] -impl From for pyo3::PyErr { - fn from(e: Error) -> Self { - use pyo3::{exceptions::PyValueError, PyErr}; - PyErr::new::(e.to_string()) - } -} +mod python_bindings; diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index b62c22e..45aaef5 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -3,7 +3,8 @@ use std::collections::HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::{prelude::*, TokenizerError, VocabularyError}; +use crate::prelude::*; +use crate::{Error, Result}; use locator::EosTokenLocator; use processor::TokenProcessor; @@ -44,19 +45,18 @@ impl Vocabulary { pub fn from_pretrained( model: &str, parameters: Option, - ) -> Result { + ) -> Result { let mut tokenizer = - Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| { - VocabularyError::UnableToCreateTokenizer { + Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| { + Error::UnableToCreateTokenizer { model: model.to_string(), - source: TokenizerError(error), } })?; Self::filter_prepend_normalizers(&mut tokenizer); let eos_token_id = EosTokenLocator::locate(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { - return Err(VocabularyError::UnableToLocateEosTokenId { + return Err(Error::UnableToLocateEosTokenId { model: model.to_string(), }); }; @@ -106,9 +106,9 @@ impl Vocabulary { } impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary { - type Error = VocabularyError; + type Error = Error; - fn try_from(value: (&mut Tokenizer, u32)) -> Result { + fn try_from(value: (&mut Tokenizer, u32)) -> Result { let (tokenizer, eos_token_id) = value; let mut vocabulary = Vocabulary::new(Some(eos_token_id)); @@ -313,12 +313,7 @@ mod tests { assert!(vocabulary.is_err()); if let Err(e) = vocabulary { - assert_eq!( - e, - VocabularyError::TokenProcessorError( - crate::TokenProcessorError::UnsupportedTokenizer - ) - ) + assert_eq!(e, Error::UnsupportedByTokenProcessor) } } @@ -328,9 +323,8 @@ mod tests { let vocabulary = Vocabulary::from_pretrained(model, None); assert!(vocabulary.is_err()); - if let Err(VocabularyError::UnableToCreateTokenizer { model, source }) = vocabulary { + if let Err(Error::UnableToCreateTokenizer { model }) = vocabulary { assert_eq!(model, model.to_string()); - assert_eq!(source.to_string(), "Tokenizer error".to_string()); } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 9488a78..cec32f5 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -5,9 +5,7 @@ use serde::Deserialize; use tokenizers::normalizers::Replace; use tokenizers::{DecoderWrapper, Tokenizer}; -use crate::TokenProcessorError; - -type Result = std::result::Result; +use crate::{Error, Result}; /// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete /// UTF-8 characters, for example, byte ` \xf0` can be one token. These tokenizers map each @@ -157,7 +155,7 @@ impl TokenProcessor { /// Create new `TokenProcessor` with the level defined based on tokenizer's decoders. pub(crate) fn new(tokenizer: &Tokenizer) -> Result { match tokenizer.get_decoder() { - None => Err(TokenProcessorError::UnsupportedTokenizer), + None => Err(Error::UnsupportedByTokenProcessor), Some(decoder) => match decoder { DecoderWrapper::ByteLevel(_) => Ok(Self { level: TokenProcessorLevel::Byte, @@ -188,10 +186,10 @@ impl TokenProcessor { level: TokenProcessorLevel::ByteFallback(Mods { spacechar }), }) } else { - Err(TokenProcessorError::UnsupportedTokenizer) + Err(Error::UnsupportedByTokenProcessor) } } - _ => Err(TokenProcessorError::UnsupportedTokenizer), + _ => Err(Error::UnsupportedByTokenProcessor), }, } } @@ -199,23 +197,22 @@ impl TokenProcessor { /// Operates on each token based on the level of `TokenProcessor`. pub(crate) fn process(&self, token: String) -> Result> { match &self.level { - TokenProcessorLevel::Byte => { - let mut bytes = vec![]; - for char in token.chars() { - match CHAR_MAP.get(&char) { - None => return Err(TokenProcessorError::ByteProcessorFailed), - Some(b) => bytes.push(*b), - } - } - Ok(bytes) - } + TokenProcessorLevel::Byte => token + .chars() + .map(|char| { + CHAR_MAP + .get(&char) + .copied() + .ok_or(Error::ByteProcessorFailed) + }) + .collect(), TokenProcessorLevel::ByteFallback(mods) => { // If the token is of form `<0x__>`: if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') { // Get to a single byte specified in the __ part and parse it in base 16 to a byte. match u8::from_str_radix(&token[3..5], 16) { Ok(byte) => Ok([byte].to_vec()), - Err(_) => Err(TokenProcessorError::ByteFallbackProcessorFailed), + Err(_) => Err(Error::ByteFallbackProcessorFailed), } } else { Ok(mods.apply_default(token).as_bytes().to_vec()) @@ -228,10 +225,10 @@ impl TokenProcessor { /// into local `ReplaceDecoder` structure. fn unpack_decoder(decoder: &Replace) -> Result { match serde_json::to_value(decoder) { - Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed), + Err(_) => Err(Error::DecoderUnpackingFailed), Ok(value) => match serde_json::from_value(value) { Ok(d) => Ok(d), - Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed), + Err(_) => Err(Error::DecoderUnpackingFailed), }, } } @@ -324,7 +321,7 @@ mod tests { let result = TokenProcessor::new(&tokenizer); assert!(result.is_err()); if let Err(e) = result { - assert_eq!(e, TokenProcessorError::UnsupportedTokenizer) + assert_eq!(e, Error::UnsupportedByTokenProcessor) } } @@ -338,7 +335,7 @@ mod tests { let result = processor.process(token.to_string()); assert!(result.is_err()); if let Err(e) = result { - assert_eq!(e, TokenProcessorError::ByteProcessorFailed) + assert_eq!(e, Error::ByteProcessorFailed) } } } @@ -352,7 +349,7 @@ mod tests { let result = processor.process("<0x6y>".to_string()); assert!(result.is_err()); if let Err(e) = result { - assert_eq!(e, TokenProcessorError::ByteFallbackProcessorFailed) + assert_eq!(e, Error::ByteFallbackProcessorFailed) } } } From fb833aeccedb8d99fb94b8c908aeecc5a6fd0e47 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Tue, 12 Nov 2024 17:02:11 +0000 Subject: [PATCH 08/14] Python: ignore deprecations warnings ERROR tests/fsm/test_regex.py - RuntimeError: Failed to import transformers.models.auto.tokenization_auto because of the following error (look up to see its traceback): Failed to import transformers.generation.utils because of the following error (look up to see its traceback): numpy.core is deprecated and has been renamed to numpy._core. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.multiarray. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3d1d0d4..5709098 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ filterwarnings = [ "error", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::UserWarning", + "ignore::DeprecationWarning", ] addopts = [ "--import-mode=importlib" From 334dab0dbb20447d690e5983f210a0ce20d5d35b Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 13 Nov 2024 10:39:56 +0000 Subject: [PATCH 09/14] Apply suggestions from CR --- Cargo.toml | 7 ++++--- src/vocabulary/locator.rs | 26 +++++++++++--------------- src/vocabulary/mod.rs | 27 +++++++++++++-------------- 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aaf2e17..947620f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,14 +9,15 @@ repository = "https://github.com/dottxt-ai/outlines-core" [dependencies] once_cell = "1.20" anyhow = "1.0.86" -thiserror = "1.0" +thiserror = "2.0" pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true } regex = "1.10.6" serde-pyobject = "0.4.0" serde_json = { version = "1.0", features = ["preserve_order"] } -serde = {version = "1", features = ["derive"]} +serde = {version = "1.0", features = ["derive"]} +# Fragile dependencies, minor updates often break the code hf-hub = "=0.3.2" -tokenizers = { version = "=0.20.0", features = ["http"] } +tokenizers = { version = "=0.20.3", features = ["http"] } [features] python-bindings = ["pyo3"] diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index b3c6f06..31a3548 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -89,19 +89,15 @@ struct EosTokenLocation { location: EosTokenField, } -pub(crate) struct EosTokenLocator; - -impl EosTokenLocator { - /// Locates eos token id by searching in defined common locations. - pub(crate) fn locate( - model: &str, - tokenizer: &Tokenizer, - parameters: &Option, - ) -> Option { - COMMON_LOCATIONS - .iter() - .find_map(|location| location.lookup(model, tokenizer, parameters)) - } +/// Locates eos token id by searching in defined common locations. +pub(crate) fn locate_eos_token_id( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, +) -> Option { + COMMON_LOCATIONS + .iter() + .find_map(|location| location.lookup(model, tokenizer, parameters)) } impl EosTokenLocation { @@ -147,7 +143,7 @@ impl EosTokenLocation { let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision); let api = ApiBuilder::new() - .with_token(params.auth_token) + .with_token(params.token) .build()? .repo(repo); @@ -188,7 +184,7 @@ mod tests { ] { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let located = - EosTokenLocator::locate(model, &tokenizer, &None).expect("Token id is not located"); + locate_eos_token_id(model, &tokenizer, &None).expect("Token id is not located"); assert_eq!(located, *expected_token_id); assert_eq!( diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 45aaef5..7048731 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -6,7 +6,6 @@ use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; use crate::prelude::*; use crate::{Error, Result}; -use locator::EosTokenLocator; use processor::TokenProcessor; mod locator; @@ -29,7 +28,7 @@ mod processor; pub struct Vocabulary { // TODO: Option is temp for back compatibility eos_token_id: Option, - map: HashMap>, + tokens: HashMap>, } impl Vocabulary { @@ -37,7 +36,7 @@ impl Vocabulary { pub fn new(eos_token_id: Option) -> Self { Self { eos_token_id, - map: HashMap::new(), + tokens: HashMap::new(), } } @@ -54,19 +53,19 @@ impl Vocabulary { })?; Self::filter_prepend_normalizers(&mut tokenizer); - let eos_token_id = EosTokenLocator::locate(model, &tokenizer, ¶meters); + let eos_token_id = locator::locate_eos_token_id(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { return Err(Error::UnableToLocateEosTokenId { model: model.to_string(), }); }; - Vocabulary::try_from((&mut tokenizer, eos_token_id)) + Vocabulary::try_from((tokenizer, eos_token_id)) } /// Per provided token returns vector of `TokenId`s if available in the vocabulary. pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { - self.map.get(token) + self.tokens.get(token) } /// Gets the identifier of the special end of the sentence token. @@ -105,10 +104,10 @@ impl Vocabulary { } } -impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary { +impl TryFrom<(Tokenizer, u32)> for Vocabulary { type Error = Error; - fn try_from(value: (&mut Tokenizer, u32)) -> Result { + fn try_from(value: (Tokenizer, u32)) -> Result { let (tokenizer, eos_token_id) = value; let mut vocabulary = Vocabulary::new(Some(eos_token_id)); @@ -118,7 +117,7 @@ impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary { } } - let processor = TokenProcessor::new(tokenizer)?; + let processor = TokenProcessor::new(&tokenizer)?; for (token, token_id) in tokenizer.get_vocab(false) { let token_bytes = processor.process(token)?; // TODO: lossy is temp: @@ -154,7 +153,7 @@ impl Vocabulary { pub fn insert_in_place(&mut self, token: impl Into, id: TokenId) { // TODO: return error if eos token id is inserted let token = token.into(); - self.map.entry(token).or_default().push(id); + self.tokens.entry(token).or_default().push(id); } /// Extends the vocabulary with tokens and their identifiers, in place. @@ -164,7 +163,7 @@ impl Vocabulary { ) { for (token, ids) in tokens_and_ids.into_iter() { let token = token.into(); - self.map.entry(token).or_default().extend(ids); + self.tokens.entry(token).or_default().extend(ids); } } } @@ -173,7 +172,7 @@ impl std::ops::Deref for Vocabulary { type Target = HashMap>; fn deref(&self) -> &HashMap> { - &self.map + &self.tokens } } @@ -191,10 +190,10 @@ impl std::fmt::Display for Vocabulary { } impl From>> for Vocabulary { - fn from(map: HashMap>) -> Vocabulary { + fn from(tokens: HashMap>) -> Vocabulary { Vocabulary { eos_token_id: None, - map, + tokens, } } } From 527f9520ce982997ad8129d0d434ceeb9f971b8a Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 13 Nov 2024 11:30:42 +0000 Subject: [PATCH 10/14] Add test rust coverage cmd to Makefile --- .gitignore | 2 +- Makefile | 23 ++++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index aeb0f0e..9dbb7b2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ benchmarks/results # Remove doc build folders .cache/ build/ - +rust-coverage/ target/ *.so *.pyd diff --git a/Makefile b/Makefile index f329306..6cd0637 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,9 @@ # Optional target to test/benchmark. TARGET ?= +TARPAULIN_INSTALLED := $(shell command -v cargo-tarpaulin > /dev/null && echo 1 || echo 0) .ONESHELL: -.PHONY: venv setup install install-release build-extension-debug build-extension-release watch-extension watch-extension-release pcc test test-rust test-python bench pybench doc dist clean check-clean-git +.PHONY: venv setup install install-release build-extension-debug build-extension-release watch-extension watch-extension-release pcc test test-rust test-python bench pybench doc dist clean check-clean-git check-tarpaulin test-rust-cov .SILENT: # Create a fresh virtual environment with the latest pip. @@ -59,6 +60,26 @@ test-python: build-extension-debug --cov=outlines_core \ --cov-report=term-missing:skip-covered +# Check if tarpaulin needs to be installed first. +check-tarpaulin: +ifeq ($(TARPAULIN_INSTALLED), 0) + @echo "cargo-tarpaulin is not found, installing..." + cargo install cargo-tarpaulin +else + @echo "cargo-tarpaulin is already installed" +endif + +# Run rust tests with coverage report. +test-rust-cov: check-tarpaulin + RUSTFLAGS="-C instrument-coverage" cargo tarpaulin \ + --out=Lcov \ + --output-dir=rust-coverage \ + --engine=llvm \ + --exclude-files=src/python_bindings/* \ + --no-dead-code \ + --workspace \ + --verbose + # Run rust benchmarks. bench: ifeq ($(TARGET),) From 2420742cfe73db3764352f26942332920906af11 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 13 Nov 2024 19:20:09 +0000 Subject: [PATCH 11/14] Improve test coverage --- Cargo.toml | 3 ++ src/vocabulary/locator.rs | 9 +++++ src/vocabulary/mod.rs | 49 ++++++++++++++++++++++++++-- src/vocabulary/processor.rs | 65 ++++++++++++++++++++++++++++++++----- 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 947620f..94eab3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,3 +36,6 @@ panic = 'abort' [package.metadata.scripts] build-python-extension = "python setup.py build_rust --inplace --debug" build-python-extension-release = "python setup.py build_rust --inplace --release" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] } diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 31a3548..4b41778 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -138,6 +138,9 @@ impl EosTokenLocation { let params = parameters.clone().unwrap_or_default(); + // Validation checks are coming as a literal adaptation logic from HF. + // In this case project is a model name, which if invalid expected to fail much earlier. + // So it seems a bit redundant to validate it this way, but no harm in doing so too. Self::validate(project)?; Self::validate(¶ms.revision)?; @@ -213,4 +216,10 @@ mod tests { let token_id = bad_file.lookup(model, &tokenizer, &None); assert!(token_id.is_none()); } + + #[test] + fn validate_config_input() { + let input = "bad_model_name*"; + assert!(EosTokenLocation::validate(input).is_err()); + } } diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 7048731..e3e24e5 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -243,13 +243,41 @@ mod tests { assert_eq!(vocabulary["0"], &[3]); } + #[test] + fn new_empty_vocabulary() { + let vocabulary = Vocabulary::new(None); + assert!(vocabulary.eos_token_id.is_none()); + assert!(vocabulary.tokens.is_empty()); + } + + #[test] + fn new_empty_vocabulary_from_hashmap() { + let map = HashMap::new(); + let vocabulary = Vocabulary::from(map); + assert!(vocabulary.eos_token_id.is_none()); + assert!(vocabulary.tokens.is_empty()); + } + + #[test] + fn new_vocabulary_from_iterator() { + let token: Token = "abc".to_string(); + let id: Vec = vec![1]; + let it = vec![(token, id)]; + let vocabulary = Vocabulary::from_iter(it); + assert!(vocabulary.eos_token_id.is_none()); + assert!(!vocabulary.tokens.is_empty()); + } + #[test] fn pretrained_from_gpt2() { let model = "openai-community/gpt2"; let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); - let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + let v_eos = vocabulary.eos_token_id; + assert!(v_eos.is_some()); + + let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 50256); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), @@ -278,7 +306,10 @@ mod tests { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); - let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + let v_eos = vocabulary.eos_token_id; + assert!(v_eos.is_some()); + + let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 2); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), @@ -358,4 +389,18 @@ mod tests { } } } + + #[test] + fn other_normalizers_being_kept() { + use tokenizers::normalizers::BertNormalizer; + + let model = "hf-internal-testing/llama-tokenizer"; + let normalizer = NormalizerWrapper::BertNormalizer(BertNormalizer::default()); + let mut tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + tokenizer.with_normalizer(Some(normalizer)); + + Vocabulary::filter_prepend_normalizers(&mut tokenizer); + + assert!(tokenizer.get_normalizer().is_some()); + } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index cec32f5..74d700d 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -93,15 +93,6 @@ pub(crate) enum TokenProcessorLevel { ByteFallback(Mods), } -impl std::fmt::Display for TokenProcessorLevel { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::Byte => write!(f, "Byte Level"), - Self::ByteFallback(mods) => write!(f, "Byte Fallback Level with mods: {:?}", mods), - } - } -} - /// Modifications to be applied by `TokenProcessor`of `ByteFallback` level. #[derive(Debug, Clone, PartialEq)] pub(crate) struct Mods { @@ -223,6 +214,7 @@ impl TokenProcessor { /// Since all fields of HF's `Replace` are private with no getters, it needs to be unpacked /// into local `ReplaceDecoder` structure. + #[cfg(not(tarpaulin_include))] fn unpack_decoder(decoder: &Replace) -> Result { match serde_json::to_value(decoder) { Err(_) => Err(Error::DecoderUnpackingFailed), @@ -352,4 +344,59 @@ mod tests { assert_eq!(e, Error::ByteFallbackProcessorFailed) } } + + #[test] + fn only_get_spacechar_replacement() { + let one_char = "_".to_string(); + let pattern = ReplacePattern::String(one_char); + let not_spacechar = "-".to_string(); + let decoder = ReplaceDecoder { + content: not_spacechar, + pattern, + }; + assert!(decoder.space_replacement().is_none()); + } + + #[test] + fn only_one_pattern_char_for_spacechar_replacement() { + let two_chars = "_*".to_string(); + let pattern = ReplacePattern::String(two_chars); + let spacechar = " ".to_string(); + let decoder = ReplaceDecoder { + content: spacechar, + pattern, + }; + assert!(decoder.space_replacement().is_none()); + } + + #[test] + fn tokenizer_without_decoders_is_unsupported() { + use tokenizers::models::bpe::BPE; + + let tokenizer = Tokenizer::new(BPE::default()); + let result = TokenProcessor::new(&tokenizer); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, Error::UnsupportedByTokenProcessor) + } + } + + #[test] + fn tokenizer_without_supported_decoders_in_sequence_is_unsupported() { + use tokenizers::decoders::sequence::Sequence; + use tokenizers::decoders::wordpiece::WordPiece; + use tokenizers::models::bpe::BPE; + + let mut tokenizer = Tokenizer::new(BPE::default()); + let decoder = WordPiece::default(); + let sequence = Sequence::new(vec![DecoderWrapper::WordPiece(decoder)]); + let decoder_sequence = DecoderWrapper::Sequence(sequence); + tokenizer.with_decoder(Some(decoder_sequence)); + + let result = TokenProcessor::new(&tokenizer); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, Error::UnsupportedByTokenProcessor) + } + } } From 5e0177a8d2215fe8ba35c9ae637f85518094d463 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 14 Nov 2024 15:59:29 +0000 Subject: [PATCH 12/14] Locator as a trait --- src/vocabulary/locator.rs | 35 ++++++++++++++++++++++++----------- src/vocabulary/mod.rs | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 4b41778..782b621 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -89,15 +89,28 @@ struct EosTokenLocation { location: EosTokenField, } -/// Locates eos token id by searching in defined common locations. -pub(crate) fn locate_eos_token_id( - model: &str, - tokenizer: &Tokenizer, - parameters: &Option, -) -> Option { - COMMON_LOCATIONS - .iter() - .find_map(|location| location.lookup(model, tokenizer, parameters)) +/// Locates eos token id. +pub(crate) trait Locator { + fn locate_eos_token_id( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option; +} + +/// Locates eos token id by searching in defined common locations in hugging face. +pub(crate) struct HFLocator; + +impl Locator for HFLocator { + fn locate_eos_token_id( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + COMMON_LOCATIONS + .iter() + .find_map(|location| location.lookup(model, tokenizer, parameters)) + } } impl EosTokenLocation { @@ -186,8 +199,8 @@ mod tests { ("hf-internal-testing/llama-tokenizer", 2, ""), ] { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); - let located = - locate_eos_token_id(model, &tokenizer, &None).expect("Token id is not located"); + let located = HFLocator::locate_eos_token_id(model, &tokenizer, &None) + .expect("Token id is not located"); assert_eq!(located, *expected_token_id); assert_eq!( diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index e3e24e5..eca4232 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -6,6 +6,7 @@ use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; use crate::prelude::*; use crate::{Error, Result}; +use locator::{HFLocator, Locator}; use processor::TokenProcessor; mod locator; @@ -44,6 +45,15 @@ impl Vocabulary { pub fn from_pretrained( model: &str, parameters: Option, + ) -> Result { + Self::from_pretrained_with_locator::(model, parameters) + } + + #[doc(hidden)] + #[inline(always)] + fn from_pretrained_with_locator( + model: &str, + parameters: Option, ) -> Result { let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| { @@ -53,7 +63,7 @@ impl Vocabulary { })?; Self::filter_prepend_normalizers(&mut tokenizer); - let eos_token_id = locator::locate_eos_token_id(model, &tokenizer, ¶meters); + let eos_token_id = L::locate_eos_token_id(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { return Err(Error::UnableToLocateEosTokenId { model: model.to_string(), @@ -358,6 +368,28 @@ mod tests { } } + struct NoneLocator; + impl Locator for NoneLocator { + fn locate_eos_token_id( + _model: &str, + _tokenizer: &Tokenizer, + _parameters: &Option, + ) -> Option { + None + } + } + + #[test] + fn unable_to_locate_eos_token_id_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let vocabulary = Vocabulary::from_pretrained_with_locator::(model, None); + + assert!(vocabulary.is_err()); + if let Err(Error::UnableToLocateEosTokenId { model }) = vocabulary { + assert_eq!(model, model.to_string()); + } + } + #[test] fn prepend_normalizers_filtered_out() { use tokenizers::normalizers::{Prepend, Sequence}; From b85523ec676baffd5f2c7654fb3de1cb3d41fc65 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Mon, 18 Nov 2024 16:19:50 +0000 Subject: [PATCH 13/14] Separate tokenizers errors, test supported pretrained models --- src/error.rs | 21 ++++++- src/vocabulary/mod.rs | 121 ++++++++++++++++++++++-------------- src/vocabulary/processor.rs | 30 ++++----- 3 files changed, 110 insertions(+), 62 deletions(-) diff --git a/src/error.rs b/src/error.rs index ff977a7..652fa74 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,28 @@ use thiserror::Error; +#[derive(Error, Debug)] +pub struct TokenizersError(pub tokenizers::Error); + +impl PartialEq for TokenizersError { + fn eq(&self, other: &Self) -> bool { + self.0.to_string() == other.0.to_string() + } +} + +impl std::fmt::Display for TokenizersError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + #[derive(Error, Debug, PartialEq)] pub enum Error { #[error("The vocabulary does not allow us to build a sequence that matches the input")] IndexError, - #[error("Unable to create tokenizer for {model}")] - UnableToCreateTokenizer { model: String }, + #[error(transparent)] + TokenizersError(#[from] TokenizersError), + #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")] + UnsupportedTokenizer { model: String, reason: String }, #[error("Unable to locate EOS token for {model}")] UnableToLocateEosTokenId { model: String }, #[error("Tokenizer is not supported by token processor")] diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index eca4232..719c904 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::prelude::*; +use crate::{error, prelude::*}; use crate::{Error, Result}; use locator::{HFLocator, Locator}; @@ -55,22 +55,44 @@ impl Vocabulary { model: &str, parameters: Option, ) -> Result { - let mut tokenizer = - Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| { - Error::UnableToCreateTokenizer { - model: model.to_string(), - } - })?; + let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone()) + .map_err(|e| Error::TokenizersError(error::TokenizersError(e)))?; Self::filter_prepend_normalizers(&mut tokenizer); + // Locate eos_token_id in defined locations. let eos_token_id = L::locate_eos_token_id(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { - return Err(Error::UnableToLocateEosTokenId { + return Err(Error::UnsupportedTokenizer { model: model.to_string(), + reason: "EOS token id".to_string(), }); }; - Vocabulary::try_from((tokenizer, eos_token_id)) + // Start building the vocabulary from eos_token_id and added tokens. + let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { + if !added_token.special { + vocabulary = vocabulary.insert(added_token.content.clone(), *id); + } + } + + // Process each vocabulary token according to the tokenizer's level. + let Ok(processor) = TokenProcessor::new(&tokenizer) else { + return Err(Error::UnsupportedTokenizer { + model: model.to_string(), + reason: "Token processor".to_string(), + }); + }; + for (token, token_id) in tokenizer.get_vocab(false) { + let token_bytes = processor.process(token)?; + // TODO: lossy is temp: + // - in python in was handled by byte_symbol function + // - interface needs to be redefined to treat Token type as bytes: Vec + let processed_token = String::from_utf8_lossy(&token_bytes); + vocabulary = vocabulary.insert(processed_token, token_id); + } + + Ok(vocabulary) } /// Per provided token returns vector of `TokenId`s if available in the vocabulary. @@ -114,33 +136,6 @@ impl Vocabulary { } } -impl TryFrom<(Tokenizer, u32)> for Vocabulary { - type Error = Error; - - fn try_from(value: (Tokenizer, u32)) -> Result { - let (tokenizer, eos_token_id) = value; - - let mut vocabulary = Vocabulary::new(Some(eos_token_id)); - for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { - if !added_token.special { - vocabulary = vocabulary.insert(added_token.content.clone(), *id); - } - } - - let processor = TokenProcessor::new(&tokenizer)?; - for (token, token_id) in tokenizer.get_vocab(false) { - let token_bytes = processor.process(token)?; - // TODO: lossy is temp: - // - in python in was handled by byte_symbol function - // - interface needs to be redefined to treat Token type as bytes: Vec - let processed_token = String::from_utf8_lossy(&token_bytes); - vocabulary = vocabulary.insert(processed_token, token_id); - } - - Ok(vocabulary) - } -} - impl Vocabulary { /// Inserts a token to the vocabulary with the specified identifier. pub fn insert(mut self, token: impl Into, id: TokenId) -> Vocabulary { @@ -278,6 +273,34 @@ mod tests { assert!(!vocabulary.tokens.is_empty()); } + #[test] + fn supported_pretrained_models() { + // Support is expected for these: + for model in [ + // GPT 2 + "openai-community/gpt2", + // Llama 2 + "hf-internal-testing/Llama-2-7B-GPTQ", + // Llama 3 + // OpenCoder: shares llama tokenizers + "hf-internal-testing/llama-3-8b-internal", + // Qwen + "Qwen/Qwen2-7B-Instruct", + // Salamandra + "BSC-LT/salamandra-2b", + ] { + let vocabulary = Vocabulary::from_pretrained(model, None); + match vocabulary { + Ok(v) => { + assert!(v.eos_token_id().is_some()); + assert_eq!(v.eos_token_id, v.eos_token_id()); + assert!(!v.tokens.is_empty()); + } + Err(_) => unreachable!(), + } + } + } + #[test] fn pretrained_from_gpt2() { let model = "openai-community/gpt2"; @@ -285,6 +308,7 @@ mod tests { let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); let v_eos = vocabulary.eos_token_id; + assert_eq!(v_eos, vocabulary.eos_token_id()); assert!(v_eos.is_some()); let v_eos = v_eos.unwrap(); @@ -317,6 +341,7 @@ mod tests { let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); let v_eos = vocabulary.eos_token_id; + assert_eq!(v_eos, vocabulary.eos_token_id()); assert!(v_eos.is_some()); let v_eos = v_eos.unwrap(); @@ -351,9 +376,12 @@ mod tests { let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; let vocabulary = Vocabulary::from_pretrained(model, None); - assert!(vocabulary.is_err()); - if let Err(e) = vocabulary { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match vocabulary { + Err(Error::UnsupportedTokenizer { model, reason }) => { + assert_eq!(model, model.to_string()); + assert_eq!(&reason, "Token processor"); + } + _ => unreachable!(), } } @@ -362,9 +390,9 @@ mod tests { let model = "hf-internal-testing/some-non-existent-model"; let vocabulary = Vocabulary::from_pretrained(model, None); - assert!(vocabulary.is_err()); - if let Err(Error::UnableToCreateTokenizer { model }) = vocabulary { - assert_eq!(model, model.to_string()); + match vocabulary { + Err(Error::TokenizersError(e)) => assert!(!e.to_string().is_empty()), + _ => unreachable!(), } } @@ -384,9 +412,12 @@ mod tests { let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; let vocabulary = Vocabulary::from_pretrained_with_locator::(model, None); - assert!(vocabulary.is_err()); - if let Err(Error::UnableToLocateEosTokenId { model }) = vocabulary { - assert_eq!(model, model.to_string()); + match vocabulary { + Err(Error::UnsupportedTokenizer { model, reason }) => { + assert_eq!(model, model.to_string()); + assert_eq!(&reason, "EOS token id"); + } + _ => unreachable!(), } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 74d700d..55b6cde 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -311,9 +311,9 @@ mod tests { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let result = TokenProcessor::new(&tokenizer); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), } } @@ -325,9 +325,9 @@ mod tests { for token in ["𝒜𝒷𝒸𝒟𝓔", "🦄🌈🌍🔥🎉", "京东购物"] { let result = processor.process(token.to_string()); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::ByteProcessorFailed) + match result { + Err(Error::ByteProcessorFailed) => {} + _ => unreachable!(), } } } @@ -339,9 +339,9 @@ mod tests { let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); let result = processor.process("<0x6y>".to_string()); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::ByteFallbackProcessorFailed) + match result { + Err(Error::ByteFallbackProcessorFailed) => {} + _ => unreachable!(), } } @@ -375,9 +375,9 @@ mod tests { let tokenizer = Tokenizer::new(BPE::default()); let result = TokenProcessor::new(&tokenizer); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), } } @@ -394,9 +394,9 @@ mod tests { tokenizer.with_decoder(Some(decoder_sequence)); let result = TokenProcessor::new(&tokenizer); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), } } } From 741f59c40f65d291f4f2800642798e426056b48b Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Tue, 19 Nov 2024 11:28:09 +0000 Subject: [PATCH 14/14] Apply CR suggestions --- src/error.rs | 9 +++------ src/lib.rs | 4 +--- src/vocabulary/locator.rs | 4 ++++ src/vocabulary/processor.rs | 24 ++++++++++++++---------- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/error.rs b/src/error.rs index 652fa74..f589731 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,9 @@ use thiserror::Error; +pub type Result = std::result::Result; + #[derive(Error, Debug)] +#[error("{0}")] pub struct TokenizersError(pub tokenizers::Error); impl PartialEq for TokenizersError { @@ -9,12 +12,6 @@ impl PartialEq for TokenizersError { } } -impl std::fmt::Display for TokenizersError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - #[derive(Error, Debug, PartialEq)] pub enum Error { #[error("The vocabulary does not allow us to build a sequence that matches the input")] diff --git a/src/lib.rs b/src/lib.rs index 4c45de4..6155b71 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,9 +6,7 @@ pub mod primitives; pub mod regex; pub mod vocabulary; -use error::Error; - -pub type Result = std::result::Result; +pub use error::{Error, Result}; #[cfg(feature = "python-bindings")] mod python_bindings; diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 782b621..d3f8bcf 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -4,6 +4,7 @@ use tokenizers::{FromPretrainedParameters, Tokenizer}; use crate::primitives::*; +/// Mapping of characters to bytes for GPT-2 like tokenizers. /// List of common eos token locations appearing on hugging face hub, ordered by priority. const COMMON_LOCATIONS: &[EosTokenLocation] = &[ // Most projects have `generation_config.json` that looks like: @@ -71,6 +72,7 @@ struct Object { eos_token: Content, } +/// `eos_token` provided in a `Content`. #[derive(Debug, Serialize, Deserialize)] struct Content { content: String, @@ -91,6 +93,7 @@ struct EosTokenLocation { /// Locates eos token id. pub(crate) trait Locator { + /// Locates eos token id in defined locations by `Locator`. fn locate_eos_token_id( model: &str, tokenizer: &Tokenizer, @@ -102,6 +105,7 @@ pub(crate) trait Locator { pub(crate) struct HFLocator; impl Locator for HFLocator { + /// Locates eos token id in defined locations. fn locate_eos_token_id( model: &str, tokenizer: &Tokenizer, diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 55b6cde..7426f24 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -77,12 +77,6 @@ static CHAR_MAP: Lazy> = Lazy::new(|| { char_map }); -/// Token processor to adjust tokens according to the tokenizer's level. -#[derive(Debug)] -pub(crate) struct TokenProcessor { - level: TokenProcessorLevel, -} - /// Recognizes different tokenizer's levels. #[derive(Debug, Clone, PartialEq)] pub(crate) enum TokenProcessorLevel { @@ -99,13 +93,17 @@ pub(crate) struct Mods { spacechar: char, } -/// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. -static DEFAULT_MODS: Mods = Mods { spacechar: ' ' }; +impl Default for Mods { + /// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. + fn default() -> Self { + Self { spacechar: ' ' } + } +} impl Mods { - /// Apply default modifications. + /// Apply default modifications to each token. fn apply_default(&self, token: String) -> String { - let to = DEFAULT_MODS.spacechar.to_string(); + let to = Self::default().spacechar.to_string(); token.replace(self.spacechar, &to) } } @@ -142,6 +140,12 @@ enum ReplacePattern { String(String), } +/// Token processor to adjust tokens according to the tokenizer's level. +#[derive(Debug)] +pub(crate) struct TokenProcessor { + level: TokenProcessorLevel, +} + impl TokenProcessor { /// Create new `TokenProcessor` with the level defined based on tokenizer's decoders. pub(crate) fn new(tokenizer: &Tokenizer) -> Result {