From 550df7fc98950e035c7191f39aec0ed8335973e2 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 29 Aug 2023 15:30:29 +0100 Subject: [PATCH] fix: Use internal tag for SumType enum serialisation (#462) This makes it easier to handle the serialised format with pydantic. BREAKING CHANGE: Turn `SumType.General` and `SumType.Simple` enum variants into struct variants --- src/types.rs | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/types.rs b/src/types.rs index d8500cb2a..24b8085f3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -94,31 +94,36 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> Ty } #[derive(Clone, PartialEq, Debug, Eq, derive_more::Display, Serialize, Deserialize)] +#[serde(tag = "s")] /// Representation of a Sum type. /// Either store the types of the variants, or in the special (but common) case /// of a "simple predicate" (sum over empty tuples), store only the size of the predicate. enum SumType { - #[display(fmt = "SimplePredicate({})", "_0")] - Simple(u8), - General(TypeRow), + #[display(fmt = "SimplePredicate({})", "size")] + Simple { + size: u8, + }, + General { + row: TypeRow, + }, } impl SumType { fn new(types: impl Into) -> Self { let row: TypeRow = types.into(); - let len = row.len(); + let len: usize = row.len(); if len <= (u8::MAX as usize) && row.iter().all(|t| *t == Type::UNIT) { - Self::Simple(len as u8) + Self::Simple { size: len as u8 } } else { - Self::General(row) + Self::General { row } } } fn get_variant(&self, tag: usize) -> Option<&Type> { match self { - SumType::Simple(size) if tag < (*size as usize) => Some(Type::UNIT_REF), - SumType::General(row) => row.get(tag), + SumType::Simple { size } if tag < (*size as usize) => Some(Type::UNIT_REF), + SumType::General { row } => row.get(tag), _ => None, } } @@ -127,8 +132,8 @@ impl SumType { impl From for Type { fn from(sum: SumType) -> Type { match sum { - SumType::Simple(size) => Type::new_simple_predicate(size), - SumType::General(types) => Type::new_sum(types), + SumType::Simple { size } => Type::new_simple_predicate(size), + SumType::General { row } => Type::new_sum(row), } } } @@ -147,9 +152,9 @@ impl TypeEnum { fn least_upper_bound(&self) -> TypeBound { match self { TypeEnum::Prim(p) => p.bound(), - TypeEnum::Sum(SumType::Simple(_)) => TypeBound::Eq, - TypeEnum::Sum(SumType::General(ts)) => { - least_upper_bound(ts.iter().map(Type::least_upper_bound)) + TypeEnum::Sum(SumType::Simple { size: _ }) => TypeBound::Eq, + TypeEnum::Sum(SumType::General { row }) => { + least_upper_bound(row.iter().map(Type::least_upper_bound)) } TypeEnum::Tuple(ts) => least_upper_bound(ts.iter().map(Type::least_upper_bound)), } @@ -237,7 +242,7 @@ impl Type { /// New simple predicate with empty Tuple variants pub const fn new_simple_predicate(size: u8) -> Self { // should be the only way to avoid going through SumType::new - Self(TypeEnum::Sum(SumType::Simple(size)), TypeBound::Eq) + Self(TypeEnum::Sum(SumType::Simple { size }), TypeBound::Eq) } /// Report the least upper TypeBound, if there is one. @@ -305,7 +310,7 @@ pub(crate) mod test { assert_eq!(pred1, pred2); - let pred_direct = SumType::Simple(2); + let pred_direct = SumType::Simple { size: 2 }; assert_eq!(pred1, pred_direct.into()) } }