Skip to content

Commit

Permalink
Support references to tokenizers (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbrandt authored Aug 10, 2023
1 parent 1b36cac commit 2af8bd8
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 12 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v0.4.3

### What's New

- Support `impl ChunkSizer` for `&Tokenizer` and `&CoreBPE`, allowing for generating chunks based off of a reference to a tokenizer as well, instead of requiring ownership.

## v0.4.2

### What's New
Expand Down
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "text-splitter"
version = "0.4.2"
version = "0.4.3"
authors = ["Ben Brandt <benjamin.j.brandt@gmail.com>"]
edition = "2021"
description = "Split text into semantic chunks, up to a desired chunk size. Supports calculating length by characters and tokens (when used with large language models)."
Expand All @@ -18,11 +18,11 @@ rustdoc-args = ["--cfg", "docsrs"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
auto_enums = "0.8.1"
either = "1.8.1"
auto_enums = "0.8.2"
either = "1.9.0"
itertools = "0.11.0"
once_cell = "1.18.0"
regex = "1.9.1"
regex = "1.9.3"
tiktoken-rs = { version = ">=0.2.0, <0.6.0", optional = true }
tokenizers = { version = ">=0.13.3, <0.14.0", default_features = false, features = [
"onig",
Expand Down
23 changes: 20 additions & 3 deletions src/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,25 @@ impl ChunkSizer for Tokenizer {
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, chunk: &str) -> usize {
self.encode(chunk, false)
.map(|enc| enc.len())
.expect("Unable to tokenize the following string {str}")
chunk_size(self, chunk)
}
}

impl ChunkSizer for &Tokenizer {
/// Returns the number of tokens in a given text after tokenization.
///
/// # Panics
///
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, chunk: &str) -> usize {
chunk_size(self, chunk)
}
}

fn chunk_size(tokenizer: &Tokenizer, chunk: &str) -> usize {
tokenizer
.encode(chunk, false)
.map(|enc| enc.len())
.expect("Unable to tokenize the following string {str}")
}
18 changes: 17 additions & 1 deletion src/tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@ impl ChunkSizer for CoreBPE {
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, text: &str) -> usize {
self.encode_ordinary(text).len()
chunk_size(self, text)
}
}

impl ChunkSizer for &CoreBPE {
/// Returns the number of tokens in a given text after tokenization.
///
/// # Panics
///
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, text: &str) -> usize {
chunk_size(self, text)
}
}

fn chunk_size(bpe: &CoreBPE, text: &str) -> usize {
bpe.encode_ordinary(text).len()
}
8 changes: 4 additions & 4 deletions tests/text_splitter_snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn huggingface_default() {
let text = fs::read_to_string(path).unwrap();

for chunk_size in [10, 100, 1000] {
let splitter = TextSplitter::new(HUGGINGFACE_TOKENIZER.clone());
let splitter = TextSplitter::new(&*HUGGINGFACE_TOKENIZER);
let chunks = splitter.chunks(&text, chunk_size).collect::<Vec<_>>();

assert_eq!(chunks.join(""), text);
Expand All @@ -103,7 +103,7 @@ fn huggingface_trim() {
let text = fs::read_to_string(path).unwrap();

for chunk_size in [10, 100, 1000] {
let splitter = TextSplitter::new(HUGGINGFACE_TOKENIZER.clone()).with_trim_chunks(true);
let splitter = TextSplitter::new(&*HUGGINGFACE_TOKENIZER).with_trim_chunks(true);
let chunks = splitter.chunks(&text, chunk_size).collect::<Vec<_>>();

for chunk in chunks.iter() {
Expand All @@ -122,7 +122,7 @@ fn tiktoken_default() {
let text = fs::read_to_string(path).unwrap();

for chunk_size in [10, 100, 1000] {
let splitter = TextSplitter::new(TIKTOKEN_TOKENIZER.clone());
let splitter = TextSplitter::new(&*TIKTOKEN_TOKENIZER);
let chunks = splitter.chunks(&text, chunk_size).collect::<Vec<_>>();

assert_eq!(chunks.join(""), text);
Expand All @@ -140,7 +140,7 @@ fn tiktoken_trim() {
let text = fs::read_to_string(path).unwrap();

for chunk_size in [10, 100, 1000] {
let splitter = TextSplitter::new(TIKTOKEN_TOKENIZER.clone()).with_trim_chunks(true);
let splitter = TextSplitter::new(&*TIKTOKEN_TOKENIZER).with_trim_chunks(true);
let chunks = splitter.chunks(&text, chunk_size).collect::<Vec<_>>();

for chunk in chunks.iter() {
Expand Down

0 comments on commit 2af8bd8

Please sign in to comment.