Skip to content

Commit

Permalink
Allow default unigram unk token for GGUF (#363)
Browse files Browse the repository at this point in the history
* Allow default unk token for gguf

* Clippy
  • Loading branch information
EricLBuehler authored May 30, 2024
1 parent e9ee6ed commit 9f2937c
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions mistralrs-core/src/pipeline/gguf_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul
.collect::<Vec<_>>()
});

let unk = content.metadata["tokenizer.ggml.unknown_token_id"]
.to_u32()
.expect("GGUF unk token is not u32");
let unk = content
.metadata
.get("tokenizer.ggml.unknown_token_id")
.map(|t| t.to_u32().expect("GGUF unk token is not u32"));

let eos = content.metadata["tokenizer.ggml.eos_token_id"]
.to_u32()
Expand All @@ -76,20 +77,24 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul

let bos_str = tokens[bos as usize].clone();
let eos_str = tokens[eos as usize].clone();
let unk_str = tokens[unk as usize].clone();
let unk_str;

let (tokenizer, ty) = match model.as_str() {
"llama" | "replit" => {
// unigram
// This is a `unigram` tokenizer
let scores = scores
.as_ref()
.expect("Expect `tokenizer.ggml.scores` for `llama` unigram tokeizer.");
let mut vocab = Vec::new();
for (token, score) in tokens.iter().zip(scores) {
vocab.push((token.clone(), *score as f64));
}
let unigram =
Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?;

// Unigram (sentencepiece) default UNK is 0
let unk = unk.map(|x| x as usize).unwrap_or(0);
unk_str = tokens[unk].clone();

let unigram = Unigram::from(vocab, Some(unk), true).map_err(anyhow::Error::msg)?;
let mut tokenizer = Tokenizer::new(ModelWrapper::Unigram(unigram));
tokenizer.with_decoder(decoders::sequence::Sequence::new(vec![
DecoderWrapper::Replace(Replace::new("▁", " ").map_err(anyhow::Error::msg)?),
Expand All @@ -104,7 +109,7 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul

tokenizer.add_special_tokens(&[AddedToken::from(tokens[bos as usize].clone(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from(tokens[eos as usize].clone(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from(tokens[unk as usize].clone(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from(tokens[unk].clone(), true)]);

(tokenizer, "unigram")
}
Expand Down

0 comments on commit 9f2937c

Please sign in to comment.