Skip to content

Commit

Permalink
Avoid calling byte_pair_encode for existing tokens
Browse files Browse the repository at this point in the history
This was byte_pair_encode can be optimized further, assuming we'll always have at least 2 tokens
  • Loading branch information
Lőrinc authored and hauntsaninja committed Feb 9, 2024
1 parent 6e4851a commit b4c687e
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ fn hash_current_thread() -> usize {
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x = unsafe {
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0
};
u64::from(x) as usize
}
Expand Down Expand Up @@ -214,11 +214,10 @@ impl CoreBPE {
let mut ret = vec![];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
ret.push(*token);
continue;
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
ret.extend(&byte_pair_encode(piece, &self.encoder));
}
ret
}
Expand Down Expand Up @@ -516,7 +515,10 @@ impl CoreBPE {
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);

tokens.truncate(tokens.len() - last_piece_token_len);
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
match self.encoder.get(&unstable_bytes) {
Some(token) => tokens.push(*token),
None => tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)),
}
}
tokens
}
Expand Down Expand Up @@ -597,15 +599,26 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
mod tests {
use rustc_hash::FxHashMap as HashMap;

use crate::byte_pair_split;
use crate::{byte_pair_split, Rank};

#[test]
fn very_simple_test() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 1);
ranks.insert(b"cd".to_vec(), 2);
fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
HashMap::from_iter([
(b"ab".to_vec(), 0),
(b"cd".to_vec(), 1),
])
}

#[test]
fn test_simple_characters() {
let ranks = setup_ranks();
let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}

#[test]
fn test_repeated_characters() {
let ranks = setup_ranks();
let res = byte_pair_split(b"abab", &ranks);
assert_eq!(res, vec![b"ab", b"ab"]);
}
}

0 comments on commit b4c687e

Please sign in to comment.