From cda94cd13b6d252cc0876a58cfc7356518a0ced0 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 12 Sep 2024 21:09:11 +0100 Subject: [PATCH] adding Batson --- Cargo.toml | 5 + crates/batson/Cargo.toml | 42 +++ crates/batson/README.md | 16 ++ crates/batson/benches/main.rs | 213 +++++++++++++++ crates/batson/examples/read_file.rs | 53 ++++ crates/batson/src/array.rs | 389 ++++++++++++++++++++++++++++ crates/batson/src/decoder.rs | 219 ++++++++++++++++ crates/batson/src/encoder.rs | 184 +++++++++++++ crates/batson/src/errors.rs | 92 +++++++ crates/batson/src/get.rs | 272 +++++++++++++++++++ crates/batson/src/header.rs | 352 +++++++++++++++++++++++++ crates/batson/src/json_writer.rs | 118 +++++++++ crates/batson/src/lib.rs | 76 ++++++ crates/batson/src/object.rs | 300 +++++++++++++++++++++ crates/batson/tests/main.rs | 238 +++++++++++++++++ crates/jiter-python/Cargo.toml | 2 +- crates/jiter/Cargo.toml | 6 +- 17 files changed, 2573 insertions(+), 4 deletions(-) create mode 100644 crates/batson/Cargo.toml create mode 100644 crates/batson/README.md create mode 100644 crates/batson/benches/main.rs create mode 100644 crates/batson/examples/read_file.rs create mode 100644 crates/batson/src/array.rs create mode 100644 crates/batson/src/decoder.rs create mode 100644 crates/batson/src/encoder.rs create mode 100644 crates/batson/src/errors.rs create mode 100644 crates/batson/src/get.rs create mode 100644 crates/batson/src/header.rs create mode 100644 crates/batson/src/json_writer.rs create mode 100644 crates/batson/src/lib.rs create mode 100644 crates/batson/src/object.rs create mode 100644 crates/batson/tests/main.rs diff --git a/Cargo.toml b/Cargo.toml index b3ac383..fe6afc9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "crates/jiter", "crates/jiter-python", + "crates/batson", "crates/fuzz", ] resolver = "2" @@ -28,5 +29,9 @@ inherits = "release" debug = true [workspace.dependencies] +jiter = { path = "crates/jiter", version = "0.5.0" } pyo3 = { version = "0.22.0" } pyo3-build-config = { version = "0.22.0" } +bencher = "0.1.5" +paste = "1.0.7" +codspeed-bencher-compat = "2.7.1" diff --git a/crates/batson/Cargo.toml b/crates/batson/Cargo.toml new file mode 100644 index 0000000..719de34 --- /dev/null +++ b/crates/batson/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "batson" +description = "Binary Alternative To (J)SON. Designed to be very fast to query." +readme = "../../README.md" +version = {workspace = true} +edition = {workspace = true} +authors = {workspace = true} +license = {workspace = true} +keywords = {workspace = true} +categories = {workspace = true} +homepage = {workspace = true} +repository = {workspace = true} + +[dependencies] +bytemuck = { version = "1.17.1", features = ["aarch64_simd", "derive", "align_offset"] } +jiter = { workspace = true } +serde = "1.0.210" +serde_json = "1.0.128" +simdutf8 = { version = "0.1.4", features = ["aarch64_neon"] } +smallvec = "2.0.0-alpha.7" + +[dev-dependencies] +bencher = { workspace = true } +paste = { workspace = true } +codspeed-bencher-compat = { workspace = true } + +[[bench]] +name = "main" +harness = false + +[lints.clippy] +dbg_macro = "deny" +print_stdout = "deny" +print_stderr = "deny" +# in general we lint against the pedantic group, but we will whitelist +# certain lints which we don't want to enforce (for now) +pedantic = { level = "deny", priority = -1 } +missing_errors_doc = "allow" +cast_possible_truncation = "allow" # TODO remove +cast_sign_loss = "allow" # TODO remove +cast_possible_wrap = "allow" # TODO remove +checked_conversions = "allow" # TODO remove diff --git a/crates/batson/README.md b/crates/batson/README.md new file mode 100644 index 0000000..92bdd24 --- /dev/null +++ b/crates/batson/README.md @@ -0,0 +1,16 @@ +# batson + +Binary Alternative To (J)SON. Designed to be very fast to query. + +Inspired by Postgres' [JSONB type](https://github.com/postgres/postgres/commit/d9134d0a355cfa447adc80db4505d5931084278a?diff=unified&w=0) and Snowflake's [VARIANT type](https://www.youtube.com/watch?v=jtjOfggD4YY). + +For a relatively small JSON document (3KB), batson is 14 to 126x faster than Jiter, and 106 to 588x faster than Serde. + +``` +test medium_get_str_found_batson ... bench: 51 ns/iter (+/- 1) +test medium_get_str_found_jiter ... bench: 755 ns/iter (+/- 66) +test medium_get_str_found_serde ... bench: 5,420 ns/iter (+/- 93) +test medium_get_str_missing_batson ... bench: 9 ns/iter (+/- 0) +test medium_get_str_missing_jiter ... bench: 1,135 ns/iter (+/- 46) +test medium_get_str_missing_serde ... bench: 5,292 ns/iter (+/- 324) +``` diff --git a/crates/batson/benches/main.rs b/crates/batson/benches/main.rs new file mode 100644 index 0000000..e280f0b --- /dev/null +++ b/crates/batson/benches/main.rs @@ -0,0 +1,213 @@ +use codspeed_bencher_compat::{benchmark_group, benchmark_main, Bencher}; +use std::hint::black_box; + +use std::fs::File; +use std::io::Read; + +use batson::get::{get_str, BatsonPath}; +use batson::{batson_to_json_string, encode_from_json}; +use jiter::JsonValue; + +fn read_file(path: &str) -> String { + let mut file = File::open(path).unwrap(); + let mut contents = String::new(); + file.read_to_string(&mut contents).unwrap(); + contents +} + +/// taken from +mod jiter_find { + use jiter::{Jiter, Peek}; + + #[derive(Debug)] + pub enum JsonPath<'s> { + Key(&'s str), + Index(usize), + None, + } + + impl From for JsonPath<'_> { + fn from(index: u64) -> Self { + JsonPath::Index(usize::try_from(index).unwrap()) + } + } + + impl From for JsonPath<'_> { + fn from(index: i32) -> Self { + match usize::try_from(index) { + Ok(i) => Self::Index(i), + Err(_) => Self::None, + } + } + } + + impl<'s> From<&'s str> for JsonPath<'s> { + fn from(key: &'s str) -> Self { + JsonPath::Key(key) + } + } + + pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { + let json_str = opt_json?; + let mut jiter = Jiter::new(json_str.as_bytes()); + let mut peek = jiter.peek().ok()?; + for element in path { + match element { + JsonPath::Key(key) if peek == Peek::Object => { + let mut next_key = jiter.known_object().ok()??; + + while next_key != *key { + jiter.next_skip().ok()?; + next_key = jiter.next_key().ok()??; + } + + peek = jiter.peek().ok()?; + } + JsonPath::Index(index) if peek == Peek::Array => { + let mut array_item = jiter.known_array().ok()??; + + for _ in 0..*index { + jiter.known_skip(array_item).ok()?; + array_item = jiter.array_step().ok()??; + } + + peek = array_item; + } + _ => { + return None; + } + } + } + Some((jiter, peek)) + } + + pub fn get_str(json_data: Option<&str>, path: &[JsonPath]) -> Option { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + Peek::String => Some(jiter.known_str().ok()?.to_owned()), + _ => None, + } + } else { + None + } + } +} + +mod serde_find { + use batson::get::BatsonPath; + use serde_json::Value; + + pub fn get_str(json_data: &[u8], path: &[BatsonPath]) -> Option { + let json_value: Value = serde_json::from_slice(json_data).ok()?; + let mut current = &json_value; + for key in path { + current = match (key, current) { + (BatsonPath::Key(k), Value::Object(map)) => map.get(*k)?, + (BatsonPath::Index(i), Value::Array(vec)) => vec.get(*i)?, + _ => return None, + } + } + match current { + Value::String(s) => Some(s.clone()), + _ => None, + } + } +} + +fn json_to_batson(json: &[u8]) -> Vec { + let json_value = JsonValue::parse(json, false).unwrap(); + encode_from_json(&json_value).unwrap() +} + +fn medium_get_str_found_batson(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let batson_data = json_to_batson(json_data); + let path: Vec = vec!["person".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = get_str(black_box(&batson_data), &path); + black_box(v) + }); +} + +fn medium_get_str_found_jiter(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let path: Vec = vec!["person".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = jiter_find::get_str(black_box(Some(&json)), &path); + black_box(v) + }); +} + +fn medium_get_str_found_serde(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let path: Vec = vec!["person".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = serde_find::get_str(black_box(json_data), &path).unwrap(); + black_box(v) + }); +} + +fn medium_get_str_missing_batson(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let batson_data = json_to_batson(json_data); + let path: Vec = vec!["squid".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = get_str(black_box(&batson_data), &path); + black_box(v) + }); +} + +fn medium_get_str_missing_jiter(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let path: Vec = vec!["squid".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = jiter_find::get_str(black_box(Some(&json)), &path); + black_box(v) + }); +} + +fn medium_get_str_missing_serde(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let path: Vec = vec!["squid".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = serde_find::get_str(black_box(json_data), &path); + black_box(v) + }); +} + +fn medium_convert_batson_to_json(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let batson_data = json_to_batson(json_data); + bench.iter(|| { + let v = batson_to_json_string(black_box(&batson_data)).unwrap(); + black_box(v) + }); +} + +fn medium_convert_json_to_batson(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json = json.as_bytes(); + bench.iter(|| { + let json_value = JsonValue::parse(json, false).unwrap(); + let b = encode_from_json(&json_value).unwrap(); + black_box(b) + }); +} + +benchmark_group!( + benches, + medium_get_str_found_batson, + medium_get_str_found_jiter, + medium_get_str_found_serde, + medium_get_str_missing_batson, + medium_get_str_missing_jiter, + medium_get_str_missing_serde, + medium_convert_batson_to_json, + medium_convert_json_to_batson +); +benchmark_main!(benches); diff --git a/crates/batson/examples/read_file.rs b/crates/batson/examples/read_file.rs new file mode 100644 index 0000000..2aa6891 --- /dev/null +++ b/crates/batson/examples/read_file.rs @@ -0,0 +1,53 @@ +use batson::get::BatsonPath; +use batson::{batson_to_json_string, encode_from_json}; +use jiter::JsonValue; +use std::fs::File; +use std::io::Read; + +fn main() { + let filename = std::env::args().nth(1).expect( + r#" +No arguments provided! + +Usage: +cargo run --example read_file file.json [path] +"#, + ); + + let mut file = File::open(&filename).expect("failed to open file"); + let mut json = Vec::new(); + file.read_to_end(&mut json).expect("failed to read file"); + + let json_value = JsonValue::parse(&json, false).expect("invalid JSON"); + let batson = encode_from_json(&json_value).expect("failed to construct batson data"); + println!("json length: {}", json.len()); + println!("batson length: {}", batson.len()); + + let output_json = batson_to_json_string(&batson).expect("failed to convert batson to JSON"); + println!("output json length: {}", output_json.len()); + + if let Some(path) = std::env::args().nth(2) { + let path: Vec = path.split('.').map(to_batson_path).collect(); + let start = std::time::Instant::now(); + let value = batson::get::get_str(&batson, &path).expect("failed to get value"); + let elapsed = start.elapsed(); + println!("Found value: {value:?} (time taken: {elapsed:?})"); + } + + println!("reloading to check round-trip"); + let json_value = JsonValue::parse(output_json.as_bytes(), false).expect("invalid JSON"); + let batson = encode_from_json(&json_value).expect("failed to construct batson data"); + let output_json2 = batson_to_json_string(&batson).expect("failed to convert batson to JSON"); + println!("JSON unchanged after re-encoding: {:?}", output_json == output_json2); + + println!("\n\noutput json:\n{}", output_json); +} + +fn to_batson_path(s: &str) -> BatsonPath { + if s.chars().all(char::is_numeric) { + let index: usize = s.parse().unwrap(); + index.into() + } else { + s.into() + } +} diff --git a/crates/batson/src/array.rs b/crates/batson/src/array.rs new file mode 100644 index 0000000..b29d028 --- /dev/null +++ b/crates/batson/src/array.rs @@ -0,0 +1,389 @@ +use std::sync::Arc; + +use jiter::{JsonArray, JsonValue}; +use smallvec::SmallVec; + +use crate::decoder::Decoder; +use crate::encoder::Encoder; +use crate::errors::{DecodeResult, EncodeResult, ToJsonResult}; +use crate::header::{Category, Header, Length, NumberHint, Primitive}; +use crate::json_writer::JsonWriter; + +#[cfg(target_endian = "big")] +compile_error!("big-endian architectures are not yet supported as we use `bytemuck` for zero-copy header decoding."); + +/// Batson heterogeneous representation +#[derive(Debug)] +pub(crate) struct HetArray<'b> { + positions: &'b [u32], +} + +impl<'b> HetArray<'b> { + pub fn decode_header(d: &mut Decoder<'b>, length: Length) -> DecodeResult { + if matches!(length, Length::Empty) { + Ok(Self { positions: &[] }) + } else { + let length = length.decode(d)?; + let positions = d.take_slice_as(length)?; + Ok(Self { positions }) + } + } + + pub fn len(&self) -> usize { + self.positions.len() + } + + pub fn get(&self, d: &mut Decoder<'b>, index: usize) -> bool { + if let Some(position) = self.positions.get(index) { + d.index = *position as usize; + true + } else { + false + } + } + + pub fn to_json(&self, d: &mut Decoder<'b>) -> DecodeResult> { + self.positions + .iter() + .map(|_| d.take_value()) + .collect::>>() + .map(Arc::new) + } + + pub fn write_json(&self, d: &mut Decoder<'b>, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut steps = 0..self.len(); + writer.start_array(); + if steps.next().is_some() { + d.write_json(writer)?; + for _ in steps { + writer.comma(); + d.write_json(writer)?; + } + } + writer.end_array(); + Ok(()) + } +} + +pub(crate) fn header_array_get(d: &mut Decoder, length: Length, index: usize) -> DecodeResult> { + u8_array_get(d, length, index)? + .map(|b| Header::decode(b, d)) + .transpose() +} + +pub(crate) fn header_array_to_json<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult> { + let length = length.decode(d)?; + d.take_slice(length)? + .iter() + .map(|b| Header::decode(*b, d).map(|h| h.as_value(d))) + .collect::>() + .map(Arc::new) +} + +pub(crate) fn header_array_write_to_json(d: &mut Decoder, length: Length, writer: &mut JsonWriter) -> ToJsonResult<()> { + let length = length.decode(d)?; + let s = d.take_slice(length)?; + let mut iter = s.iter(); + + writer.start_array(); + if let Some(b) = iter.next() { + let h = Header::decode(*b, d)?; + h.write_json_header_only(writer)?; + for b in iter { + writer.comma(); + let h = Header::decode(*b, d)?; + h.write_json_header_only(writer)?; + } + } + writer.end_array(); + Ok(()) +} + +pub(crate) fn u8_array_get(d: &mut Decoder, length: Length, index: usize) -> DecodeResult> { + let length = length.decode(d)?; + let v = d.take_slice(length)?; + Ok(v.get(index).copied()) +} + +pub(crate) fn u8_array_to_json<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult> { + let v = u8_array_slice(d, length)? + .iter() + .map(|b| JsonValue::Int(i64::from(*b))) + .collect(); + Ok(Arc::new(v)) +} + +pub(crate) fn u8_array_slice<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult<&'b [u8]> { + let length = length.decode(d)?; + d.take_slice(length) +} + +pub(crate) fn i64_array_get(d: &mut Decoder, length: Length, index: usize) -> DecodeResult> { + let length = length.decode(d)?; + d.align::(); + let s: &[i64] = d.take_slice_as(length)?; + Ok(s.get(index).copied()) +} + +pub(crate) fn i64_array_to_json<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult> { + let s = i64_array_slice(d, length)?; + let v = s.iter().copied().map(JsonValue::Int).collect(); + Ok(Arc::new(v)) +} + +pub(crate) fn i64_array_slice<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult<&'b [i64]> { + let length = length.decode(d)?; + d.take_slice_as(length) +} + +pub(crate) fn encode_array(encoder: &mut Encoder, array: &JsonArray) -> EncodeResult<()> { + if array.is_empty() { + // shortcut but also no alignment! + encoder.encode_length(Category::HetArray, 0) + } else if let Some(packed_array) = PackedArray::new(array) { + match packed_array { + PackedArray::Header(array) => { + encoder.push(Category::HeaderArray.encode_with(array.len() as u8)); + // no alignment necessary, it's a vec of u8 + encoder.extend(&array); + } + PackedArray::I64(array) => { + encoder.push(Category::I64Array.encode_with(array.len() as u8)); + encoder.align::(); + encoder.extend(bytemuck::cast_slice(&array)); + } + PackedArray::U8(array) => { + encoder.push(Category::U8Array.encode_with(array.len() as u8)); + // no alignment necessary, it's a vec of u8 + encoder.extend(&array); + } + } + Ok(()) + } else { + encoder.encode_length(Category::HetArray, array.len())?; + + let mut positions: Vec = Vec::with_capacity(array.len()); + encoder.align::(); + let positions_start = encoder.ring_fence(array.len() * size_of::()); + + for value in array.iter() { + positions.push(encoder.position() as u32); + encoder.encode_value(value)?; + } + encoder.set_range(positions_start, bytemuck::cast_slice(&positions)); + Ok(()) + } +} + +#[derive(Debug)] +enum PackedArray { + Header(Vec), + U8(Vec), + I64(Vec), +} + +impl PackedArray { + fn new(array: &JsonArray) -> Option { + let mut header_only: Option> = Some(Vec::with_capacity(array.len())); + let mut u8_only: Option> = Some(Vec::with_capacity(array.len())); + let mut i64_only: Option> = Some(Vec::with_capacity(array.len())); + + macro_rules! push_len { + ($cat: expr, $is_empty: expr) => {{ + u8_only = None; + i64_only = None; + if $is_empty { + header_only.as_mut()?.push($cat.encode_with(Length::Empty as u8)); + } else { + header_only = None; + } + }}; + } + + for element in array.iter() { + match element { + JsonValue::Null => { + u8_only = None; + i64_only = None; + header_only + .as_mut()? + .push(Category::Primitive.encode_with(Primitive::Null as u8)); + } + JsonValue::Bool(b) => { + u8_only = None; + i64_only = None; + let right: Primitive = (*b).into(); + header_only.as_mut()?.push(Category::Primitive.encode_with(right as u8)); + } + JsonValue::Int(i) => { + if let Some(i64_only) = &mut i64_only { + i64_only.push(*i); + } + // if u8_only is still alive, push to it if we can + if let Some(u8_only_) = &mut u8_only { + if let Ok(u8) = u8::try_from(*i) { + u8_only_.push(u8); + } else { + u8_only = None; + } + } + // if header_only is still alive, push to it if we can + if let Some(h) = &mut header_only { + if let Some(n) = NumberHint::header_only_i64(*i) { + h.push(Category::Int.encode_with(n as u8)); + } else { + header_only = None; + } + } + } + JsonValue::BigInt(b) => todo!("BigInt {b:?}"), + JsonValue::Float(f) => { + u8_only = None; + i64_only = None; + if let Some(n) = NumberHint::header_only_f64(*f) { + header_only.as_mut()?.push(Category::Float.encode_with(n as u8)); + } else { + header_only = None; + } + } + JsonValue::Str(s) => push_len!(Category::Str, s.is_empty()), + // TODO could use a header only array if it's empty + JsonValue::Array(a) => push_len!(Category::HetArray, a.is_empty()), + JsonValue::Object(o) => push_len!(Category::Object, o.is_empty()), + } + if header_only.is_none() && i64_only.is_none() { + // stop early if neither work + return None; + } + } + // u8 array is preferable to header array as it's the pure binary representation + if let Some(u8_array) = u8_only { + Some(Self::U8(u8_array)) + } else if let Some(header_only) = header_only { + Some(Self::Header(header_only)) + } else { + i64_only.map(Self::I64) + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use smallvec::smallvec; + + use crate::compare_json_values; + use crate::decoder::Decoder; + use crate::encoder::Encoder; + use crate::header::Header; + + use super::*; + + /// hack while waiting for + macro_rules! assert_arrays_eq { + ($a: expr, $b: expr) => {{ + assert_eq!($a.len(), $b.len()); + for (a, b) in $a.iter().zip($b.iter()) { + assert!(compare_json_values(a, b)); + } + }}; + } + + #[test] + fn array_round_trip() { + let array = Arc::new(smallvec![JsonValue::Null, JsonValue::Int(123), JsonValue::Bool(false),]); + let mut encoder = Encoder::new(); + encoder.encode_array(&array).unwrap(); + let bytes: Vec = encoder.into(); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HetArray(3.into())); + + let het_array = HetArray::decode_header(&mut decoder, 3.into()).unwrap(); + assert_eq!(het_array.len(), 3); + assert_eq!(het_array.positions, &[16, 17, 19]); + let decode_array = het_array.to_json(&mut decoder).unwrap(); + assert_arrays_eq!(decode_array, array); + } + + #[test] + fn array_round_trip_empty() { + let array = Arc::new(smallvec![]); + let mut encoder = Encoder::new(); + encoder.encode_array(&array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 1); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HetArray(0.into())); + + let het_array = HetArray::decode_header(&mut decoder, 0.into()).unwrap(); + assert_eq!(het_array.len(), 0); + let decode_array = het_array.to_json(&mut decoder).unwrap(); + assert_arrays_eq!(decode_array, array); + } + + #[test] + fn header_array_round_trip() { + let array = Arc::new(smallvec![ + JsonValue::Null, + JsonValue::Bool(false), + JsonValue::Bool(true), + JsonValue::Int(7), + JsonValue::Float(4.0), + ]); + let mut encoder = Encoder::new(); + encoder.encode_array(&array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 6); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HeaderArray(5.into())); + + let header_array = header_array_to_json(&mut decoder, 5.into()).unwrap(); + assert_arrays_eq!(header_array, array); + } + + #[test] + fn u8_array_round_trip() { + let array = Arc::new(smallvec![JsonValue::Int(7), JsonValue::Int(4), JsonValue::Int(123),]); + let mut encoder = Encoder::new(); + encoder.encode_array(&array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 4); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::U8Array(3.into())); + + let mut decoder = Decoder::new(&bytes); + let v = decoder.take_value().unwrap(); + assert!(compare_json_values(&v, &JsonValue::Array(array))); + } + + #[test] + fn i64_array_round_trip() { + let array = Arc::new(smallvec![ + JsonValue::Int(7), + JsonValue::Int(i64::MAX), + JsonValue::Int(i64::MIN), + JsonValue::Int(1234), + JsonValue::Int(1_234_567_890), + ]); + let mut encoder = Encoder::new(); + encoder.encode_array(&array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 6 * 8); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::I64Array(5.into())); + + let i64_array = i64_array_to_json(&mut decoder, 5.into()).unwrap(); + assert_arrays_eq!(i64_array, array); + } +} diff --git a/crates/batson/src/decoder.rs b/crates/batson/src/decoder.rs new file mode 100644 index 0000000..9e953f3 --- /dev/null +++ b/crates/batson/src/decoder.rs @@ -0,0 +1,219 @@ +use std::fmt; + +use jiter::JsonValue; + +use crate::array::{ + header_array_to_json, header_array_write_to_json, i64_array_slice, i64_array_to_json, u8_array_slice, + u8_array_to_json, HetArray, +}; +use crate::errors::{DecodeError, DecodeErrorType, DecodeResult, ToJsonResult}; +use crate::header::{Header, Length}; +use crate::json_writer::JsonWriter; +use crate::object::Object; + +#[cfg(target_endian = "big")] +compile_error!("big-endian architectures are not yet supported as we use `bytemuck` for zero-copy header decoding."); +// see `decode_slice_as` for more information + +#[derive(Clone)] +pub(crate) struct Decoder<'b> { + bytes: &'b [u8], + pub index: usize, +} + +impl fmt::Debug for Decoder<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let upcoming = self.bytes.get(self.index..).unwrap_or_default(); + f.debug_struct("Decoder") + .field("total_length", &self.bytes.len()) + .field("upcoming_length", &upcoming.len()) + .field("index", &self.index) + .field("upcoming", &upcoming) + .finish() + } +} + +impl<'b> Decoder<'b> { + pub fn new(bytes: &'b [u8]) -> Self { + Self { bytes, index: 0 } + } + + pub fn take_header(&mut self) -> DecodeResult
{ + let byte = self.next().ok_or_else(|| self.eof())?; + Header::decode(byte, self) + } + + pub fn align(&mut self) { + let align = align_of::(); + // I've checked and this is equivalent to: `self.index = self.index + align - (self.index % align)` + // is it actually faster? + self.index = (self.index + align - 1) & !(align - 1); + } + + pub fn take_value(&mut self) -> DecodeResult> { + match self.take_header()? { + Header::Null => Ok(JsonValue::Null), + Header::Bool(b) => Ok(JsonValue::Bool(b)), + Header::Int(n) => n.decode_i64(self).map(JsonValue::Int), + Header::IntBig(i) => todo!("decoding for bigint {i:?}"), + Header::Float(n) => n.decode_f64(self).map(JsonValue::Float), + Header::Str(l) => self.decode_str(l).map(|s| JsonValue::Str(s.into())), + Header::Object(length) => { + let obj = Object::decode_header(self, length)?; + obj.to_json(self).map(JsonValue::Object) + } + Header::HetArray(length) => { + let het = HetArray::decode_header(self, length)?; + het.to_json(self).map(JsonValue::Array) + } + Header::U8Array(length) => u8_array_to_json(self, length).map(JsonValue::Array), + Header::HeaderArray(length) => header_array_to_json(self, length).map(JsonValue::Array), + Header::I64Array(length) => i64_array_to_json(self, length).map(JsonValue::Array), + } + } + + pub fn write_json(&mut self, writer: &mut JsonWriter) -> ToJsonResult<()> { + match self.take_header()? { + Header::Null => writer.write_null(), + Header::Bool(b) => writer.write_value(b)?, + Header::Int(n) => { + let i = n.decode_i64(self)?; + writer.write_value(i)?; + } + Header::IntBig(i) => todo!("decoding for bigint {i:?}"), + Header::Float(n) => { + let f = n.decode_f64(self)?; + writer.write_value(f)?; + } + Header::Str(l) => { + let s = self.decode_str(l)?; + writer.write_value(s)?; + } + Header::Object(length) => { + let obj = Object::decode_header(self, length)?; + obj.write_json(self, writer)?; + } + Header::HetArray(length) => { + let het = HetArray::decode_header(self, length)?; + het.write_json(self, writer)?; + } + Header::U8Array(length) => { + let a = u8_array_slice(self, length)?; + writer.write_seq(a.iter())?; + } + Header::HeaderArray(length) => header_array_write_to_json(self, length, writer)?, + Header::I64Array(length) => { + let a = i64_array_slice(self, length)?; + writer.write_seq(a.iter())?; + } + }; + Ok(()) + } + + pub fn take_slice(&mut self, size: usize) -> DecodeResult<&'b [u8]> { + let end = self.index + size; + let s = self.bytes.get(self.index..end).ok_or_else(|| self.eof())?; + self.index = end; + Ok(s) + } + + pub fn take_slice_as(&mut self, length: usize) -> DecodeResult<&'b [T]> { + self.align::(); + let size = length * size_of::(); + let end = self.index + size; + let s = self.bytes.get(self.index..end).ok_or_else(|| self.eof())?; + + let t: &[T] = bytemuck::try_cast_slice(s).map_err(|e| self.error(DecodeErrorType::PodCastError(e)))?; + + self.index = end; + Ok(t) + } + + pub fn decode_str(&mut self, length: Length) -> DecodeResult<&'b str> { + let len = length.decode(self)?; + if len == 0 { + Ok("") + } else { + self.take_str(len) + } + } + + pub fn decode_bytes(&mut self, length: Length) -> DecodeResult<&'b [u8]> { + let len = length.decode(self)?; + if len == 0 { + Ok(b"") + } else { + self.take_slice(len) + } + } + + pub fn take_str(&mut self, length: usize) -> DecodeResult<&'b str> { + let end = self.index + length; + let slice = self.bytes.get(self.index..end).ok_or_else(|| self.eof())?; + let s = simdutf8::basic::from_utf8(slice).map_err(|e| DecodeError::from_utf8_error(self.index, e))?; + self.index = end; + Ok(s) + } + + pub fn take_u8(&mut self) -> DecodeResult { + self.next().ok_or_else(|| self.eof()) + } + + pub fn take_u16(&mut self) -> DecodeResult { + let slice = self.take_slice(2)?; + Ok(u16::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_u32(&mut self) -> DecodeResult { + let slice = self.take_slice(4)?; + Ok(u32::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_i8(&mut self) -> DecodeResult { + match self.next() { + Some(byte) => Ok(byte as i8), + None => Err(self.eof()), + } + } + + pub fn take_i32(&mut self) -> DecodeResult { + let slice = self.take_slice(4)?; + Ok(i32::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_i64(&mut self) -> DecodeResult { + let slice = self.take_slice(8)?; + Ok(i64::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_f32(&mut self) -> DecodeResult { + let slice = self.take_slice(4)?; + Ok(f32::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_f64(&mut self) -> DecodeResult { + let slice = self.take_slice(8)?; + Ok(f64::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn eof(&self) -> DecodeError { + self.error(DecodeErrorType::EOF) + } + + pub fn error(&self, error_type: DecodeErrorType) -> DecodeError { + DecodeError::new(self.index, error_type) + } +} + +impl<'b> Iterator for Decoder<'b> { + type Item = u8; + + fn next(&mut self) -> Option { + if let Some(byte) = self.bytes.get(self.index) { + self.index += 1; + Some(*byte) + } else { + None + } + } +} diff --git a/crates/batson/src/encoder.rs b/crates/batson/src/encoder.rs new file mode 100644 index 0000000..e334d23 --- /dev/null +++ b/crates/batson/src/encoder.rs @@ -0,0 +1,184 @@ +use jiter::{JsonArray, JsonObject, JsonValue}; + +use crate::array::encode_array; +use crate::errors::{EncodeError, EncodeResult}; +use crate::header::{Category, Length, NumberHint, Primitive}; +use crate::object::encode_object; + +#[derive(Debug)] +pub(crate) struct Encoder { + data: Vec, +} + +impl Encoder { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + pub fn align(&mut self) { + let align = align_of::(); + // same calculation as in `Decoder::align` + let new_len = (self.data.len() + align - 1) & !(align - 1); + self.data.resize(new_len, 0); + } + + pub fn ring_fence(&mut self, size: usize) -> usize { + let start = self.data.len(); + self.data.resize(start + size, 0); + start + } + + pub fn encode_value(&mut self, value: &JsonValue<'_>) -> EncodeResult<()> { + match value { + JsonValue::Null => self.encode_null(), + JsonValue::Bool(b) => self.encode_bool(*b), + JsonValue::Int(int) => self.encode_i64(*int), + JsonValue::BigInt(_) => todo!("encoding BigInt"), + JsonValue::Float(f) => self.encode_f64(*f), + JsonValue::Str(s) => self.encode_str(s.as_ref())?, + JsonValue::Array(array) => self.encode_array(array)?, + JsonValue::Object(obj) => self.encode_object(obj)?, + }; + Ok(()) + } + + pub fn position(&self) -> usize { + self.data.len() + } + + pub fn encode_null(&mut self) { + let h = Category::Primitive.encode_with(Primitive::Null as u8); + self.push(h); + } + + pub fn encode_bool(&mut self, bool: bool) { + let right: Primitive = bool.into(); + let h = Category::Primitive.encode_with(right as u8); + self.push(h); + } + + pub fn encode_i64(&mut self, value: i64) { + if (0..=10).contains(&value) { + self.push(Category::Int.encode_with(value as u8)); + } else if let Ok(size_8) = i8::try_from(value) { + self.push(Category::Int.encode_with(NumberHint::Size8 as u8)); + self.extend(&size_8.to_le_bytes()); + } else if let Ok(size_32) = i32::try_from(value) { + self.push(Category::Int.encode_with(NumberHint::Size32 as u8)); + self.extend(&size_32.to_le_bytes()); + } else { + self.push(Category::Int.encode_with(NumberHint::Size64 as u8)); + self.extend(&value.to_le_bytes()); + } + } + + pub fn encode_f64(&mut self, value: f64) { + match value { + 0.0 => self.push(Category::Float.encode_with(NumberHint::Zero as u8)), + 1.0 => self.push(Category::Float.encode_with(NumberHint::One as u8)), + 2.0 => self.push(Category::Float.encode_with(NumberHint::Two as u8)), + 3.0 => self.push(Category::Float.encode_with(NumberHint::Three as u8)), + 4.0 => self.push(Category::Float.encode_with(NumberHint::Four as u8)), + 5.0 => self.push(Category::Float.encode_with(NumberHint::Five as u8)), + 6.0 => self.push(Category::Float.encode_with(NumberHint::Six as u8)), + 7.0 => self.push(Category::Float.encode_with(NumberHint::Seven as u8)), + 8.0 => self.push(Category::Float.encode_with(NumberHint::Eight as u8)), + 9.0 => self.push(Category::Float.encode_with(NumberHint::Nine as u8)), + 10.0 => self.push(Category::Float.encode_with(NumberHint::Ten as u8)), + _ => { + // should we do something with f32 here? + self.push(Category::Float.encode_with(NumberHint::Size64 as u8)); + self.extend(&value.to_le_bytes()); + } + } + } + + pub fn encode_str(&mut self, s: &str) -> EncodeResult<()> { + self.encode_length(Category::Str, s.len())?; + self.extend(s.as_bytes()); + Ok(()) + } + + pub fn encode_bytes(&mut self, b: &[u8]) -> EncodeResult<()> { + self.encode_length(Category::U8Array, b.len())?; + self.extend(b); + Ok(()) + } + + pub fn encode_object(&mut self, object: &JsonObject) -> EncodeResult<()> { + encode_object(self, object) + } + + pub fn encode_array(&mut self, array: &JsonArray) -> EncodeResult<()> { + encode_array(self, array) + } + + pub fn extend(&mut self, s: &[u8]) { + self.data.extend_from_slice(s); + } + + pub fn set_range(&mut self, start: usize, s: &[u8]) { + self.data[start..start + s.len()].as_mut().copy_from_slice(s); + } + + pub fn encode_length(&mut self, cat: Category, len: usize) -> EncodeResult<()> { + match len { + 0 => self.push(cat.encode_with(Length::Empty as u8)), + 1 => self.push(cat.encode_with(Length::One as u8)), + 2 => self.push(cat.encode_with(Length::Two as u8)), + 3 => self.push(cat.encode_with(Length::Three as u8)), + 4 => self.push(cat.encode_with(Length::Four as u8)), + 5 => self.push(cat.encode_with(Length::Five as u8)), + 6 => self.push(cat.encode_with(Length::Six as u8)), + 7 => self.push(cat.encode_with(Length::Seven as u8)), + 8 => self.push(cat.encode_with(Length::Eight as u8)), + 9 => self.push(cat.encode_with(Length::Nine as u8)), + 10 => self.push(cat.encode_with(Length::Ten as u8)), + _ => { + if let Ok(s) = u8::try_from(len) { + self.push(cat.encode_with(Length::U8 as u8)); + self.push(s); + } else if let Ok(int) = u16::try_from(len) { + self.push(cat.encode_with(Length::U16 as u8)); + self.extend(&int.to_le_bytes()); + } else if let Ok(int) = u32::try_from(len) { + self.push(cat.encode_with(Length::U32 as u8)); + self.extend(&int.to_le_bytes()); + } else { + return Err(EncodeError::StrTooLong); + } + } + } + Ok(()) + } + + pub fn push(&mut self, h: u8) { + self.data.push(h); + } +} + +impl From for Vec { + fn from(encoder: Encoder) -> Self { + encoder.data + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::decoder::Decoder; + use crate::header::Header; + + #[test] + fn encode_int() { + let mut enc = Encoder::new(); + enc.encode_i64(0); + let h = Decoder::new(&enc.data).take_header().unwrap(); + assert_eq!(h, Header::Int(NumberHint::Zero)); + + let mut enc = Encoder::new(); + enc.encode_i64(7); + let h = Decoder::new(&enc.data).take_header().unwrap(); + assert_eq!(h, Header::Int(NumberHint::Seven)); + } +} diff --git a/crates/batson/src/errors.rs b/crates/batson/src/errors.rs new file mode 100644 index 0000000..762534a --- /dev/null +++ b/crates/batson/src/errors.rs @@ -0,0 +1,92 @@ +use std::fmt; + +use bytemuck::PodCastError; +use serde::ser::Error; +use simdutf8::basic::Utf8Error; + +pub type EncodeResult = Result; + +#[derive(Debug, Copy, Clone)] +pub enum EncodeError { + StrTooLong, +} + +pub type DecodeResult = Result; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct DecodeError { + pub index: usize, + pub error_type: DecodeErrorType, +} + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Error at index {}: {}", self.index, self.error_type) + } +} + +impl From for serde_json::Error { + fn from(e: DecodeError) -> Self { + serde_json::Error::custom(e.to_string()) + } +} + +impl DecodeError { + pub fn new(index: usize, error_type: DecodeErrorType) -> Self { + Self { index, error_type } + } + + pub fn from_utf8_error(index: usize, error: Utf8Error) -> Self { + Self::new(index, DecodeErrorType::Utf8Error(error)) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum DecodeErrorType { + EOF, + ObjectBodyIndexInvalid, + HeaderInvalid { value: u8, ty: &'static str }, + Utf8Error(Utf8Error), + PodCastError(PodCastError), +} + +impl fmt::Display for DecodeErrorType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DecodeErrorType::EOF => write!(f, "Unexpected end of file"), + DecodeErrorType::ObjectBodyIndexInvalid => write!(f, "Object body index is invalid"), + DecodeErrorType::HeaderInvalid { value, ty } => { + write!(f, "Header value {value} is invalid for type {ty}") + } + DecodeErrorType::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + DecodeErrorType::PodCastError(e) => write!(f, "Pod cast error: {e}"), + } + } +} + +pub type ToJsonResult = Result; + +#[derive(Debug)] +pub enum ToJsonError { + Str(&'static str), + DecodeError(DecodeError), + JsonError(serde_json::Error), +} + +impl From<&'static str> for ToJsonError { + fn from(e: &'static str) -> Self { + Self::Str(e) + } +} + +impl From for ToJsonError { + fn from(e: DecodeError) -> Self { + Self::DecodeError(e) + } +} + +impl From for ToJsonError { + fn from(e: serde_json::Error) -> Self { + Self::JsonError(e) + } +} diff --git a/crates/batson/src/get.rs b/crates/batson/src/get.rs new file mode 100644 index 0000000..1dbfdaf --- /dev/null +++ b/crates/batson/src/get.rs @@ -0,0 +1,272 @@ +#![allow(clippy::module_name_repetitions)] +use crate::array::{header_array_get, i64_array_get, u8_array_get, HetArray}; +use crate::decoder::Decoder; +use crate::errors::{DecodeError, DecodeResult}; +use crate::header::Header; +use crate::object::Object; + +#[derive(Debug)] +pub enum BatsonPath<'s> { + Key(&'s str), + Index(usize), +} + +impl From for BatsonPath<'_> { + fn from(index: usize) -> Self { + Self::Index(index) + } +} + +impl<'s> From<&'s str> for BatsonPath<'s> { + fn from(key: &'s str) -> Self { + Self::Key(key) + } +} + +pub fn get_bool(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult> { + GetValue::get(bytes, path).map(|v| v.and_then(Into::into)) +} + +pub fn get_str<'b>(bytes: &'b [u8], path: &[BatsonPath]) -> DecodeResult> { + get_try_into(bytes, path) +} + +pub fn get_int(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult> { + get_try_into(bytes, path) +} + +pub fn contains(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult { + GetValue::get(bytes, path).map(|v| v.is_some()) +} + +pub fn get_length(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult> { + if let Some(v) = GetValue::get(bytes, path)? { + v.into_length() + } else { + Ok(None) + } +} + +fn get_try_into<'b, T>(bytes: &'b [u8], path: &[BatsonPath]) -> DecodeResult> +where + Option: TryFrom, Error = DecodeError>, +{ + if let Some(v) = GetValue::get(bytes, path)? { + v.try_into() + } else { + Ok(None) + } +} + +#[derive(Debug)] +enum GetValue<'b> { + Header(Decoder<'b>, Header), + U8(u8), + I64(i64), +} + +impl<'b> GetValue<'b> { + fn get(bytes: &'b [u8], path: &[BatsonPath]) -> DecodeResult> { + let mut decoder = Decoder::new(bytes); + let mut opt_header: Option
= Some(decoder.take_header()?); + let mut value: Option = None; + for element in path { + let Some(header) = opt_header.take() else { + return Ok(None); + }; + match element { + BatsonPath::Key(key) => { + if let Header::Object(length) = header { + let object = Object::decode_header(&mut decoder, length)?; + if object.get(&mut decoder, key)? { + opt_header = Some(decoder.take_header()?); + } + } + } + BatsonPath::Index(index) => match header { + Header::HeaderArray(length) => { + opt_header = header_array_get(&mut decoder, length, *index)?; + } + Header::U8Array(length) => { + if let Some(u8_value) = u8_array_get(&mut decoder, length, *index)? { + value = Some(GetValue::U8(u8_value)); + } + } + Header::I64Array(length) => { + if let Some(i64_value) = i64_array_get(&mut decoder, length, *index)? { + value = Some(GetValue::I64(i64_value)); + } + } + Header::HetArray(length) => { + let a = HetArray::decode_header(&mut decoder, length)?; + if a.get(&mut decoder, *index) { + opt_header = Some(decoder.take_header()?); + } + } + _ => {} + }, + } + } + if let Some(header) = opt_header { + Ok(Some(Self::Header(decoder, header))) + } else if let Some(value) = value { + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn header(self) -> Option
{ + match self { + Self::Header(_, header) => Some(header), + _ => None, + } + } + + fn into_length(self) -> DecodeResult> { + let Self::Header(mut decoder, header) = self else { + return Ok(None); + }; + match header { + Header::Str(length) + | Header::Object(length) + | Header::HeaderArray(length) + | Header::U8Array(length) + | Header::I64Array(length) + | Header::HetArray(length) => length.decode(&mut decoder).map(Some), + _ => Ok(None), + } + } +} + +impl From> for Option { + fn from(v: GetValue) -> Self { + v.header().and_then(Header::into_bool) + } +} + +impl<'b> TryFrom> for Option<&'b str> { + type Error = DecodeError; + + fn try_from(v: GetValue<'b>) -> DecodeResult { + match v { + GetValue::Header(mut decoder, Header::Str(length)) => { + let length = length.decode(&mut decoder)?; + decoder.take_str(length).map(Some) + } + _ => Ok(None), + } + } +} + +impl TryFrom> for Option { + type Error = DecodeError; + + fn try_from(v: GetValue) -> DecodeResult { + match v { + GetValue::Header(mut decoder, Header::Int(n)) => n.decode_i64(&mut decoder).map(Some), + GetValue::I64(i64) => Ok(Some(i64)), + GetValue::U8(u8) => Ok(Some(i64::from(u8))), + GetValue::Header(..) => Ok(None), + } + } +} + +#[cfg(test)] +mod test { + use crate::encode_from_json; + use crate::header::{Header, NumberHint}; + use jiter::{JsonValue, LazyIndexMap}; + use smallvec::smallvec; + use std::sync::Arc; + + use super::*; + + #[test] + fn get_object() { + let v: JsonValue<'static> = JsonValue::Object(Arc::new(LazyIndexMap::from(vec![ + ("null".into(), JsonValue::Null), + ("true".into(), JsonValue::Bool(true)), + ]))); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &["null".into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Null)); + let v = GetValue::get(&bytes, &["true".into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Bool(true))); + + assert!(GetValue::get(&bytes, &["foo".into()]).unwrap().is_none()); + assert!(GetValue::get(&bytes, &[1.into()]).unwrap().is_none()); + assert!(GetValue::get(&bytes, &["null".into(), 1.into()]).unwrap().is_none()); + } + + #[test] + fn get_header_array() { + let v: JsonValue<'static> = JsonValue::Array(Arc::new(smallvec![JsonValue::Null, JsonValue::Bool(true),])); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Null)); + + let v = GetValue::get(&bytes, &[1.into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Bool(true))); + + assert!(GetValue::get(&bytes, &["foo".into()]).unwrap().is_none()); + assert!(GetValue::get(&bytes, &[2.into()]).unwrap().is_none()); + } + + #[test] + fn get_het_array() { + let v: JsonValue<'static> = + JsonValue::Array(Arc::new( + smallvec![JsonValue::Int(42), JsonValue::Str("foobar".into()),], + )); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Int(NumberHint::Size8))); + } + + fn value_u8(v: &GetValue) -> Option { + match v { + GetValue::U8(u8) => Some(*u8), + _ => None, + } + } + + fn value_i64(v: &GetValue) -> Option { + match v { + GetValue::I64(i64) => Some(*i64), + _ => None, + } + } + + #[test] + fn get_u8_array() { + let v: JsonValue<'static> = JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(255),])); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(value_u8(&v), Some(42)); + + let v = GetValue::get(&bytes, &[1.into()]).unwrap().unwrap(); + assert_eq!(value_u8(&v), Some(255)); + + assert!(GetValue::get(&bytes, &[2.into()]).unwrap().is_none()); + } + + #[test] + fn get_i64_array() { + let v: JsonValue<'static> = + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(i64::MAX),])); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(value_i64(&v), Some(42)); + + let v = GetValue::get(&bytes, &[1.into()]).unwrap().unwrap(); + assert_eq!(value_i64(&v), Some(i64::MAX)); + + assert!(GetValue::get(&bytes, &[2.into()]).unwrap().is_none()); + } +} diff --git a/crates/batson/src/header.rs b/crates/batson/src/header.rs new file mode 100644 index 0000000..f08711a --- /dev/null +++ b/crates/batson/src/header.rs @@ -0,0 +1,352 @@ +use crate::decoder::Decoder; +use crate::errors::{DecodeErrorType, DecodeResult}; +use crate::json_writer::JsonWriter; +use crate::ToJsonResult; +use jiter::{JsonValue, LazyIndexMap}; +use smallvec::smallvec; +use std::sync::Arc; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum Header { + Null, + Bool(bool), + Int(NumberHint), + IntBig(Length), + Float(NumberHint), + Str(Length), + Object(Length), + // array types in order of their complexity + #[allow(clippy::enum_variant_names)] + HeaderArray(Length), + U8Array(Length), + I64Array(Length), + HetArray(Length), +} + +impl Header { + /// Decode the next byte from a decoder into a header value + pub fn decode(byte: u8, d: &Decoder) -> DecodeResult { + let (left, right) = split_byte(byte); + let cat = Category::from_u8(left, d)?; + match cat { + Category::Primitive => Primitive::from_u8(right, d).map(Primitive::header_value), + Category::Int => NumberHint::from_u8(right, d).map(Self::Int), + Category::BigInt => Length::from_u8(right, d).map(Self::IntBig), + Category::Float => NumberHint::from_u8(right, d).map(Self::Float), + Category::Str => Length::from_u8(right, d).map(Self::Str), + Category::Object => Length::from_u8(right, d).map(Self::Object), + Category::HeaderArray => Length::from_u8(right, d).map(Self::HeaderArray), + Category::U8Array => Length::from_u8(right, d).map(Self::U8Array), + Category::I64Array => Length::from_u8(right, d).map(Self::I64Array), + Category::HetArray => Length::from_u8(right, d).map(Self::HetArray), + } + } + + /// TODO `'static` should be okay as return lifetime, I don't know why it's not + pub fn as_value<'b>(self, _: &Decoder<'b>) -> JsonValue<'b> { + match self { + Header::Null => JsonValue::Null, + Header::Bool(b) => JsonValue::Bool(b), + Header::Int(n) => JsonValue::Int(n.decode_i64_header()), + Header::IntBig(_) => todo!(), + Header::Float(n) => JsonValue::Float(n.decode_f64_header()), + Header::Str(_) => JsonValue::Str("".into()), + Header::Object(_) => JsonValue::Object(Arc::new(LazyIndexMap::default())), + _ => JsonValue::Array(Arc::new(smallvec![])), + } + } + + pub fn write_json_header_only(self, writer: &mut JsonWriter) -> ToJsonResult<()> { + match self { + Header::Null => writer.write_null(), + Header::Bool(b) => writer.write_value(b)?, + Header::Int(n) => writer.write_value(n.decode_i64_header())?, + Header::IntBig(_) => todo!(), + Header::Float(n) => writer.write_value(n.decode_f64_header())?, + // TODO check the + Header::Str(len) => { + len.check_empty()?; + writer.write_value("")?; + } + Header::Object(len) => { + len.check_empty()?; + writer.write_empty_object(); + } + Self::HeaderArray(len) | Self::U8Array(len) | Self::I64Array(len) | Self::HetArray(len) => { + len.check_empty()?; + writer.write_empty_array(); + } + } + Ok(()) + } + + pub fn into_bool(self) -> Option { + match self { + Header::Bool(b) => Some(b), + _ => None, + } + } +} + +macro_rules! impl_from_u8 { + ($header_enum:ty, $max_value:literal) => { + impl $header_enum { + fn from_u8(value: u8, p: &Decoder) -> DecodeResult { + if value <= $max_value { + Ok(unsafe { std::mem::transmute::(value) }) + } else { + Err(p.error(DecodeErrorType::HeaderInvalid { + value, + ty: stringify!($header_enum), + })) + } + } + } + }; +} + +/// Left half of the first header byte determines the category of the value +/// Up to 16 categories are possible +#[derive(Debug, Copy, Clone)] +pub(crate) enum Category { + Primitive = 0, + Int = 1, + BigInt = 2, + Float = 3, + Str = 4, + Object = 5, + HeaderArray = 6, + U8Array = 7, + I64Array = 8, + HetArray = 9, +} +impl_from_u8!(Category, 9); + +impl Category { + pub fn encode_with(self, right: u8) -> u8 { + let left = self as u8; + (left << 4) | right + } +} + +#[derive(Debug, Copy, Clone)] +pub(crate) enum Primitive { + Null = 0, + True = 1, + False = 2, +} +impl_from_u8!(Primitive, 2); + +impl From for Primitive { + fn from(value: bool) -> Self { + if value { + Self::True + } else { + Self::False + } + } +} + +impl Primitive { + fn header_value(self) -> Header { + match self { + Primitive::Null => Header::Null, + Primitive::True => Header::Bool(true), + Primitive::False => Header::Bool(false), + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum NumberHint { + Zero = 0, + One = 1, + Two = 2, + Three = 3, + Four = 4, + Five = 5, + Six = 6, + Seven = 7, + Eight = 8, + Nine = 9, + Ten = 10, + // larger numbers + Size8 = 11, + Size32 = 12, + Size64 = 13, +} +impl_from_u8!(NumberHint, 13); + +impl NumberHint { + pub fn decode_i64(self, d: &mut Decoder) -> DecodeResult { + match self { + NumberHint::Size8 => d.take_i8().map(i64::from), + NumberHint::Size32 => d.take_i32().map(i64::from), + NumberHint::Size64 => d.take_i64(), + // TODO check this has same performance as inline match + _ => Ok(self.decode_i64_header()), + } + } + + #[inline] + pub fn decode_i64_header(self) -> i64 { + match self { + NumberHint::Zero => 0, + NumberHint::One => 1, + NumberHint::Two => 2, + NumberHint::Three => 3, + NumberHint::Four => 4, + NumberHint::Five => 5, + NumberHint::Six => 6, + NumberHint::Seven => 7, + NumberHint::Eight => 8, + NumberHint::Nine => 9, + NumberHint::Ten => 10, + _ => unreachable!("Expected concrete value, got {self:?}"), + } + } + + pub fn decode_f64(self, d: &mut Decoder) -> DecodeResult { + match self { + // f8 doesn't exist + NumberHint::Size8 => Err(d.error(DecodeErrorType::HeaderInvalid { + value: self as u8, + ty: "f64", + })), + NumberHint::Size32 => d.take_f32().map(f64::from), + NumberHint::Size64 => d.take_f64(), + // TODO check this has same performance as inline match + _ => Ok(self.decode_f64_header()), + } + } + + #[inline] + fn decode_f64_header(self) -> f64 { + match self { + NumberHint::Zero => 0.0, + NumberHint::One => 1.0, + NumberHint::Two => 2.0, + NumberHint::Three => 3.0, + NumberHint::Four => 4.0, + NumberHint::Five => 5.0, + NumberHint::Six => 6.0, + NumberHint::Seven => 7.0, + NumberHint::Eight => 8.0, + NumberHint::Nine => 9.0, + NumberHint::Ten => 10.0, + _ => unreachable!("Expected concrete value, got {self:?}"), + } + } + + pub fn header_only_i64(int: i64) -> Option { + match int { + 0 => Some(NumberHint::Zero), + 1 => Some(NumberHint::One), + 2 => Some(NumberHint::Two), + 3 => Some(NumberHint::Three), + 4 => Some(NumberHint::Four), + 5 => Some(NumberHint::Five), + 6 => Some(NumberHint::Six), + 7 => Some(NumberHint::Seven), + 8 => Some(NumberHint::Eight), + 9 => Some(NumberHint::Nine), + 10 => Some(NumberHint::Ten), + _ => None, + } + } + + pub fn header_only_f64(float: f64) -> Option { + match float { + 0.0 => Some(NumberHint::Zero), + 1.0 => Some(NumberHint::One), + 2.0 => Some(NumberHint::Two), + 3.0 => Some(NumberHint::Three), + 4.0 => Some(NumberHint::Four), + 5.0 => Some(NumberHint::Five), + 6.0 => Some(NumberHint::Six), + 7.0 => Some(NumberHint::Seven), + 8.0 => Some(NumberHint::Eight), + 9.0 => Some(NumberHint::Nine), + 10.0 => Some(NumberHint::Ten), + _ => None, + } + } +} + +/// String and packed array length header part +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum Length { + Empty = 0, + One = 1, + Two = 2, + Three = 3, + Four = 4, + Five = 5, + Six = 6, + Seven = 7, + Eight = 8, + Nine = 9, + Ten = 10, + // larger numbers + U8 = 11, + U16 = 12, + U32 = 13, +} +impl_from_u8!(Length, 13); + +impl From for Length { + fn from(len: u64) -> Self { + match len { + 0 => Self::Empty, + 1 => Self::One, + 2 => Self::Two, + 3 => Self::Three, + 4 => Self::Four, + 5 => Self::Five, + 6 => Self::Six, + 7 => Self::Seven, + 8 => Self::Eight, + 9 => Self::Nine, + 10 => Self::Ten, + len if len <= u64::from(u8::MAX) => Self::U8, + len if len <= u64::from(u16::MAX) => Self::U16, + _ => Self::U32, + } + } +} + +impl Length { + pub fn decode(self, d: &mut Decoder) -> DecodeResult { + match self { + Self::Empty => Ok(0), + Self::One => Ok(1), + Self::Two => Ok(2), + Self::Three => Ok(3), + Self::Four => Ok(4), + Self::Five => Ok(5), + Self::Six => Ok(6), + Self::Seven => Ok(7), + Self::Eight => Ok(8), + Self::Nine => Ok(9), + Self::Ten => Ok(10), + Self::U8 => d.take_u8().map(|s| s as usize), + Self::U16 => d.take_u16().map(|s| s as usize), + Self::U32 => d.take_u32().map(|s| s as usize), + } + } + + pub fn check_empty(self) -> ToJsonResult<()> { + if matches!(self, Self::Empty) { + Ok(()) + } else { + Err("Expected empty length, got non-empty".into()) + } + } +} + +/// Split a byte into two 4-bit halves - u8 numbers with a range of 0-15 +fn split_byte(byte: u8) -> (u8, u8) { + let left = byte >> 4; // Shift the byte right by 4 bits + let right = byte & 0b0000_1111; // Mask the byte with 00001111 + (left, right) +} diff --git a/crates/batson/src/json_writer.rs b/crates/batson/src/json_writer.rs new file mode 100644 index 0000000..d5f3936 --- /dev/null +++ b/crates/batson/src/json_writer.rs @@ -0,0 +1,118 @@ +use serde::ser::Serializer as _; +use serde_json::ser::Serializer; + +use crate::errors::ToJsonResult; + +pub(crate) struct JsonWriter { + vec: Vec, +} + +impl JsonWriter { + pub fn new() -> Self { + Self { + vec: Vec::with_capacity(128), + } + } + + pub fn write_null(&mut self) { + self.vec.extend_from_slice(b"null"); + } + + #[allow(clippy::needless_pass_by_value)] + pub fn write_value(&mut self, v: impl WriteJson) -> ToJsonResult<()> { + v.write_json(self) + } + + pub fn write_seq<'a>(&mut self, mut v: impl Iterator) -> ToJsonResult<()> { + self.start_array(); + + if let Some(first) = v.next() { + first.write_json(self)?; + for value in v { + self.comma(); + value.write_json(self)?; + } + } + self.end_array(); + Ok(()) + } + + pub fn write_empty_array(&mut self) { + self.vec.extend_from_slice(b"[]"); + } + + pub fn start_array(&mut self) { + self.vec.push(b'['); + } + + pub fn end_array(&mut self) { + self.vec.push(b']'); + } + + pub fn write_key(&mut self, key: &str) -> ToJsonResult<()> { + self.write_value(key)?; + self.vec.push(b':'); + Ok(()) + } + + pub fn write_empty_object(&mut self) { + self.vec.extend_from_slice(b"{}"); + } + + pub fn start_object(&mut self) { + self.vec.push(b'{'); + } + + pub fn end_object(&mut self) { + self.vec.push(b'}'); + } + + pub fn comma(&mut self) { + self.vec.push(b','); + } +} + +impl From for Vec { + fn from(writer: JsonWriter) -> Self { + writer.vec + } +} + +pub(crate) trait WriteJson { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()>; +} + +impl WriteJson for &str { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_str(self).map_err(Into::into) + } +} + +impl WriteJson for bool { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + writer.vec.extend_from_slice(if *self { b"true" } else { b"false" }); + Ok(()) + } +} + +impl WriteJson for u8 { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_u8(*self).map_err(Into::into) + } +} + +impl WriteJson for i64 { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_i64(*self).map_err(Into::into) + } +} + +impl WriteJson for f64 { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_f64(*self).map_err(Into::into) + } +} diff --git a/crates/batson/src/lib.rs b/crates/batson/src/lib.rs new file mode 100644 index 0000000..c014e8b --- /dev/null +++ b/crates/batson/src/lib.rs @@ -0,0 +1,76 @@ +#![allow(dead_code)] +mod array; +mod decoder; +mod encoder; +mod errors; +pub mod get; +mod header; +mod json_writer; +mod object; + +use jiter::JsonValue; + +use crate::json_writer::JsonWriter; +use decoder::Decoder; +use encoder::Encoder; +pub use errors::{DecodeErrorType, DecodeResult, EncodeError, EncodeResult, ToJsonError, ToJsonResult}; + +/// Encode binary data from a JSON value. +/// +/// # Errors +/// +/// Returns an error if the data is not valid. +pub fn encode_from_json(value: &JsonValue<'_>) -> EncodeResult> { + let mut encoder = Encoder::new(); + encoder.encode_value(value)?; + Ok(encoder.into()) +} + +/// Decode binary data to a JSON value. +/// +/// # Errors +/// +/// Returns an error if the data is not valid. +pub fn decode_to_json_value(bytes: &[u8]) -> DecodeResult { + Decoder::new(bytes).take_value() +} + +pub fn batson_to_json_vec(batson_bytes: &[u8]) -> ToJsonResult> { + let mut writer = JsonWriter::new(); + Decoder::new(batson_bytes).write_json(&mut writer)?; + Ok(writer.into()) +} + +pub fn batson_to_json_string(batson_bytes: &[u8]) -> ToJsonResult { + let v = batson_to_json_vec(batson_bytes)?; + // safe since we're guaranteed to have written valid UTF-8 + unsafe { Ok(String::from_utf8_unchecked(v)) } +} + +/// Hack while waiting for +#[must_use] +pub fn compare_json_values(a: &JsonValue<'_>, b: &JsonValue<'_>) -> bool { + match (a, b) { + (JsonValue::Null, JsonValue::Null) => true, + (JsonValue::Bool(a), JsonValue::Bool(b)) => a == b, + (JsonValue::Int(a), JsonValue::Int(b)) => a == b, + (JsonValue::BigInt(a), JsonValue::BigInt(b)) => a == b, + (JsonValue::Float(a), JsonValue::Float(b)) => (a - b).abs() <= f64::EPSILON, + (JsonValue::Str(a), JsonValue::Str(b)) => a == b, + (JsonValue::Array(a), JsonValue::Array(b)) => { + if a.len() != b.len() { + return false; + } + a.iter().zip(b.iter()).all(|(a, b)| compare_json_values(a, b)) + } + (JsonValue::Object(a), JsonValue::Object(b)) => { + if a.len() != b.len() { + return false; + } + a.iter() + .zip(b.iter()) + .all(|((ak, av), (bk, bv))| ak == bk && compare_json_values(av, bv)) + } + _ => false, + } +} diff --git a/crates/batson/src/object.rs b/crates/batson/src/object.rs new file mode 100644 index 0000000..d8ed4a3 --- /dev/null +++ b/crates/batson/src/object.rs @@ -0,0 +1,300 @@ +use std::cmp::Ordering; +use std::sync::Arc; + +use bytemuck::{Pod, Zeroable}; +use jiter::{JsonObject, LazyIndexMap}; + +use crate::decoder::Decoder; +use crate::encoder::Encoder; +use crate::errors::{DecodeErrorType, DecodeResult, EncodeResult}; +use crate::header::{Category, Length}; +use crate::json_writer::JsonWriter; +use crate::ToJsonResult; + +#[derive(Debug)] +pub(crate) struct Object<'b> { + super_header: &'b [SuperHeaderItem], +} + +impl<'b> Object<'b> { + pub fn decode_header(d: &mut Decoder<'b>, length: Length) -> DecodeResult { + if matches!(length, Length::Empty) { + Ok(Self { super_header: &[] }) + } else { + let length = length.decode(d)?; + let super_header = d.take_slice_as(length)?; + Ok(Self { super_header }) + } + } + + pub fn len(&self) -> usize { + self.super_header.len() + } + + pub fn get(&self, d: &mut Decoder<'b>, key: &str) -> DecodeResult { + // key length and hash + let kl = key.len() as u32; + let kh = djb2_hash(key); + let Some(header_iter) = binary_search(self.super_header, |h| h.cmp_raw(kl, kh)) else { + return Ok(false); + }; + let start_index = d.index; + + for h in header_iter { + d.index = h.position as usize; + let possible_key = d.take_slice(h.key_length as usize)?; + if possible_key == key.as_bytes() { + return Ok(true); + } + } + // reset the index + d.index = start_index; + Ok(false) + } + + pub fn to_json(&self, d: &mut Decoder<'b>) -> DecodeResult> { + self.super_header + .iter() + .map(|_| { + let key = self.take_next_key(d)?; + let value = d.take_value()?; + Ok((key.into(), value)) + }) + .collect::>>() + .map(Arc::new) + } + + pub fn write_json(&self, d: &mut Decoder<'b>, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut steps = 0..self.len(); + writer.start_object(); + if steps.next().is_some() { + let key = self.take_next_key(d)?; + writer.write_key(key)?; + d.write_json(writer)?; + for _ in steps { + writer.comma(); + let key = self.take_next_key(d)?; + writer.write_key(key)?; + d.write_json(writer)?; + } + } + writer.end_object(); + Ok(()) + } + + pub fn take_next_key(&self, d: &mut Decoder<'b>) -> DecodeResult<&'b str> { + let header_index = d.take_u32()?; + match self.super_header.get(header_index as usize) { + Some(h) => d.take_str(h.key_length as usize), + None => Err(d.error(DecodeErrorType::ObjectBodyIndexInvalid)), + } + } +} + +/// Represents an item in the header +/// +/// # Warning +/// +/// **Member order matters here** since it decides the layout of the struct when serialized. +#[derive(Debug, Copy, Clone, Pod, Zeroable)] +#[repr(C)] +struct SuperHeaderItem { + key_length: u32, + key_hash: u32, + position: u32, +} + +impl SuperHeaderItem { + fn new(key: &str, position: u32) -> Self { + Self { + key_length: key.len() as u32, + key_hash: djb2_hash(key), + position, + } + } + + fn cmp_raw(&self, key_len: u32, key_hash: u32) -> Ordering { + match self.key_length.cmp(&key_len) { + Ordering::Equal => self.key_hash.cmp(&key_hash), + x => x, + } + } +} + +/// Search a sorted slice and return a sub-slice of values that match a given predicate. +fn binary_search<'b, S>( + haystack: &'b [S], + compare: impl Fn(&S) -> Ordering + 'b, +) -> Option> { + let mut low = 0; + let mut high = haystack.len(); + + // Perform binary search to find one occurrence of the value + loop { + let mid = low + (high - low) / 2; + match compare(&haystack[mid]) { + Ordering::Less => low = mid + 1, + Ordering::Greater => high = mid, + Ordering::Equal => { + // Finding the start of the sub-slice with the target value + let start = haystack[..mid] + .iter() + .rposition(|x| compare(x).is_ne()) + .map_or(0, |pos| pos + 1); + return Some(haystack[start..].iter().take_while(move |x| compare(x).is_eq())); + } + } + if low >= high { + return None; + } + } +} + +pub(crate) fn encode_object(encoder: &mut Encoder, object: &JsonObject) -> EncodeResult<()> { + if object.is_empty() { + // shortcut but also no alignment! + return encoder.encode_length(Category::Object, 0); + } + + encoder.encode_length(Category::Object, object.len())?; + + let mut super_header = Vec::with_capacity(object.len()); + encoder.align::(); + let super_header_start = encoder.ring_fence(object.len() * size_of::()); + + for (key, value) in object.iter() { + let key_str = key.as_ref(); + // add space for the header index, to be set correctly later + encoder.extend(&0u32.to_le_bytes()); + // push to the super header, with the position at this stage + super_header.push(SuperHeaderItem::new(key_str, encoder.position() as u32)); + // now we've recorded the position, write the key and value to the encoder + encoder.extend(key_str.as_bytes()); + encoder.encode_value(value)?; + } + super_header.sort_by(|a, b| a.cmp_raw(b.key_length, b.key_hash)); + + for (header_index, h) in super_header.iter().enumerate() { + // set the header index in body + encoder.set_range(h.position as usize - 4, &(header_index as u32).to_le_bytes()); + } + encoder.set_range(super_header_start, bytemuck::cast_slice(&super_header)); + Ok(()) +} + +/// Very simple and fast hashing algorithm that nonetheless gives good distribution. +/// +/// See and +/// and for more information. +fn djb2_hash(s: &str) -> u32 { + let mut hash_value: u32 = 5381; + for i in s.bytes() { + // hash_value * 33 + char + hash_value = hash_value + .wrapping_shl(5) + .wrapping_add(hash_value) + .wrapping_add(u32::from(i)); + } + hash_value +} + +#[cfg(test)] +mod test { + use jiter::{JsonValue, LazyIndexMap}; + + use crate::encode_from_json; + use crate::header::Header; + + use super::*; + + #[test] + fn decode_get() { + let v = JsonValue::Object(Arc::new(LazyIndexMap::from(vec![ + ("aa".into(), JsonValue::Str("hello, world!".into())), + ("bat".into(), JsonValue::Int(42)), + ("c".into(), JsonValue::Bool(true)), + ]))); + let b = encode_from_json(&v).unwrap(); + let mut d = Decoder::new(&b); + let header = d.take_header().unwrap(); + assert_eq!(header, Header::Object(3.into())); + + let obj = Object::decode_header(&mut d, 3.into()).unwrap(); + + assert_eq!(obj.len(), 3); + + let mut d2 = d.clone(); + assert!(obj.get(&mut d2, "aa").unwrap()); + assert_eq!(d2.take_value().unwrap(), JsonValue::Str("hello, world!".into())); + + let mut d3 = d.clone(); + assert!(obj.get(&mut d3, "bat").unwrap()); + assert_eq!(d3.take_value().unwrap(), JsonValue::Int(42)); + + let mut d4 = d.clone(); + assert!(obj.get(&mut d4, "c").unwrap()); + assert_eq!(d4.take_value().unwrap(), JsonValue::Bool(true)); + + assert!(!obj.get(&mut d, "x").unwrap()); + } + + #[test] + fn decode_empty() { + let v = JsonValue::Object(Arc::new(LazyIndexMap::default())); + let b = encode_from_json(&v).unwrap(); + assert_eq!(b.len(), 1); + let mut d = Decoder::new(&b); + let header = d.take_header().unwrap(); + assert_eq!(header, Header::Object(0.into())); + + let obj = Object::decode_header(&mut d, 0.into()).unwrap(); + assert_eq!(obj.len(), 0); + } + + #[test] + fn binary_search_direct() { + let slice = &["", "b", "ba", "fo", "spam"]; + let mut count = 0; + for i in binary_search(slice, |x| x.len().cmp(&1)).unwrap() { + assert_eq!(*i, "b"); + count += 1; + } + assert_eq!(count, 1); + } + + fn binary_search_vec(haystack: &[S], compare: impl Fn(&S) -> Ordering) -> Option> { + binary_search(haystack, compare).map(|i| i.cloned().collect()) + } + + #[test] + fn binary_search_ints() { + let slice = &[1, 2, 2, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8]; + assert_eq!(binary_search_vec(slice, |x| x.cmp(&1)), Some(vec![1])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&2)), Some(vec![2, 2, 2])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&3)), Some(vec![3])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&4)), Some(vec![4])); + assert_eq!( + binary_search_vec(slice, |x| x.cmp(&7)), + Some(vec![7, 7, 7, 7, 7, 7, 7, 7]) + ); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&8)), Some(vec![8, 8])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&12)), None); + } + + #[test] + fn binary_search_strings() { + let slice = &["", "b", "ba", "fo", "spam"]; + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&0)), Some(vec![""])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&1)), Some(vec!["b"])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&2)), Some(vec!["ba", "fo"])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&4)), Some(vec!["spam"])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&5)), None); + } + + #[test] + fn binary_search_take_while() { + // in valid input to test take_while isn't iterating further + let slice = &[1, 2, 2, 1, 3]; + assert_eq!(binary_search_vec(slice, |x| x.cmp(&1)), Some(vec![1])); + } +} diff --git a/crates/batson/tests/main.rs b/crates/batson/tests/main.rs new file mode 100644 index 0000000..fd85839 --- /dev/null +++ b/crates/batson/tests/main.rs @@ -0,0 +1,238 @@ +use std::sync::Arc; + +use jiter::{JsonValue, LazyIndexMap}; +use smallvec::smallvec; + +use batson::get::{contains, get_bool, get_int, get_length, get_str}; +use batson::{batson_to_json_string, compare_json_values, decode_to_json_value, encode_from_json}; + +#[test] +fn round_trip_all() { + let v: JsonValue<'static> = JsonValue::Object(Arc::new(LazyIndexMap::from(vec![ + // primitives + ("null".into(), JsonValue::Null), + ("false".into(), JsonValue::Bool(false)), + ("true".into(), JsonValue::Bool(true)), + // int + ("int-zero".into(), JsonValue::Int(0)), + ("int-in-header".into(), JsonValue::Int(9)), + ("int-8".into(), JsonValue::Int(123)), + ("int-32".into(), JsonValue::Int(1_000)), + ("int-64".into(), JsonValue::Int(i64::from(i32::MAX) + 5)), + ("int-max".into(), JsonValue::Int(i64::MAX)), + ("int-neg-in-header".into(), JsonValue::Int(-9)), + ("int-neg-8".into(), JsonValue::Int(-123)), + ("int-neg-32".into(), JsonValue::Int(-1_000)), + ("int-gex-64".into(), JsonValue::Int(-(i64::from(i32::MAX) + 5))), + ("int-min".into(), JsonValue::Int(i64::MIN)), + // floats + ("float-zero".into(), JsonValue::Float(0.0)), + ("float-in-header".into(), JsonValue::Float(9.0)), + ("float-pos".into(), JsonValue::Float(123.45)), + ("float-pos2".into(), JsonValue::Float(123_456_789.0)), + ("float-neg".into(), JsonValue::Float(-123.45)), + ("float-neg2".into(), JsonValue::Float(-123_456_789.0)), + // strings + ("str-empty".into(), JsonValue::Str("".into())), + ("str-short".into(), JsonValue::Str("foo".into())), + ("str-larger".into(), JsonValue::Str("foo bat spam".into())), + // het array + ( + "het-array".into(), + JsonValue::Array(Arc::new(smallvec![ + JsonValue::Int(42), + JsonValue::Str("foobar".into()), + JsonValue::Bool(true), + ])), + ), + // header array + ( + "header-array".into(), + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(6), JsonValue::Bool(true),])), + ), + // i64 array + ( + "i64-array".into(), + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(i64::MAX),])), + ), + // u8 array + ( + "u8-array".into(), + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(255),])), + ), + ]))); + let b = encode_from_json(&v).unwrap(); + + let v2 = decode_to_json_value(&b).unwrap(); + assert!(compare_json_values(&v2, &v)); +} + +fn json_to_batson(json: &[u8]) -> Vec { + let json_value = JsonValue::parse(json, false).unwrap(); + encode_from_json(&json_value).unwrap() +} + +#[test] +fn test_get_bool() { + let bytes = json_to_batson(br#"{"foo": true}"#); + + assert!(get_bool(&bytes, &["foo".into()]).unwrap().unwrap()); + assert!(get_bool(&bytes, &["bar".into()]).unwrap().is_none()); +} + +#[test] +fn test_contains() { + let bytes = json_to_batson(br#"{"foo": true, "bar": [1, 2], "ham": "foo"}"#); + + assert!(contains(&bytes, &["foo".into()]).unwrap()); + assert!(contains(&bytes, &["bar".into()]).unwrap()); + assert!(contains(&bytes, &["ham".into()]).unwrap()); + assert!(contains(&bytes, &["bar".into(), 0.into()]).unwrap()); + assert!(contains(&bytes, &["bar".into(), 1.into()]).unwrap()); + + assert!(!contains(&bytes, &["spam".into()]).unwrap()); + assert!(!contains(&bytes, &["bar".into(), 2.into()]).unwrap()); + assert!(!contains(&bytes, &["ham".into(), 0.into()]).unwrap()); +} + +#[test] +fn test_get_str_object() { + let bytes = json_to_batson(br#"{"foo": "bar", "spam": true}"#); + + assert_eq!(get_str(&bytes, &["foo".into()]).unwrap().unwrap(), "bar"); + assert!(get_str(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_str(&bytes, &["spam".into()]).unwrap().is_none()); +} + +#[test] +fn test_get_str_array() { + let bytes = json_to_batson(br#"["foo", 123, "bar"]"#); + + assert_eq!(get_str(&bytes, &[0.into()]).unwrap().unwrap(), "foo"); + assert_eq!(get_str(&bytes, &[2.into()]).unwrap().unwrap(), "bar"); + + assert!(get_str(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_str(&bytes, &[3.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_str_nested() { + let bytes = json_to_batson(br#"{"foo": [null, {"bar": "baz"}]}"#); + + assert_eq!( + get_str(&bytes, &["foo".into(), 1.into(), "bar".into()]) + .unwrap() + .unwrap(), + "baz" + ); + + assert!(get_str(&bytes, &["foo".into()]).unwrap().is_none()); + assert!(get_str(&bytes, &["spam".into(), 1.into()]).unwrap().is_none()); + assert!(get_str(&bytes, &["spam".into(), 1.into(), "bar".into(), 6.into()]) + .unwrap() + .is_none()); +} + +#[test] +fn test_get_int_object() { + let bytes = json_to_batson(br#"{"foo": 42, "spam": true}"#); + + assert_eq!(get_int(&bytes, &["foo".into()]).unwrap().unwrap(), 42); + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &["spam".into()]).unwrap().is_none()); +} + +#[test] +fn test_get_int_het_array() { + let bytes = json_to_batson(br#"[-42, "foo", 922337203685477580]"#); + + assert_eq!(get_int(&bytes, &[0.into()]).unwrap().unwrap(), -42); + assert_eq!(get_int(&bytes, &[2.into()]).unwrap().unwrap(), 922_337_203_685_477_580); + + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &[3.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_int_u8_array() { + let bytes = json_to_batson(br"[42, 123]"); + + assert_eq!(get_int(&bytes, &[0.into()]).unwrap().unwrap(), 42); + assert_eq!(get_int(&bytes, &[1.into()]).unwrap().unwrap(), 123); + + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &[2.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_int_i64_array() { + let bytes = json_to_batson(br"[-123, 922337203685477580]"); + + assert_eq!(get_int(&bytes, &[0.into()]).unwrap().unwrap(), -123); + assert_eq!(get_int(&bytes, &[1.into()]).unwrap().unwrap(), 922_337_203_685_477_580); + + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &[2.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_length() { + let bytes = json_to_batson(br#"{"foo": [null, {"a": 1, "b": 2}, 1]}"#); + + assert_eq!(get_length(&bytes, &[]).unwrap().unwrap(), 1); + assert_eq!(get_length(&bytes, &["foo".into()]).unwrap().unwrap(), 3); + assert_eq!(get_length(&bytes, &["foo".into(), 1.into()]).unwrap().unwrap(), 2); +} + +#[test] +fn test_to_json() { + let bytes = json_to_batson(br" [true, 123] "); + let s = batson_to_json_string(&bytes).unwrap(); + assert_eq!(s, r"[true,123]"); +} + +fn json_round_trip(input_json: &str) { + let bytes = json_to_batson(input_json.as_bytes()); + let output_json = batson_to_json_string(&bytes).unwrap(); + assert_eq!(&output_json, input_json); +} + +macro_rules! json_round_trip_tests { + ($($name:ident => $json:literal;)*) => { + $( + paste::item! { + #[test] + fn [< json_round_trip_ $name >]() { + json_round_trip($json); + } + } + )* + } +} + +json_round_trip_tests!( + array_empty => r"[]"; + array_bool => r"[true,false]"; + array_bool_int => r"[true,123]"; + array_u8 => r"[1,2,44,255]"; + array_i64 => r"[-1,2,44,255,1234]"; + array_header => r#"[6,true,false,null,0,[],{},""]"#; + array_het => r#"[true,123,"foo",null]"#; + string_empty => r#""""#; + string_hello => r#""hello""#; + string_escape => r#""\"he\nllo\"""#; + string_unicode => r#"{"£":"🤪"}"#; + object_empty => r#"{}"#; + object_bool => r#"{"foo":true}"#; + object_two => r#"{"foo":1,"bar":2}"#; + object_three => r#"{"foo":1,"bar":2,"baz":3}"#; + object_int => r#"{"foo":123}"#; + object_string => r#"{"foo":"bar"}"#; + object_array => r#"{"foo":[1,2]}"#; + object_nested => r#"{"foo":{"bar":true}}"#; + object_nested_array => r#"{"foo":{"bar":[1,2]}}"#; + object_nested_array_nested => r#"{"foo":{"bar":[{"baz":true}]}}"#; + float_zero => r#"0.0"#; + float_neg => r#"-123.45"#; + float_pos => r#"123.456789"#; +); diff --git a/crates/jiter-python/Cargo.toml b/crates/jiter-python/Cargo.toml index 86efab4..eef8b1f 100644 --- a/crates/jiter-python/Cargo.toml +++ b/crates/jiter-python/Cargo.toml @@ -11,7 +11,7 @@ repository = {workspace = true} [dependencies] pyo3 = { workspace = true, features = ["num-bigint"] } -jiter = { path = "../jiter", features = ["python"] } +jiter = { workspace = true, features = ["python"] } [features] # must be enabled when building with `cargo build`, maturin enables this automatically diff --git a/crates/jiter/Cargo.toml b/crates/jiter/Cargo.toml index ddcef5b..13fa2d9 100644 --- a/crates/jiter/Cargo.toml +++ b/crates/jiter/Cargo.toml @@ -24,12 +24,12 @@ bitvec = "1.0.1" python = ["dep:pyo3", "dep:pyo3-build-config"] [dev-dependencies] -bencher = "0.1.5" -paste = "1.0.7" +bencher = { workspace = true } +paste = { workspace = true } +codspeed-bencher-compat = { workspace = true } serde_json = {version = "1.0.87", features = ["preserve_order", "arbitrary_precision", "float_roundtrip"]} serde = "1.0.147" pyo3 = { workspace = true, features = ["auto-initialize"] } -codspeed-bencher-compat = "2.7.1" [build-dependencies] pyo3-build-config = { workspace = true, optional = true }