Skip to content

Commit

Permalink
Make our sats<->serde translation compatible with RON
Browse files Browse the repository at this point in the history
  • Loading branch information
coolreader18 committed Sep 24, 2024
1 parent 9f62399 commit 19bbfcc
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 104 deletions.
176 changes: 77 additions & 99 deletions crates/sats/src/de/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use serde::de as serde;

/// Converts any [`serde::Deserializer`] to a SATS [`Deserializer`]
/// so that Serde's data formats can be reused.
///
/// In order for successful round-trip deserialization, the `serde::Deserializer`
/// that this type wraps must support `deserialize_any()`.
pub struct SerdeDeserializer<D> {
/// A deserialization data format in Serde.
de: D,
Expand Down Expand Up @@ -46,19 +49,11 @@ impl<'de, D: serde::Deserializer<'de>> Deserializer<'de> for SerdeDeserializer<D
type Error = SerdeError<D::Error>;

fn deserialize_product<V: super::ProductVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
self.de
.deserialize_struct("", &[], TupleVisitor { visitor })
.map_err(SerdeError)
self.de.deserialize_any(TupleVisitor { visitor }).map_err(SerdeError)
}

fn deserialize_sum<V: super::SumVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
if visitor.is_option() && self.de.is_human_readable() {
self.de.deserialize_any(OptionVisitor { visitor }).map_err(SerdeError)
} else {
self.de
.deserialize_enum("", &[], EnumVisitor { visitor })
.map_err(SerdeError)
}
self.de.deserialize_any(EnumVisitor { visitor }).map_err(SerdeError)
}

fn deserialize_bool(self) -> Result<bool, Self::Error> {
Expand Down Expand Up @@ -267,71 +262,6 @@ impl<'de, A: serde::SeqAccess<'de>> super::SeqProductAccess<'de> for SeqTupleAcc
}
}

/// Converts a `SumVisitor` into a `serde::Visitor` for deserializing option.
struct OptionVisitor<V> {
/// The visitor to convert.
visitor: V,
}

impl<'de, V: super::SumVisitor<'de>> serde::Visitor<'de> for OptionVisitor<V> {
type Value = V::Output;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("option")
}

fn visit_map<A: serde::MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
self.visitor.visit_sum(SomeAccess(map)).map_err(unwrap_error)
}

fn visit_unit<E: serde::Error>(self) -> Result<Self::Value, E> {
self.visitor.visit_sum(NoneAccess(PhantomData)).map_err(unwrap_error)
}
}

/// Deserializes `some` variant of an optional value.
/// Converts Serde's map deserialization to SATS.
struct SomeAccess<A>(A);

impl<'de, A: serde::MapAccess<'de>> super::SumAccess<'de> for SomeAccess<A> {
type Error = SerdeError<A::Error>;
type Variant = Self;

fn variant<V: super::VariantVisitor>(mut self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
self.0
.next_key_seed(VariantVisitor { visitor })
.and_then(|x| match x {
Some(x) => Ok((x, self)),
None => Err(serde::Error::custom("expected variant name")),
})
.map_err(SerdeError)
}
}
impl<'de, A: serde::MapAccess<'de>> super::VariantAccess<'de> for SomeAccess<A> {
type Error = SerdeError<A::Error>;

fn deserialize_seed<T: super::DeserializeSeed<'de>>(mut self, seed: T) -> Result<T::Output, Self::Error> {
let ret = self.0.next_value_seed(SeedWrapper(seed)).map_err(SerdeError)?;
self.0.next_key_seed(NothingVisitor).map_err(SerdeError)?;
Ok(ret)
}
}

/// Deserializes nothing, producing `!` effectively.
struct NothingVisitor;
impl<'de> serde::DeserializeSeed<'de> for NothingVisitor {
type Value = std::convert::Infallible;
fn deserialize<D: serde::Deserializer<'de>>(self, deserializer: D) -> Result<Self::Value, D::Error> {
deserializer.deserialize_identifier(self)
}
}
impl serde::Visitor<'_> for NothingVisitor {
type Value = std::convert::Infallible;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("nothing")
}
}

/// Deserializes `none` variant of an optional value.
struct NoneAccess<E>(PhantomData<E>);
impl<E: super::Error> super::SumAccess<'_> for NoneAccess<E> {
Expand Down Expand Up @@ -364,29 +294,32 @@ impl<'de, V: super::SumVisitor<'de>> serde::Visitor<'de> for EnumVisitor<V> {
type Value = V::Output;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("enum")
match self.visitor.sum_name() {
Some(name) => write!(f, "sum type {name}"),
None => f.write_str("sum type"),
}
}

fn visit_enum<A: serde::EnumAccess<'de>>(self, access: A) -> Result<Self::Value, A::Error> {
fn visit_map<A>(self, access: A) -> Result<Self::Value, A::Error>
where
A: serde::MapAccess<'de>,
{
self.visitor.visit_sum(EnumAccess { access }).map_err(unwrap_error)
}
}

/// Converts Serde's `EnumAccess` to SATS `SumAccess`.
struct EnumAccess<A> {
/// The Serde `EnumAccess`.
access: A,
}

impl<'de, A: serde::EnumAccess<'de>> super::SumAccess<'de> for EnumAccess<A> {
type Error = SerdeError<A::Error>;
type Variant = VariantAccess<A::Variant>;
fn visit_seq<A>(self, access: A) -> Result<Self::Value, A::Error>
where
A: serde::SeqAccess<'de>,
{
self.visitor.visit_sum(SeqEnumAccess { access }).map_err(unwrap_error)
}

fn variant<V: super::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
self.access
.variant_seed(VariantVisitor { visitor })
.map(|(variant, access)| (variant, VariantAccess { access }))
.map_err(SerdeError)
fn visit_unit<E: serde::Error>(self) -> Result<Self::Value, E> {
if self.visitor.is_option() {
self.visitor.visit_sum(NoneAccess(PhantomData)).map_err(unwrap_error)
} else {
Err(E::invalid_type(serde::Unexpected::Unit, &self))
}
}
}

Expand All @@ -400,7 +333,7 @@ impl<'de, V: super::VariantVisitor> serde::DeserializeSeed<'de> for VariantVisit
type Value = V::Output;

fn deserialize<D: serde::Deserializer<'de>>(self, deserializer: D) -> Result<Self::Value, D::Error> {
deserializer.deserialize_identifier(self)
deserializer.deserialize_any(self)
}
}

Expand Down Expand Up @@ -430,17 +363,62 @@ impl<V: super::VariantVisitor> serde::Visitor<'_> for VariantVisitor<V> {
}
}

/// Deserializes the data of a variant using Serde's `serde::VariantAccess` translating this to SATS.
struct VariantAccess<A> {
// Implements `serde::VariantAccess`.
/// Converts Serde's `EnumAccess` to SATS `SumAccess`.
struct EnumAccess<A> {
/// The Serde `EnumAccess`.
access: A,
}

impl<'de, A: serde::VariantAccess<'de>> super::VariantAccess<'de> for VariantAccess<A> {
impl<'de, A: serde::MapAccess<'de>> super::SumAccess<'de> for EnumAccess<A> {
type Error = SerdeError<A::Error>;
type Variant = Self;

fn deserialize_seed<T: super::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Output, Self::Error> {
self.access.newtype_variant_seed(SeedWrapper(seed)).map_err(SerdeError)
fn variant<V: super::VariantVisitor>(mut self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
let errmsg = "expected map representing sum type to have exactly one field";
let key = self
.access
.next_key_seed(VariantVisitor { visitor })
.map_err(SerdeError)?
.ok_or_else(|| SerdeError(serde::Error::custom(errmsg)))?;
Ok((key, self))
}
}

impl<'de, A: serde::MapAccess<'de>> super::VariantAccess<'de> for EnumAccess<A> {
type Error = SerdeError<A::Error>;

fn deserialize_seed<T: super::DeserializeSeed<'de>>(mut self, seed: T) -> Result<T::Output, Self::Error> {
self.access.next_value_seed(SeedWrapper(seed)).map_err(SerdeError)
}
}

struct SeqEnumAccess<A> {
access: A,
}

const SEQ_ENUM_ERR: &str = "expected seq representing sum type to have exactly two fields";
impl<'de, A: serde::SeqAccess<'de>> super::SumAccess<'de> for SeqEnumAccess<A> {
type Error = SerdeError<A::Error>;
type Variant = Self;

fn variant<V: super::VariantVisitor>(mut self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
let key = self
.access
.next_element_seed(VariantVisitor { visitor })
.map_err(SerdeError)?
.ok_or_else(|| SerdeError(serde::Error::custom(SEQ_ENUM_ERR)))?;
Ok((key, self))
}
}

impl<'de, A: serde::SeqAccess<'de>> super::VariantAccess<'de> for SeqEnumAccess<A> {
type Error = SerdeError<A::Error>;

fn deserialize_seed<T: super::DeserializeSeed<'de>>(mut self, seed: T) -> Result<T::Output, Self::Error> {
self.access
.next_element_seed(SeedWrapper(seed))
.map_err(SerdeError)?
.ok_or_else(|| SerdeError(serde::Error::custom(SEQ_ENUM_ERR)))
}
}

Expand Down
12 changes: 7 additions & 5 deletions crates/sats/src/ser/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,18 @@ impl<S: serde::Serializer> Serializer for SerdeSerializer<S> {
value: &T,
) -> Result<Self::Ok, Self::Error> {
// can't use serialize_variant cause we're too dynamic :(
use serde::SerializeMap;
let mut map = self.ser.serialize_map(Some(1)).map_err(SerdeError)?;
use serde::{SerializeMap, SerializeTuple};
let value = SerializeWrapper::from_ref(value);
if let Some(name) = name {
let mut map = self.ser.serialize_map(Some(1)).map_err(SerdeError)?;
map.serialize_entry(name, value).map_err(SerdeError)?;
map.end().map_err(SerdeError)
} else {
// FIXME: this probably wouldn't decode if you ran it back through
map.serialize_entry(&tag, value).map_err(SerdeError)?;
let mut seq = self.ser.serialize_tuple(2).map_err(SerdeError)?;
seq.serialize_element(&tag).map_err(SerdeError)?;
seq.serialize_element(value).map_err(SerdeError)?;
seq.end().map_err(SerdeError)
}
map.end().map_err(SerdeError)
}

unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
Expand Down

0 comments on commit 19bbfcc

Please sign in to comment.