Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more support for tiktoken based tokenizers #1493

Merged
merged 19 commits into from
Apr 15, 2024
4 changes: 4 additions & 0 deletions bindings/python/py_src/tokenizers/models/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class BPE(Model):

byte_fallback (:obj:`bool`, `optional`):
Whether to use spm byte-fallback trick (defaults to False)

ignore_merges (:obj:`bool`, `optional`):
Whether or not to match tokens with the vocab before using merges.
"""
def __init__(
self,
Expand All @@ -124,6 +127,7 @@ class BPE(Model):
end_of_word_suffix=None,
fuse_unk=None,
byte_fallback=False,
ignore_merges=False,
):
pass

Expand Down
14 changes: 13 additions & 1 deletion bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ impl PyModel {
///
/// byte_fallback (:obj:`bool`, `optional`):
/// Whether to use spm byte-fallback trick (defaults to False)
///
/// ignore_merges (:obj:`bool`, `optional`):
/// Whether or not to match tokens with the vocab before using merges.
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
pub struct PyBPE {}

Expand All @@ -279,6 +282,7 @@ impl PyBPE {
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
"ignore_merges" => builder = builder.ignore_merges(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
};
}
Expand Down Expand Up @@ -396,11 +400,19 @@ impl PyBPE {
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
setter!(self_, BPE, byte_fallback, byte_fallback);
}
#[getter]
fn get_ignore_merges(self_: PyRef<Self>) -> bool {
getter!(self_, BPE, ignore_merges)
}

#[setter]
fn set_ignore_merges(self_: PyRef<Self>, ignore_merges: bool) {
setter!(self_, BPE, ignore_merges, ignore_merges);
}
#[new]
#[pyo3(
signature = (vocab=None, merges=None, **kwargs),
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)")]
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False, ignore_merges=False)")]
fn new(
py: Python<'_>,
vocab: Option<PyVocab>,
Expand Down
117 changes: 117 additions & 0 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct Config {
end_of_word_suffix: Option<String>,
fuse_unk: bool,
byte_fallback: bool,
ignore_merges: bool,
}

/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
Expand All @@ -49,6 +50,7 @@ impl Default for BpeBuilder {
end_of_word_suffix: None,
fuse_unk: false,
byte_fallback: false,
ignore_merges: false,
},
}
}
Expand Down Expand Up @@ -123,6 +125,12 @@ impl BpeBuilder {
self.config.byte_fallback = byte_fallback;
self
}
/// Set the `ignore_merges` option.
#[must_use]
pub fn ignore_merges(mut self, ignore_merges: bool) -> Self {
self.config.ignore_merges = ignore_merges;
self
}

/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(mut self) -> Result<BPE> {
Expand Down Expand Up @@ -190,6 +198,7 @@ impl BpeBuilder {
end_of_word_suffix: self.config.end_of_word_suffix,
fuse_unk: self.config.fuse_unk,
byte_fallback: self.config.byte_fallback,
ignore_merges: self.config.ignore_merges,
})
}
}
Expand Down Expand Up @@ -219,6 +228,8 @@ pub struct BPE {
/// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
/// for each byte in the unk token
pub byte_fallback: bool,
/// Whether or not to direct output words if they are part of the vocab.
pub ignore_merges: bool,
}

impl std::fmt::Debug for BPE {
Expand All @@ -232,6 +243,7 @@ impl std::fmt::Debug for BPE {
.field("byte_fallback", &self.byte_fallback)
.field("vocab", &self.vocab.len())
.field("merges", &self.merges.len())
.field("ignore_merges", &self.ignore_merges)
.finish()
}
}
Expand All @@ -258,6 +270,7 @@ impl Clone for BPE {
end_of_word_suffix: self.end_of_word_suffix.clone(),
fuse_unk: self.fuse_unk,
byte_fallback: self.byte_fallback,
ignore_merges: self.ignore_merges,
}
}
}
Expand Down Expand Up @@ -449,6 +462,17 @@ impl BPE {
fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
Ok(self.word_to_tokens(hit).collect())
} else if let Some(id) = self.vocab.get(sequence) {
if self.ignore_merges {
Ok(vec![Token::new(*id, sequence.to_string().clone(), (0, 0))])
} else {
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
if let Some(ref cache) = self.cache {
cache.set(sequence.to_owned(), word);
}
Ok(ret)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
let word = self.merge_word(sequence)?;
let ret = self.word_to_tokens(&word).collect();
Expand Down Expand Up @@ -862,4 +886,97 @@ mod tests {
let tokens = bpe.tokenize("\n").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
}

#[test]
fn test_ignore_merges() {
// 0x0A == '\n' in bytes
let vocab: Vocab = [
(".:.:".into(), 0),
("Ġbelirtilen".into(), 1),
(".".into(), 2),
(":".into(), 3),
("bel".into(), 4),
("irtilen".into(), 5),
("Ġ".into(), 6),
(".:".into(), 7),
("belirtilen".into(), 8),
(".:.".into(), 9),
("be".into(), 10),
("l".into(), 11),
("ir".into(), 12),
("ti".into(), 13),
("en".into(), 14),
("irtil".into(), 15),
("irti".into(), 16),
("i".into(), 17),
("r".into(), 18),
("t".into(), 19),
("b".into(), 20),
("e".into(), 21),
("n".into(), 22),
]
.iter()
.cloned()
.collect();
let mut bpe = BpeBuilder::default()
.vocab_and_merges(
vocab,
vec![
(".".into(), ":".into()),
("b".into(), "e".into()),
("be".into(), "l".into()),
("i".into(), "r".into()),
("t".into(), "i".into()),
("ir".into(), "ti".into()),
("e".into(), "n".into()),
("irti".into(), "l".into()),
],
)
.ignore_merges(true)
.build()
.unwrap();
let tokens = bpe.tokenize(".:.:").unwrap();
assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 0))]);

let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 0))]);

bpe.ignore_merges = false;

let tokens = bpe.tokenize(".:.:").unwrap();
assert_eq!(
tokens,
vec![
Token::new(7u32, ".:".into(), (0, 2)),
Token::new(7u32, ".:".into(), (2, 4))
]
);

let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
assert_eq!(
tokens,
vec![
Token {
id: 6,
value: "Ġ".into(),
offsets: (0, 2)
},
Token {
id: 4,
value: "bel".into(),
offsets: (2, 5)
},
Token {
id: 15,
value: "irtil".into(),
offsets: (5, 10)
},
Token {
id: 14,
value: "en".into(),
offsets: (10, 12)
}
]
)
}
}
Loading