Skip to content

Commit

Permalink
Refactor: Move C string utility functions to utils module
Browse files Browse the repository at this point in the history
  • Loading branch information
kojix2 committed Oct 14, 2024
1 parent 97dd8fd commit dce7e54
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
40 changes: 4 additions & 36 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,14 @@ use std::os::raw::c_char;
use tiktoken_rs;
use tiktoken_rs::CoreBPE;

mod utils;
use utils::c_str_to_string;

#[no_mangle]
pub extern "C" fn tiktoken_init_logger() {
SimpleLogger::new().init().unwrap();
}

fn get_string_from_c_char(ptr: *const c_char) -> Result<String, std::str::Utf8Error> {
let c_str = unsafe { CStr::from_ptr(ptr) };
let str_slice = c_str.to_str()?;
Ok(str_slice.to_string())
}

fn c_str_to_string(ptr: *const c_char) -> Option<String> {
if ptr.is_null() {
return None;
}

let c_str = match get_string_from_c_char(ptr) {
Ok(str) => str,
Err(_) => {
warn!("Invalid UTF-8 sequence provided!");
return None;
}
};

Some(c_str)
}

#[no_mangle]
pub extern "C" fn tiktoken_r50k_base() -> *mut CoreBPE {
let bpe = tiktoken_rs::r50k_base();
Expand Down Expand Up @@ -470,6 +451,7 @@ pub extern "C" fn tiktoken_c_version() -> *const c_char {
mod tests {
use super::*;
use std::ffi::CString;
use utils::get_string_from_c_char;

#[test]
fn test_tiktoken_c_version() {
Expand All @@ -478,20 +460,6 @@ mod tests {
assert_eq!(version, env!("CARGO_PKG_VERSION"));
}

#[test]
fn test_get_string_from_c_char() {
let c_str = CString::new("I am a cat.").unwrap();
let str = get_string_from_c_char(c_str.as_ptr()).unwrap();
assert_eq!(str, "I am a cat.");
}

#[test]
fn test_c_str_to_string() {
let c_str = CString::new("I am a cat.").unwrap();
let str = c_str_to_string(c_str.as_ptr()).unwrap();
assert_eq!(str, "I am a cat.");
}

#[test]
fn test_c50k_base() {
let corebpe = tiktoken_r50k_base();
Expand Down
45 changes: 45 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use log::warn;
use std::ffi::CStr;
use std::os::raw::c_char;

pub fn get_string_from_c_char(ptr: *const c_char) -> Result<String, std::str::Utf8Error> {
let c_str = unsafe { CStr::from_ptr(ptr) };
let str_slice = c_str.to_str()?;
Ok(str_slice.to_string())
}

pub fn c_str_to_string(ptr: *const c_char) -> Option<String> {
if ptr.is_null() {
return None;
}

let c_str = match get_string_from_c_char(ptr) {
Ok(str) => str,
Err(_) => {
warn!("Invalid UTF-8 sequence provided!");
return None;
}
};

Some(c_str)
}

#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;

#[test]
fn test_get_string_from_c_char() {
let c_str = CString::new("I am a cat.").unwrap();
let str = get_string_from_c_char(c_str.as_ptr()).unwrap();
assert_eq!(str, "I am a cat.");
}

#[test]
fn test_c_str_to_string() {
let c_str = CString::new("I am a cat.").unwrap();
let str = c_str_to_string(c_str.as_ptr()).unwrap();
assert_eq!(str, "I am a cat.");
}
}

0 comments on commit dce7e54

Please sign in to comment.