diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 1c1c9310a..707dc7230 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -183,12 +183,12 @@ impl PyBpeTrainer { } #[getter] - fn get_min_frequency(self_: PyRef) -> u32 { + fn get_min_frequency(self_: PyRef) -> u64 { getter!(self_, BpeTrainer, min_frequency) } #[setter] - fn set_min_frequency(self_: PyRef, freq: u32) { + fn set_min_frequency(self_: PyRef, freq: u64) { setter!(self_, BpeTrainer, min_frequency, freq); } @@ -397,12 +397,12 @@ impl PyWordPieceTrainer { } #[getter] - fn get_min_frequency(self_: PyRef) -> u32 { + fn get_min_frequency(self_: PyRef) -> u64 { getter!(self_, WordPieceTrainer, min_frequency()) } #[setter] - fn set_min_frequency(self_: PyRef, freq: u32) { + fn set_min_frequency(self_: PyRef, freq: u64) { setter!(self_, WordPieceTrainer, @set_min_frequency, freq); } @@ -589,12 +589,12 @@ impl PyWordLevelTrainer { } #[getter] - fn get_min_frequency(self_: PyRef) -> u32 { + fn get_min_frequency(self_: PyRef) -> u64 { getter!(self_, WordLevelTrainer, min_frequency) } #[setter] - fn set_min_frequency(self_: PyRef, freq: u32) { + fn set_min_frequency(self_: PyRef, freq: u64) { setter!(self_, WordLevelTrainer, min_frequency, freq); } diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 3821cdab4..303fdbc81 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -11,7 +11,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet}; #[derive(Debug, Eq)] struct Merge { pair: Pair, - count: u32, + count: u64, pos: HashSet, } impl PartialEq for Merge { @@ -36,7 +36,7 @@ impl Ord for Merge { } struct Config { - min_frequency: u32, + min_frequency: u64, vocab_size: usize, show_progress: bool, special_tokens: Vec, @@ -79,7 +79,7 @@ impl BpeTrainerBuilder { /// Set the expected minimum frequency #[must_use] - pub fn min_frequency(mut self, frequency: u32) -> Self { + pub fn min_frequency(mut self, frequency: u64) -> Self { self.config.min_frequency = frequency; self } @@ -176,7 +176,7 @@ impl BpeTrainerBuilder { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)] pub struct BpeTrainer { /// The minimum frequency a pair must have to produce a merge operation - pub min_frequency: u32, + pub min_frequency: u64, /// The target vocabulary size pub vocab_size: usize, /// Whether to show progress while training @@ -195,7 +195,7 @@ pub struct BpeTrainer { /// An optional parameter to limit the max length of any single token pub max_token_length: Option, - words: HashMap, + words: HashMap, } impl Default for BpeTrainer { @@ -205,7 +205,7 @@ impl Default for BpeTrainer { } impl BpeTrainer { - pub fn new(min_frequency: u32, vocab_size: usize) -> Self { + pub fn new(min_frequency: u64, vocab_size: usize) -> Self { Self { min_frequency, vocab_size, @@ -263,7 +263,7 @@ impl BpeTrainer { /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, - wc: &HashMap, + wc: &HashMap, w2id: &mut HashMap, id2w: &mut Vec, ) { @@ -322,13 +322,13 @@ impl BpeTrainer { /// Tokenize words and add subwords to the vocabulary when relevant fn tokenize_words( &self, - wc: &HashMap, + wc: &HashMap, w2id: &mut HashMap, id2w: &mut Vec, p: &Option, - ) -> (Vec, Vec) { + ) -> (Vec, Vec) { let mut words: Vec = Vec::with_capacity(wc.len()); - let mut counts: Vec = Vec::with_capacity(wc.len()); + let mut counts: Vec = Vec::with_capacity(wc.len()); for (word, count) in wc { let mut current_word = Word::new(); @@ -373,7 +373,7 @@ impl BpeTrainer { fn count_pairs( &self, words: &[Word], - counts: &[u32], + counts: &[u64], p: &Option, ) -> (HashMap, HashMap>) { words @@ -431,7 +431,7 @@ impl BpeTrainer { pub fn do_train( &self, - word_counts: &HashMap, + word_counts: &HashMap, model: &mut BPE, ) -> Result> { let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); @@ -470,7 +470,7 @@ impl BpeTrainer { if count > 0 { queue.push(Merge { pair, - count: count as u32, + count: count as u64, pos, }); } @@ -493,8 +493,8 @@ impl BpeTrainer { } let mut top = queue.pop().unwrap(); - if top.count != pair_counts[&top.pair] as u32 { - top.count = pair_counts[&top.pair] as u32; + if top.count != pair_counts[&top.pair] as u64 { + top.count = pair_counts[&top.pair] as u64; queue.push(top); continue; } @@ -573,7 +573,7 @@ impl BpeTrainer { if count > 0 { queue.push(Merge { pair, - count: count as u32, + count: count as u64, pos, }); } @@ -632,7 +632,7 @@ impl Trainer for BpeTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; @@ -665,7 +665,7 @@ mod tests { #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: HashMap = [ ("roses".into(), 1), ("are".into(), 2), ("red".into(), 1), @@ -744,7 +744,7 @@ mod tests { */ let max_token_length = 16; - let long_word_counts: HashMap = [ + let long_word_counts: HashMap = [ ("singlelongtokenwithoutcasechange", 2), ("singleLongTokenWithCamelCaseChange", 2), ("Longsingletokenwithpunctu@t!onwithin", 2), @@ -784,7 +784,7 @@ mod tests { // directly compares tokens with known expected values. // maybe unstable depending on specific settings or changes. */ - let long_word_counts: HashMap = [ + let long_word_counts: HashMap = [ ("sin", 2), ("Sin", 2), ("Lon", 2), diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index d4048b15d..c52ad08d7 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -10,7 +10,7 @@ use std::collections::HashMap; pub struct WordLevelTrainer { /// The minimum frequency a word must have to be part of the vocabulary #[builder(default = "0")] - pub min_frequency: u32, + pub min_frequency: u64, /// The target vocabulary size #[builder(default = "30_000")] pub vocab_size: usize, @@ -22,7 +22,7 @@ pub struct WordLevelTrainer { pub special_tokens: Vec, #[builder(default, private)] - words: HashMap, + words: HashMap, } impl Default for WordLevelTrainer { @@ -38,14 +38,14 @@ impl WordLevelTrainer { fn do_train( &self, - word_counts: &HashMap, + word_counts: &HashMap, model: &mut WordLevel, ) -> Result> { let mut ordered_counts = word_counts.iter().collect::>(); //sort the word counts first by inverse counts and then by word, in order //to keep the sorting deterministic in case of equal counts - let cmp = |l: &(&String, &u32), r: &(&String, &u32)| -> Ordering { + let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering { let count_comp: Ordering = l.1.cmp(r.1); if count_comp != Ordering::Equal { return count_comp.reverse(); @@ -100,7 +100,7 @@ impl Trainer for WordLevelTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; @@ -132,7 +132,7 @@ mod tests { #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: HashMap = [ ("the".into(), 25), ("roses".into(), 22), ("are".into(), 24), diff --git a/tokenizers/src/models/wordpiece/trainer.rs b/tokenizers/src/models/wordpiece/trainer.rs index 1adcc2be4..58a5abc8f 100644 --- a/tokenizers/src/models/wordpiece/trainer.rs +++ b/tokenizers/src/models/wordpiece/trainer.rs @@ -26,7 +26,7 @@ impl WordPieceTrainerBuilder { /// Set the expected minimum frequency #[must_use] - pub fn min_frequency(mut self, frequency: u32) -> Self { + pub fn min_frequency(mut self, frequency: u64) -> Self { self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency); self } @@ -94,11 +94,11 @@ pub struct WordPieceTrainer { } impl WordPieceTrainer { - pub fn min_frequency(&self) -> u32 { + pub fn min_frequency(&self) -> u64 { self.bpe_trainer.min_frequency } - pub fn set_min_frequency(&mut self, freq: u32) { + pub fn set_min_frequency(&mut self, freq: u64) { self.bpe_trainer.min_frequency = freq; }