Skip to content

Commit

Permalink
Refactor: Move BPE functions to a separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
kojix2 committed Oct 14, 2024
1 parent dce7e54 commit 1f4f47b
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 135 deletions.
142 changes: 142 additions & 0 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use log::warn;
use std::ffi::{c_char, CStr};
use tiktoken_rs::CoreBPE;

// get_bpe_from_tokenizer is not yet implemented.
// Use tiktoken_r50k_base(), tiktoken_p50k_base(), tiktoken_p50k_edit(), tiktoken_cl100k_base(), and tiktoken_o200k_base()
// instead.

#[no_mangle]
pub extern "C" fn tiktoken_r50k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::r50k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_p50k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::p50k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_p50k_edit() -> *mut CoreBPE {
let bpe = tiktoken_rs::p50k_edit();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_cl100k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::cl100k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_o200k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::o200k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_destroy_corebpe(ptr: *mut CoreBPE) {
if ptr.is_null() {
return;
}
unsafe {
let _ = Box::from_raw(ptr);
}
}

#[no_mangle]
pub extern "C" fn tiktoken_get_bpe_from_model(model: *const c_char) -> *mut CoreBPE {
if model.is_null() {
warn!("Null pointer provided for model!");
return std::ptr::null_mut();
}
let model = unsafe {
let raw = CStr::from_ptr(model);
match raw.to_str() {
Ok(valid_str) => valid_str,
Err(_) => {
warn!("Invalid UTF-8 sequence provided for model!");
return std::ptr::null_mut();
}
}
};
let bpe = tiktoken_rs::get_bpe_from_model(model);
match bpe {
Ok(bpe) => {
let boxed = Box::new(bpe);
Box::into_raw(boxed)
}
Err(_) => {
warn!("Failed to get BPE from model!");
std::ptr::null_mut()
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;

#[test]
fn test_c50k_base() {
let corebpe = tiktoken_r50k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_p50k_base() {
let corebpe = tiktoken_p50k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_p50k_edit() {
let corebpe = tiktoken_p50k_edit();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_cl100k_base() {
let corebpe = tiktoken_cl100k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_o200k_base() {
let corebpe = tiktoken_o200k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_get_bpe_from_model() {
let model = CString::new("gpt-4").unwrap();
let corebpe = tiktoken_get_bpe_from_model(model.as_ptr());
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_get_bpe_from_model_invalid_model() {
let model = CString::new("cat-gpt").unwrap();
let corebpe = tiktoken_get_bpe_from_model(model.as_ptr());
assert!(corebpe.is_null());
}
}
138 changes: 3 additions & 135 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use log::warn;
use simple_logger::SimpleLogger;
use std::ffi::CStr;
use std::os::raw::c_char;
use std::ffi::{c_char, CStr};
use tiktoken_rs;
use tiktoken_rs::CoreBPE;

mod corebpe;
mod utils;
use utils::c_str_to_string;

Expand All @@ -13,89 +13,6 @@ pub extern "C" fn tiktoken_init_logger() {
SimpleLogger::new().init().unwrap();
}

#[no_mangle]
pub extern "C" fn tiktoken_r50k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::r50k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_p50k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::p50k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_p50k_edit() -> *mut CoreBPE {
let bpe = tiktoken_rs::p50k_edit();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_cl100k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::cl100k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_o200k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::o200k_base();
let corebpe = bpe.unwrap();
let boxed = Box::new(corebpe);
Box::into_raw(boxed)
}

#[no_mangle]
pub extern "C" fn tiktoken_destroy_corebpe(ptr: *mut CoreBPE) {
if ptr.is_null() {
return;
}
unsafe {
let _ = Box::from_raw(ptr);
}
}

// get_bpe_from_tokenizer is not yet implemented.
// Use tiktoken_r50k_base(), tiktoken_p50k_base(), tiktoken_p50k_edit(), tiktoken_cl100k_base(), and tiktoken_o200k_base()
// instead.

#[no_mangle]
pub extern "C" fn tiktoken_get_bpe_from_model(model: *const c_char) -> *mut CoreBPE {
if model.is_null() {
warn!("Null pointer provided for model!");
return std::ptr::null_mut();
}
let model = unsafe {
let raw = CStr::from_ptr(model);
match raw.to_str() {
Ok(valid_str) => valid_str,
Err(_) => {
warn!("Invalid UTF-8 sequence provided for model!");
return std::ptr::null_mut();
}
}
};
let bpe = tiktoken_rs::get_bpe_from_model(model);
match bpe {
Ok(bpe) => {
let boxed = Box::new(bpe);
Box::into_raw(boxed)
}
Err(_) => {
warn!("Failed to get BPE from model!");
std::ptr::null_mut()
}
}
}

#[no_mangle]
pub extern "C" fn tiktoken_get_completion_max_tokens(
model: *const c_char,
Expand Down Expand Up @@ -450,6 +367,7 @@ pub extern "C" fn tiktoken_c_version() -> *const c_char {
#[cfg(test)]
mod tests {
use super::*;
use corebpe::{tiktoken_destroy_corebpe, tiktoken_get_bpe_from_model};
use std::ffi::CString;
use utils::get_string_from_c_char;

Expand All @@ -460,56 +378,6 @@ mod tests {
assert_eq!(version, env!("CARGO_PKG_VERSION"));
}

#[test]
fn test_c50k_base() {
let corebpe = tiktoken_r50k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_p50k_base() {
let corebpe = tiktoken_p50k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_p50k_edit() {
let corebpe = tiktoken_p50k_edit();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_cl100k_base() {
let corebpe = tiktoken_cl100k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_o200k_base() {
let corebpe = tiktoken_o200k_base();
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_get_bpe_from_model() {
let model = CString::new("gpt-4").unwrap();
let corebpe = tiktoken_get_bpe_from_model(model.as_ptr());
assert!(!corebpe.is_null());
tiktoken_destroy_corebpe(corebpe);
}

#[test]
fn test_get_bpe_from_model_invalid_model() {
let model = CString::new("cat-gpt").unwrap();
let corebpe = tiktoken_get_bpe_from_model(model.as_ptr());
assert!(corebpe.is_null());
}

#[test]
fn test_get_completion_max_tokens() {
let model = CString::new("gpt-4").unwrap();
Expand Down

0 comments on commit 1f4f47b

Please sign in to comment.