diff --git a/trustfall_core/src/lib.rs b/trustfall_core/src/lib.rs index 0d3107fe..d18c3fe2 100644 --- a/trustfall_core/src/lib.rs +++ b/trustfall_core/src/lib.rs @@ -14,8 +14,11 @@ pub mod graphql_query; pub mod interpreter; pub mod ir; pub mod schema; +mod serialization; mod util; +pub use serialization::TryIntoStruct; + #[cfg(test)] mod numbers_interpreter; diff --git a/trustfall_core/src/main.rs b/trustfall_core/src/main.rs index 370af1c8..f366f6a5 100644 --- a/trustfall_core/src/main.rs +++ b/trustfall_core/src/main.rs @@ -16,6 +16,7 @@ mod ir; mod nullables_interpreter; mod numbers_interpreter; mod schema; +mod serialization; mod util; use std::{ diff --git a/trustfall_core/src/serialization/deserializers.rs b/trustfall_core/src/serialization/deserializers.rs new file mode 100644 index 00000000..54ccdfb7 --- /dev/null +++ b/trustfall_core/src/serialization/deserializers.rs @@ -0,0 +1,267 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use serde::de::{self, IntoDeserializer}; + +use crate::ir::FieldValue; + +#[derive(Debug, Clone)] +pub(super) struct QueryResultDeserializer { + query_result: BTreeMap, FieldValue>, +} + +impl QueryResultDeserializer { + pub(super) fn new(query_result: BTreeMap, FieldValue>) -> Self { + Self { query_result } + } +} + +#[derive(Debug, Clone)] +struct QueryResultMapDeserializer, FieldValue)>> { + iter: I, + next_value: Option, +} + +impl, FieldValue)>> QueryResultMapDeserializer { + fn new(iter: I) -> Self { + Self { + iter, + next_value: Default::default(), + } + } +} + +#[derive(Debug, Clone, thiserror::Error)] +pub enum Error { + #[error("error from deserialize: {0}")] + Custom(String), +} + +impl de::Error for Error { + fn custom(msg: T) -> Self + where + T: std::fmt::Display, + { + Self::Custom(msg.to_string()) + } +} + +impl<'de> de::Deserializer<'de> for QueryResultDeserializer { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_map(QueryResultMapDeserializer::new( + self.query_result.into_iter(), + )) + } + + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum identifier ignored_any + } +} + +impl<'de, I: Iterator, FieldValue)>> de::MapAccess<'de> + for QueryResultMapDeserializer +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: de::DeserializeSeed<'de>, + { + self.iter + .next() + .map(|(key, value)| { + self.next_value = Some(value); + seed.deserialize(key.into_deserializer()) + }) + .transpose() + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + seed.deserialize( + self.next_value + .take() + .expect("called next_value_seed out of order") + .into_deserializer(), + ) + } +} + +pub struct FieldValueDeserializer { + value: FieldValue, +} + +impl<'de> de::IntoDeserializer<'de, Error> for FieldValue { + type Deserializer = FieldValueDeserializer; + + fn into_deserializer(self) -> Self::Deserializer { + FieldValueDeserializer { value: self } + } +} + +impl<'de> de::Deserializer<'de> for FieldValueDeserializer { + type Error = Error; + + fn deserialize_i8(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Int64(v) => { + visitor.visit_i8(v.try_into().map_err(::custom)?) + } + FieldValue::Uint64(v) => { + visitor.visit_i8(v.try_into().map_err(::custom)?) + } + _ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error + } + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Int64(v) => { + visitor.visit_i16(v.try_into().map_err(::custom)?) + } + FieldValue::Uint64(v) => { + visitor.visit_i16(v.try_into().map_err(::custom)?) + } + _ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error + } + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Int64(v) => { + visitor.visit_i32(v.try_into().map_err(::custom)?) + } + FieldValue::Uint64(v) => { + visitor.visit_i32(v.try_into().map_err(::custom)?) + } + _ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error + } + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Int64(v) => { + visitor.visit_u8(v.try_into().map_err(::custom)?) + } + FieldValue::Uint64(v) => { + visitor.visit_u8(v.try_into().map_err(::custom)?) + } + _ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error + } + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Int64(v) => { + visitor.visit_u16(v.try_into().map_err(::custom)?) + } + FieldValue::Uint64(v) => { + visitor.visit_u16(v.try_into().map_err(::custom)?) + } + _ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error + } + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Int64(v) => { + visitor.visit_u32(v.try_into().map_err(::custom)?) + } + FieldValue::Uint64(v) => { + visitor.visit_u32(v.try_into().map_err(::custom)?) + } + _ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error + } + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Float64(v) => visitor.visit_f32(v as f32), + _ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error + } + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + if let FieldValue::List(v) = &self.value { + if len != v.len() { + return Err(Self::Error::Custom(format!( + "cannot deserialize {} length list into {len} sized tuple", + v.len() + ))); + } + } + self.deserialize_any(visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match &self.value { + &FieldValue::Null => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_none() + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value { + FieldValue::Null => visitor.visit_none(), + FieldValue::Int64(v) => visitor.visit_i64(v), + FieldValue::Uint64(v) => visitor.visit_u64(v), + FieldValue::Float64(v) => visitor.visit_f64(v), + FieldValue::String(v) => visitor.visit_string(v), + FieldValue::Boolean(v) => visitor.visit_bool(v), + FieldValue::DateTimeUtc(_) => todo!(), + FieldValue::Enum(_) => todo!(), + FieldValue::List(v) => visitor.visit_seq(v.into_deserializer()), + } + } + + serde::forward_to_deserialize_any! { + bool i64 i128 u64 u128 f64 char str string seq + bytes byte_buf unit unit_struct newtype_struct + tuple_struct map enum struct identifier + } +} diff --git a/trustfall_core/src/serialization/mod.rs b/trustfall_core/src/serialization/mod.rs new file mode 100644 index 00000000..db3c3f8d --- /dev/null +++ b/trustfall_core/src/serialization/mod.rs @@ -0,0 +1,64 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use serde::de; + +use crate::ir::FieldValue; + +mod deserializers; + +#[cfg(test)] +mod tests; + +/// Deserialize Trustfall query results into a Rust struct. +/// +/// ```rust +/// # use std::{collections::BTreeMap, sync::Arc}; +/// # use maplit::btreemap; +/// # use trustfall_core::ir::FieldValue; +/// # +/// # fn run_query() -> Result, FieldValue>>>, ()> { +/// # Ok(Box::new(vec![ +/// # btreemap! { +/// # Arc::from("number") => FieldValue::Int64(42), +/// # Arc::from("text") => FieldValue::String("the answer to everything".to_string()), +/// # } +/// # ].into_iter())) +/// # } +/// +/// use trustfall_core::TryIntoStruct; +/// +/// #[derive(Debug, PartialEq, Eq, serde::Deserialize)] +/// struct Output { +/// number: i64, +/// text: String, +/// } +/// +/// let results: Vec<_> = run_query() +/// .expect("bad query arguments") +/// .map(|v| v.try_into_struct().expect("struct definition did not match query result shape")) +/// .collect(); +/// +/// assert_eq!( +/// vec![ +/// Output { +/// number: 42, +/// text: "the answer to everything".to_string(), +/// }, +/// ], +/// results, +/// ); +/// ``` +pub trait TryIntoStruct { + type Error; + + fn try_into_struct de::Deserialize<'de>>(self) -> Result; +} + +impl TryIntoStruct for BTreeMap, FieldValue> { + type Error = deserializers::Error; + + fn try_into_struct de::Deserialize<'de>>(self) -> Result { + let deserializer = deserializers::QueryResultDeserializer::new(self); + S::deserialize(deserializer) + } +} diff --git a/trustfall_core/src/serialization/tests.rs b/trustfall_core/src/serialization/tests.rs new file mode 100644 index 00000000..337b6056 --- /dev/null +++ b/trustfall_core/src/serialization/tests.rs @@ -0,0 +1,236 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use serde::Deserialize; + +use super::TryIntoStruct; +use crate::ir::FieldValue; + +#[test] +fn deserialize_simple() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Output { + foo: i64, + bar: String, + } + + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("foo") => FieldValue::Int64(42), + Arc::from("bar") => FieldValue::String("the answer to everything".to_string()), + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!( + Output { + foo: 42, + bar: "the answer to everything".to_string(), + }, + output_value + ); +} + +#[test] +fn deserialize_list() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Output { + foo: Vec>, + bar: Vec, + } + + let vec_int = vec![vec![1, 2], vec![], vec![3, 4]]; + let vec_str: Vec = vec!["one".into(), "".into(), "two".into(), "three".into()]; + + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("foo") => FieldValue::List(vec_int.clone().into_iter().map(|x| FieldValue::List(x.into_iter().map(Into::into).collect())).collect()), + Arc::from("bar") => FieldValue::List(vec_str.clone().into_iter().map(Into::into).collect()), + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!( + Output { + foo: vec_int, + bar: vec_str, + }, + output_value + ); +} + +#[test] +fn deserialize_option() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Output { + foo: Vec>, + bar: Option, + } + + let vec_int = vec![Some(1), None, Some(2), Some(3)]; + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("foo") => FieldValue::List(vec_int.clone().into_iter().map(Into::into).collect()), + Arc::from("bar") => FieldValue::Null, + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!( + Output { + foo: vec_int, + bar: None, + }, + output_value + ); +} + +#[test] +fn deserialize_extra_keys_in_query_result() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Output { + foo: i64, + bar: String, + } + + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("foo") => FieldValue::Int64(42), + Arc::from("bar") => FieldValue::String("the answer to everything".to_string()), + Arc::from("extra") => FieldValue::Null, + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!( + Output { + foo: 42, + bar: "the answer to everything".to_string(), + }, + output_value + ); +} + +#[test] +fn deserialize_serde_rename() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Output { + #[serde(rename = "renamed_foo")] + foo: i64, + + #[serde(alias = "renamed_bar")] + bar: String, + } + + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("renamed_foo") => FieldValue::Int64(42), + Arc::from("renamed_bar") => FieldValue::String("the answer to everything".to_string()), + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!( + Output { + foo: 42, + bar: "the answer to everything".to_string(), + }, + output_value + ); +} + +#[test] +fn deserialize_narrower_types() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Output { + i32: i32, + i16: i16, + i8: i8, + u32: u32, + u16: u16, + u8: u8, + } + + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("i32") => FieldValue::Int64(-32), + Arc::from("i16") => FieldValue::Int64(-16), + Arc::from("i8") => FieldValue::Int64(8), + Arc::from("u32") => FieldValue::Uint64(32), + Arc::from("u16") => FieldValue::Uint64(16), + Arc::from("u8") => FieldValue::Uint64(8), + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!( + Output { + i32: -32, + i16: -16, + i8: 8, + u32: 32, + u16: 16, + u8: 8, + }, + output_value + ); +} + +#[test] +fn deserialize_narrower_type_f32() { + #[derive(Debug, Deserialize, PartialEq)] + struct Output { + f32: f32, + } + + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("f32") => FieldValue::Float64(1.234), + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!(Output { f32: 1.234f32 }, output_value); +} + +#[test] +fn deserialize_adjust_numeric_type_signedness() { + #[derive(Debug, Deserialize, PartialEq, Eq)] + struct Output { + i64: i64, + i32: i32, + i16: i16, + i8: i8, + u64: u64, + u32: u32, + u16: u16, + u8: u8, + } + + let value: BTreeMap, FieldValue> = btreemap! { + Arc::from("i64") => FieldValue::Uint64(i64::MAX as u64), + Arc::from("i32") => FieldValue::Uint64(i32::MAX as u64), + Arc::from("i16") => FieldValue::Uint64(i16::MAX as u64), + Arc::from("i8") => FieldValue::Uint64(i8::MAX as u64), + Arc::from("u64") => FieldValue::Int64(i64::MAX), + Arc::from("u32") => FieldValue::Int64(i32::MAX.into()), + Arc::from("u16") => FieldValue::Int64(i16::MAX.into()), + Arc::from("u8") => FieldValue::Int64(i8::MAX.into()), + }; + + let output_value = value + .try_into_struct::() + .expect("failed to create struct"); + assert_eq!( + Output { + i64: i64::MAX, + i32: i32::MAX, + i16: i16::MAX, + i8: i8::MAX, + u64: i64::MAX as u64, + u32: i32::MAX as u32, + u16: i16::MAX as u16, + u8: i8::MAX as u8, + }, + output_value + ); +}