Skip to content

Commit

Permalink
More spots needed to compile
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller committed Jan 17, 2024
1 parent 3ddcb2d commit fd24c27
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
12 changes: 6 additions & 6 deletions bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ impl PyBpeTrainer {
}

#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, BpeTrainer, min_frequency)
}

#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, BpeTrainer, min_frequency, freq);
}

Expand Down Expand Up @@ -397,12 +397,12 @@ impl PyWordPieceTrainer {
}

#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, WordPieceTrainer, min_frequency())
}

#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, WordPieceTrainer, @set_min_frequency, freq);
}

Expand Down Expand Up @@ -589,12 +589,12 @@ impl PyWordLevelTrainer {
}

#[getter]
fn get_min_frequency(self_: PyRef<Self>) -> u32 {
fn get_min_frequency(self_: PyRef<Self>) -> u64 {
getter!(self_, WordLevelTrainer, min_frequency)
}

#[setter]
fn set_min_frequency(self_: PyRef<Self>, freq: u32) {
fn set_min_frequency(self_: PyRef<Self>, freq: u64) {
setter!(self_, WordLevelTrainer, min_frequency, freq);
}

Expand Down
12 changes: 6 additions & 6 deletions tokenizers/src/models/wordlevel/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashMap;
pub struct WordLevelTrainer {
/// The minimum frequency a word must have to be part of the vocabulary
#[builder(default = "0")]
pub min_frequency: u32,
pub min_frequency: u64,
/// The target vocabulary size
#[builder(default = "30_000")]
pub vocab_size: usize,
Expand All @@ -22,7 +22,7 @@ pub struct WordLevelTrainer {
pub special_tokens: Vec<AddedToken>,

#[builder(default, private)]
words: HashMap<String, u32>,
words: HashMap<String, u64>,
}

impl Default for WordLevelTrainer {
Expand All @@ -38,14 +38,14 @@ impl WordLevelTrainer {

fn do_train(
&self,
word_counts: &HashMap<String, u32>,
word_counts: &HashMap<String, u64>,
model: &mut WordLevel,
) -> Result<Vec<AddedToken>> {
let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();

//sort the word counts first by inverse counts and then by word, in order
//to keep the sorting deterministic in case of equal counts
let cmp = |l: &(&String, &u32), r: &(&String, &u32)| -> Ordering {
let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
let count_comp: Ordering = l.1.cmp(r.1);
if count_comp != Ordering::Equal {
return count_comp.reverse();
Expand Down Expand Up @@ -100,7 +100,7 @@ impl Trainer for WordLevelTrainer {
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
let words: Result<HashMap<String, u32>> = iterator
let words: Result<HashMap<String, u64>> = iterator
.maybe_par_bridge()
.map(|sequence| {
let words = process(sequence.as_ref())?;
Expand Down Expand Up @@ -132,7 +132,7 @@ mod tests {

#[test]
fn test_train() {
let word_counts: HashMap<String, u32> = [
let word_counts: HashMap<String, u64> = [
("the".into(), 25),
("roses".into(), 22),
("are".into(), 24),
Expand Down

0 comments on commit fd24c27

Please sign in to comment.