Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved interface for split_by_token #18

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions tiktoken-rs/src/vendor_tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,7 @@ impl CoreBPE {
/// use tiktoken_rs::cl100k_base;
/// let bpe = cl100k_base().unwrap();
/// let tokenized: Result<Vec<_>, _> = bpe
/// .split_by_token_with_special_tokens("This is a test with a lot of spaces")
/// .collect();
/// .split_by_token("This is a test with a lot of spaces", true);
/// let tokenized = tokenized.unwrap();
/// assert_eq!(
/// tokenized,
Expand All @@ -580,6 +579,7 @@ impl CoreBPE {
/// # Arguments
///
/// * text: A string slice containing the text to be tokenized.
/// * use_special_tokens: A boolean indicating whether to use the special tokens in the BPE model.
///
/// # Returns
///
Expand All @@ -592,18 +592,47 @@ impl CoreBPE {
///
/// * The input text cannot be converted into a valid UTF-8 string during the decoding process.
///
pub fn split_by_token_with_special_tokens<'a>(
pub fn split_by_token<'a>(
&'a self,
text: &'a str,
use_special_tokens: bool,
) -> Result<Vec<String>> {
self.split_by_token_iter(text, use_special_tokens).collect()
}

/// Iterator for decoding and splitting a String.
/// See `split_by_token` for more details.
pub fn split_by_token_iter<'a>(
&'a self,
text: &'a str,
use_special_tokens: bool,
) -> impl Iterator<Item = Result<String>> + 'a {
// First, encode the text using the BPE model
let encoded = self.encode_with_special_tokens(text);
let encoded = match use_special_tokens {
true => self.encode_with_special_tokens(text),
false => self.encode_ordinary(text),
};

self._decode_native_and_split(encoded).map(|token|
// Map each token to a Result<String>
String::from_utf8(token)
.map_err(|e| anyhow!(e.to_string())))
}

/// Tokenize a string and return the decoded tokens using the correct BPE model.
/// This method is equivalent to `split_by_token(text, false)`.
pub fn split_by_token_ordinary<'a>(&'a self, text: &'a str) -> Result<Vec<String>> {
self.split_by_token(text, false)
}

/// Iterator for decoding and splitting a String.
/// This method is equivalent to `split_by_token_iter(text, false)`.
pub fn split_by_token_ordinary_iter<'a>(
&'a self,
text: &'a str,
) -> impl Iterator<Item = Result<String>> + 'a {
self.split_by_token_iter(text, false)
}
}

#[cfg(feature = "python")]
Expand Down
2 changes: 1 addition & 1 deletion tiktoken-rs/tests/tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn cl100k_base_test() {
fn cl100k_split_test() {
let bpe = cl100k_base().unwrap();
let tokenized: Result<Vec<_>, _> = bpe
.split_by_token_with_special_tokens("This is a test with a lot of spaces")
.split_by_token_iter("This is a test with a lot of spaces", true)
.collect();
let tokenized = tokenized.unwrap();
assert_eq!(
Expand Down