From d799f7f5f76c6fb3aafbccdf0f7c38d93580e86a Mon Sep 17 00:00:00 2001 From: doug-q <141026920+doug-q@users.noreply.github.com> Date: Thu, 30 May 2024 10:04:29 +0100 Subject: [PATCH] test: serialisation round trip testing for `OpDef` (#999) We also add coverage for roundtripping `CustomSerialized` which was omitted from the previous PR. --------- Co-authored-by: Alan Lawrence --- hugr-core/src/extension.rs | 5 +- hugr-core/src/extension/op_def.rs | 162 +++++++++++++++++- hugr-core/src/hugr/serialize/test.rs | 16 +- hugr-core/src/ops/constant.rs | 61 +++---- hugr-core/src/ops/constant/custom.rs | 40 +++++ hugr-core/src/ops/custom.rs | 23 +-- hugr-core/src/proptest.rs | 4 + .../std_extensions/arithmetic/int_types.rs | 58 +++++++ 8 files changed, 309 insertions(+), 60 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index c2e58b0a8..1ef9c50cd 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -538,7 +538,9 @@ impl FromIterator for ExtensionSet { } #[cfg(test)] -mod test { +pub mod test { + // We re-export this here because mod op_def is private. + pub use super::op_def::test::SimpleOpDef; mod proptest { @@ -549,6 +551,7 @@ mod test { impl Arbitrary for ExtensionSet { type Parameters = (); type Strategy = BoxedStrategy; + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { ( hash_set(0..10usize, 0..3), diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index d91aa6890..42edd56dc 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -474,12 +474,14 @@ impl Extension { } #[cfg(test)] -mod test { +pub(super) mod test { use std::num::NonZeroU64; + use itertools::Itertools; + use super::SignatureFromArgs; use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; - use crate::extension::op_def::LowerFunc; + use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc}; use crate::extension::prelude::USIZE_T; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; @@ -494,6 +496,87 @@ mod test { const EXT_ID: ExtensionId = "MyExt"; } + #[derive(serde::Serialize, serde::Deserialize, Debug)] + pub struct SimpleOpDef(OpDef); + + impl SimpleOpDef { + pub fn new(op_def: OpDef) -> Self { + assert!(op_def.constant_folder.is_none()); + assert!(matches!( + op_def.signature_func, + SignatureFunc::TypeScheme(_) + )); + assert!(op_def + .lower_funcs + .iter() + .all(|lf| matches!(lf, LowerFunc::FixedHugr { .. }))); + Self(op_def) + } + } + + impl From for OpDef { + fn from(value: SimpleOpDef) -> Self { + value.0 + } + } + + impl PartialEq for SimpleOpDef { + fn eq(&self, other: &Self) -> bool { + let OpDef { + extension, + name, + description, + misc, + signature_func, + lower_funcs, + constant_folder: _, + } = &self.0; + let OpDef { + extension: other_extension, + name: other_name, + description: other_description, + misc: other_misc, + signature_func: other_signature_func, + lower_funcs: other_lower_funcs, + constant_folder: _, + } = &other.0; + + let get_sig = |sf: &_| match sf { + // if SignatureFunc or CustomValidator are changed we should get + // a compile error here. To fix: modify the fields matched on here, + // maintaining the lack of `..` and, for each part that is + // serializable, ensure we are checking it for equality below. + SignatureFunc::TypeScheme(CustomValidator { + poly_func, + validate: _, + }) => Some(poly_func.clone()), + // This is ruled out by `new()` but leave it here for later. + SignatureFunc::CustomFunc(_) => None, + }; + + let get_lower_funcs = |lfs: &Vec| { + lfs.iter() + .map(|lf| match lf { + // as with get_sig above, this should break if the hierarchy + // is changed, update similarly. + LowerFunc::FixedHugr { extensions, hugr } => { + Some((extensions.clone(), hugr.clone())) + } + // This is ruled out by `new()` but leave it here for later. + LowerFunc::CustomFunc(_) => None, + }) + .collect_vec() + }; + + extension == other_extension + && name == other_name + && description == other_description + && misc == other_misc + && get_sig(signature_func) == get_sig(other_signature_func) + && get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs) + } + } + #[test] fn op_def_with_type_scheme() -> Result<(), Box> { let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); @@ -686,4 +769,79 @@ mod test { ); Ok(()) } + + mod proptest { + use super::SimpleOpDef; + use ::proptest::prelude::*; + + use crate::{ + builder::test::simple_dfg_hugr, + extension::{ + op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc, + }, + types::PolyFuncType, + }; + + impl Arbitrary for SignatureFunc { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + // TODO there is also SignatureFunc::CustomFunc, but for now + // this is not serialised. When it is, we should generate + // examples here . + any::() + .prop_map(|x| SignatureFunc::TypeScheme(CustomValidator::from_polyfunc(x))) + .boxed() + } + } + + impl Arbitrary for LowerFunc { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + // TODO There is also LowerFunc::CustomFunc, but for now this is + // not serialised. When it is, we should generate examples here. + any::() + .prop_map(|extensions| LowerFunc::FixedHugr { + extensions, + hugr: simple_dfg_hugr(), + }) + .boxed() + } + } + + impl Arbitrary for SimpleOpDef { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + use crate::proptest::{any_serde_yaml_value, any_smolstr, any_string}; + use proptest::collection::{hash_map, vec}; + let misc = hash_map(any_string(), any_serde_yaml_value(), 0..3); + ( + any::(), + any_smolstr(), + any_string(), + misc, + any::(), + vec(any::(), 0..2), + ) + .prop_map( + |(extension, name, description, misc, signature_func, lower_funcs)| { + Self::new(OpDef { + extension, + name, + description, + misc, + signature_func, + lower_funcs, + // TODO ``constant_folder` is not serialised, we should + // generate examples once it is. + constant_folder: None, + }) + }, + ) + .boxed() + } + } + } } diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 1a2f2b54c..f50556498 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -5,7 +5,7 @@ use crate::builder::{ }; use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T}; use crate::extension::simple_op::MakeRegisteredOp; -use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; +use crate::extension::{test::SimpleOpDef, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::internal::HugrMutInternals; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG}; @@ -29,7 +29,6 @@ const NAT: Type = crate::extension::prelude::USIZE_T; const QB: Type = crate::extension::prelude::QB_T; /// Version 1 of the Testing HUGR serialisation format, see `testing_hugr.py`. -#[cfg(test)] #[derive(Serialize, Deserialize, PartialEq, Debug, Default)] struct SerTestingV1 { typ: Option, @@ -37,6 +36,7 @@ struct SerTestingV1 { poly_func_type: Option, value: Option, optype: Option, + op_def: Option, } type TestingModel = SerTestingV1; @@ -91,6 +91,7 @@ impl_sertesting_from!(crate::types::SumType, sum_type); impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type); impl_sertesting_from!(crate::ops::Value, value); impl_sertesting_from!(NodeSer, optype); +impl_sertesting_from!(SimpleOpDef, op_def); #[test] fn empty_hugr_serialize() { @@ -471,8 +472,8 @@ fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { } mod proptest { - use super::super::NodeSer; use super::check_testing_roundtrip; + use super::{NodeSer, SimpleOpDef}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; use crate::types::{PolyFuncType, Type}; @@ -513,8 +514,13 @@ mod proptest { } #[test] - fn prop_roundtrip_optype(ns: NodeSer) { - check_testing_roundtrip(ns) + fn prop_roundtrip_optype(op: NodeSer ) { + check_testing_roundtrip(op) + } + + #[test] + fn prop_roundtrip_opdef(opdef: SimpleOpDef) { + check_testing_roundtrip(opdef) } } } diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 7e3a1ac37..5f059f4f0 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -635,51 +635,38 @@ mod test { mod proptest { use super::super::OpaqueValue; use crate::{ - ops::Value, - std_extensions::arithmetic::int_types::{ConstInt, LOG_WIDTH_MAX}, + ops::{constant::CustomSerialized, Value}, + std_extensions::arithmetic::int_types::ConstInt, std_extensions::collections::ListValue, types::{SumType, Type}, }; - use ::proptest::prelude::*; + use ::proptest::{collection::vec, prelude::*}; impl Arbitrary for OpaqueValue { type Parameters = (); type Strategy = BoxedStrategy; fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - use proptest::collection::vec; - let signed_strat = (..=LOG_WIDTH_MAX).prop_flat_map(|log_width| { - use i64; - let max_val = (2u64.pow(log_width as u32) / 2) as i64; - let min_val = -max_val - 1; - (min_val..=max_val).prop_map(move |v| { - OpaqueValue::new( - ConstInt::new_s(log_width, v).expect("guaranteed to be in bounds"), + // We intentionally do not include `ConstF64` because it does not + // roundtrip serialise + prop_oneof![ + any::().prop_map_into(), + any::().prop_map_into() + ] + .prop_recursive( + 3, // No more than 3 branch levels deep + 32, // Target around 32 total elements + 3, // Each collection is up to 3 elements long + |child_strat| { + (Type::any_non_row_var(), vec(child_strat, 0..3)).prop_map( + |(typ, children)| { + Self::new(ListValue::new( + typ, + children.into_iter().map(|e| Value::Extension { e }), + )) + }, ) - }) - }); - let unsigned_strat = (..=LOG_WIDTH_MAX).prop_flat_map(|log_width| { - (0..2u64.pow(log_width as u32)).prop_map(move |v| { - OpaqueValue::new( - ConstInt::new_u(log_width, v).expect("guaranteed to be in bounds"), - ) - }) - }); - prop_oneof![unsigned_strat, signed_strat] - .prop_recursive( - 3, // No more than 3 branch levels deep - 32, // Target around 32 total elements - 3, // Each collection is up to 3 elements long - |element| { - (Type::any_non_row_var(), vec(element.clone(), 0..3)).prop_map( - |(typ, contents)| { - OpaqueValue::new(ListValue::new( - typ, - contents.into_iter().map(|e| Value::Extension { e }), - )) - }, - ) - }, - ) - .boxed() + }, + ) + .boxed() } } diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index 54a9af514..eb46a2bb7 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -476,4 +476,44 @@ mod test { serde_yaml::from_value(serde_yaml::to_value(&ev).unwrap()).unwrap() ); } + + mod proptest { + use ::proptest::prelude::*; + + use crate::{ + extension::ExtensionSet, + ops::constant::CustomSerialized, + proptest::{any_serde_yaml_value, any_string}, + types::Type, + }; + + impl Arbitrary for CustomSerialized { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + let typ = any::(); + let extensions = any::(); + // here we manually construct a serialized `dyn CustomConst`. + // The "c" and "v" come from the `typetag::serde` annotation on + // `trait CustomConst`. + // TODO This is not ideal, if we were to accidentally + // generate a valid tag(e.g. "ConstInt") then things will + // go wrong: the serde::Deserialize impl for that type will + // interpret "v" and fail. + let value = (any_serde_yaml_value(), any_string()).prop_map(|(content, tag)| { + [("c".into(), tag.into()), ("v".into(), content)] + .into_iter() + .collect::() + .into() + }); + (typ, value, extensions) + .prop_map(|(typ, value, extensions)| CustomSerialized { + typ, + value, + extensions, + }) + .boxed() + } + } + } } diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index e8e7e0e2a..d97a8202c 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -4,7 +4,9 @@ use std::sync::Arc; use thiserror::Error; #[cfg(test)] use { + crate::extension::test::SimpleOpDef, crate::proptest::{any_nonempty_smolstr, any_nonempty_string}, + ::proptest::prelude::*, ::proptest_derive::Arbitrary, }; @@ -28,6 +30,7 @@ use super::{NamedOp, OpName, OpNameRef, OpTrait, OpType}; /// [`OpaqueOp`]: crate::ops::custom::OpaqueOp /// [`ExtensionOp`]: crate::ops::custom::ExtensionOp #[derive(Clone, Debug, Eq, serde::Serialize, serde::Deserialize)] +#[cfg_attr(test, derive(Arbitrary))] #[serde(into = "OpaqueOp", from = "OpaqueOp")] pub enum CustomOp { /// When we've found (loaded) the [Extension] definition and identified the [OpDef] @@ -170,9 +173,12 @@ impl From for CustomOp { /// /// [Extension]: crate::Extension #[derive(Clone, Debug)] -// TODO when we can geneerate `OpDef`s enable this -// #[cfg_attr(test, derive(proptest_derive::Arbitrary))] +#[cfg_attr(test, derive(Arbitrary))] pub struct ExtensionOp { + #[cfg_attr( + test, + proptest(strategy = "any::().prop_map(|x| Arc::new(x.into()))") + )] def: Arc, args: Vec, signature: FunctionType, // Cache @@ -446,17 +452,4 @@ mod test { assert!(op.is_opaque()); assert!(!op.is_extension_op()); } - - mod proptest { - use ::proptest::prelude::*; - - impl Arbitrary for super::super::CustomOp { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - // TODO when we can geneerate `OpDef`s add an `ExtensionOp` case here - any::().prop_map_into().boxed() - } - } - } } diff --git a/hugr-core/src/proptest.rs b/hugr-core/src/proptest.rs index 37bf46516..344b32a70 100644 --- a/hugr-core/src/proptest.rs +++ b/hugr-core/src/proptest.rs @@ -151,6 +151,10 @@ pub fn any_string() -> SBoxedStrategy { ANY_STRING.to_owned() } +pub fn any_smolstr() -> SBoxedStrategy { + ANY_STRING.clone().prop_map_into().sboxed() +} + pub fn any_serde_yaml_value() -> impl Strategy { // use serde_yaml::value::{Tag, TaggedValue, Value}; ANY_SERDE_YAML_VALUE_LEAF diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 444daefc1..5ff2abea3 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -276,4 +276,62 @@ mod test { ConstInt::new_s(50, -2).unwrap_err(); ConstInt::new_u(50, 2).unwrap_err(); } + + mod proptest { + use super::{ConstInt, LOG_WIDTH_MAX}; + use ::proptest::prelude::*; + use i64; + impl Arbitrary for ConstInt { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + let signed_strat = any_signed_int_with_log_width().prop_map(|(log_width, v)| { + ConstInt::new_s(log_width, v).expect("guaranteed to be in bounds") + }); + let unsigned_strat = (..=LOG_WIDTH_MAX).prop_flat_map(|log_width| { + (0..2u64.pow(log_width as u32)).prop_map(move |v| { + ConstInt::new_u(log_width, v).expect("guaranteed to be in bounds") + }) + }); + + prop_oneof![unsigned_strat, signed_strat].boxed() + } + } + + fn any_signed_int_with_log_width() -> impl Strategy { + (..=LOG_WIDTH_MAX).prop_flat_map(|log_width| { + let width = 2u64.pow(log_width as u32); + let max_val = ((1u64 << (width - 1)) - 1u64) as i64; + let min_val = -max_val - 1; + prop_oneof![(min_val..=max_val), Just(min_val), Just(max_val)] + .prop_map(move |x| (log_width, x)) + }) + } + + proptest! { + #[test] + fn valid_signed_int((log_width, x) in any_signed_int_with_log_width()) { + let (min,max) = match log_width { + 0 => (-1, 0), + 1 => (-2, 1), + 2 => (-8, 7), + 3 => (i8::MIN as i64, i8::MAX as i64), + 4 => (i16::MIN as i64, i16::MAX as i64), + 5 => (i32::MIN as i64, i32::MAX as i64), + 6 => (i64::MIN, i64::MAX), + _ => unreachable!(), + }; + let width = 2i64.pow(log_width as u32); + // the left hand side counts the number of valid values as follows: + // - use i128 to be able to hold the number of valid i64s + // - there are exactly `max` valid positive values; + // - there are exactly `-min` valid negative values; + // - there are exactly 1 zero values. + prop_assert_eq!((max as i128) - (min as i128) + 1, 1 << width); + prop_assert!(x >= min); + prop_assert!(x <= max); + prop_assert!(ConstInt::new_s(log_width, x).is_ok()) + } + } + } }