Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TryIntoStruct trait for ergonomic result parsing into a struct. #275

Merged
merged 2 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trustfall/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ pub use trustfall_core::ir::{FieldValue, TransparentValue};
/// Trustfall query schema.
pub use trustfall_core::schema::Schema;

// Trait for converting query results into structs.
pub use trustfall_core::TryIntoStruct;

/// Run a Trustfall query over the data provider specified by the given schema and adapter.
pub fn execute_query<'vertex>(
schema: &Schema,
Expand Down
3 changes: 3 additions & 0 deletions trustfall_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ pub mod graphql_query;
pub mod interpreter;
pub mod ir;
pub mod schema;
mod serialization;
mod util;

pub use serialization::TryIntoStruct;

#[cfg(test)]
mod numbers_interpreter;

Expand Down
1 change: 1 addition & 0 deletions trustfall_core/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod ir;
mod nullables_interpreter;
mod numbers_interpreter;
mod schema;
mod serialization;
mod util;

use std::{
Expand Down
267 changes: 267 additions & 0 deletions trustfall_core/src/serialization/deserializers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
use std::{collections::BTreeMap, sync::Arc};

use serde::de::{self, IntoDeserializer};

use crate::ir::FieldValue;

#[derive(Debug, Clone)]
pub(super) struct QueryResultDeserializer {
query_result: BTreeMap<Arc<str>, FieldValue>,
}

impl QueryResultDeserializer {
pub(super) fn new(query_result: BTreeMap<Arc<str>, FieldValue>) -> Self {
Self { query_result }
}
}

#[derive(Debug, Clone)]
struct QueryResultMapDeserializer<I: Iterator<Item = (Arc<str>, FieldValue)>> {
iter: I,
next_value: Option<FieldValue>,
}

impl<I: Iterator<Item = (Arc<str>, FieldValue)>> QueryResultMapDeserializer<I> {
fn new(iter: I) -> Self {
Self {
iter,
next_value: Default::default(),
}
}
}

#[derive(Debug, Clone, thiserror::Error)]
pub enum Error {
#[error("error from deserialize: {0}")]
Custom(String),
}

impl de::Error for Error {
fn custom<T>(msg: T) -> Self
where
T: std::fmt::Display,
{
Self::Custom(msg.to_string())
}
}

impl<'de> de::Deserializer<'de> for QueryResultDeserializer {
type Error = Error;

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_map(QueryResultMapDeserializer::new(
self.query_result.into_iter(),
))
}

serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map struct enum identifier ignored_any
}
}

impl<'de, I: Iterator<Item = (Arc<str>, FieldValue)>> de::MapAccess<'de>
for QueryResultMapDeserializer<I>
{
type Error = Error;

fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: de::DeserializeSeed<'de>,
{
self.iter
.next()
.map(|(key, value)| {
self.next_value = Some(value);
seed.deserialize(key.into_deserializer())
})
.transpose()
}

fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: de::DeserializeSeed<'de>,
{
seed.deserialize(
self.next_value
.take()
.expect("called next_value_seed out of order")
.into_deserializer(),
)
}
}

pub struct FieldValueDeserializer {
value: FieldValue,
}

impl<'de> de::IntoDeserializer<'de, Error> for FieldValue {
type Deserializer = FieldValueDeserializer;

fn into_deserializer(self) -> Self::Deserializer {
FieldValueDeserializer { value: self }
}
}

impl<'de> de::Deserializer<'de> for FieldValueDeserializer {
type Error = Error;

fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_i8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_i8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_i16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_i16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_i32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_i32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_u8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_u8(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_u16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_u16(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Int64(v) => {
visitor.visit_u32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
FieldValue::Uint64(v) => {
visitor.visit_u32(v.try_into().map_err(<Self::Error as de::Error>::custom)?)
}
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Float64(v) => visitor.visit_f32(v as f32),
_ => self.deserialize_any(visitor), // we'll let `deserialize_any()` raise the error
}
}

fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
if let FieldValue::List(v) = &self.value {
if len != v.len() {
return Err(Self::Error::Custom(format!(
"cannot deserialize {} length list into {len} sized tuple",
v.len()
)));
}
}
self.deserialize_any(visitor)
}

fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match &self.value {
&FieldValue::Null => visitor.visit_none(),
_ => visitor.visit_some(self),
}
}

fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_none()
}

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
match self.value {
FieldValue::Null => visitor.visit_none(),
FieldValue::Int64(v) => visitor.visit_i64(v),
FieldValue::Uint64(v) => visitor.visit_u64(v),
FieldValue::Float64(v) => visitor.visit_f64(v),
FieldValue::String(v) => visitor.visit_string(v),
FieldValue::Boolean(v) => visitor.visit_bool(v),
FieldValue::DateTimeUtc(_) => todo!(),
FieldValue::Enum(_) => todo!(),
FieldValue::List(v) => visitor.visit_seq(v.into_deserializer()),
}
}

serde::forward_to_deserialize_any! {
bool i64 i128 u64 u128 f64 char str string seq
bytes byte_buf unit unit_struct newtype_struct
tuple_struct map enum struct identifier
}
}
64 changes: 64 additions & 0 deletions trustfall_core/src/serialization/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::{collections::BTreeMap, sync::Arc};

use serde::de;

use crate::ir::FieldValue;

mod deserializers;

#[cfg(test)]
mod tests;

/// Deserialize Trustfall query results into a Rust struct.
///
/// ```rust
/// # use std::{collections::BTreeMap, sync::Arc};
/// # use maplit::btreemap;
/// # use trustfall_core::ir::FieldValue;
/// #
/// # fn run_query() -> Result<Box<dyn Iterator<Item = BTreeMap<Arc<str>, FieldValue>>>, ()> {
/// # Ok(Box::new(vec![
/// # btreemap! {
/// # Arc::from("number") => FieldValue::Int64(42),
/// # Arc::from("text") => FieldValue::String("the answer to everything".to_string()),
/// # }
/// # ].into_iter()))
/// # }
///
/// use trustfall_core::TryIntoStruct;
///
/// #[derive(Debug, PartialEq, Eq, serde::Deserialize)]
/// struct Output {
/// number: i64,
/// text: String,
/// }
///
/// let results: Vec<_> = run_query()
/// .expect("bad query arguments")
/// .map(|v| v.try_into_struct().expect("struct definition did not match query result shape"))
/// .collect();
///
/// assert_eq!(
/// vec![
/// Output {
/// number: 42,
/// text: "the answer to everything".to_string(),
/// },
/// ],
/// results,
/// );
/// ```
pub trait TryIntoStruct {
type Error;

fn try_into_struct<S: for<'de> de::Deserialize<'de>>(self) -> Result<S, Self::Error>;
}

impl TryIntoStruct for BTreeMap<Arc<str>, FieldValue> {
type Error = deserializers::Error;

fn try_into_struct<S: for<'de> de::Deserialize<'de>>(self) -> Result<S, deserializers::Error> {
let deserializer = deserializers::QueryResultDeserializer::new(self);
S::deserialize(deserializer)
}
}
Loading