Skip to content

Commit

Permalink
test: serialisation round trip testing for OpDef (#999)
Browse files Browse the repository at this point in the history
We also add coverage for roundtripping `CustomSerialized` which was
omitted from the previous PR.

---------

Co-authored-by: Alan Lawrence <alan.lawrence@quantinuum.com>
  • Loading branch information
doug-q and acl-cqc authored May 30, 2024
1 parent d0cd023 commit d799f7f
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 60 deletions.
5 changes: 4 additions & 1 deletion hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,9 @@ impl FromIterator<ExtensionId> for ExtensionSet {
}

#[cfg(test)]
mod test {
pub mod test {
// We re-export this here because mod op_def is private.
pub use super::op_def::test::SimpleOpDef;

mod proptest {

Expand All @@ -549,6 +551,7 @@ mod test {
impl Arbitrary for ExtensionSet {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;

fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(
hash_set(0..10usize, 0..3),
Expand Down
162 changes: 160 additions & 2 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,14 @@ impl Extension {
}

#[cfg(test)]
mod test {
pub(super) mod test {
use std::num::NonZeroU64;

use itertools::Itertools;

use super::SignatureFromArgs;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::op_def::LowerFunc;
use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
use crate::extension::prelude::USIZE_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
Expand All @@ -494,6 +496,87 @@ mod test {
const EXT_ID: ExtensionId = "MyExt";
}

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct SimpleOpDef(OpDef);

impl SimpleOpDef {
pub fn new(op_def: OpDef) -> Self {
assert!(op_def.constant_folder.is_none());
assert!(matches!(
op_def.signature_func,
SignatureFunc::TypeScheme(_)
));
assert!(op_def
.lower_funcs
.iter()
.all(|lf| matches!(lf, LowerFunc::FixedHugr { .. })));
Self(op_def)
}
}

impl From<SimpleOpDef> for OpDef {
fn from(value: SimpleOpDef) -> Self {
value.0
}
}

impl PartialEq for SimpleOpDef {
fn eq(&self, other: &Self) -> bool {
let OpDef {
extension,
name,
description,
misc,
signature_func,
lower_funcs,
constant_folder: _,
} = &self.0;
let OpDef {
extension: other_extension,
name: other_name,
description: other_description,
misc: other_misc,
signature_func: other_signature_func,
lower_funcs: other_lower_funcs,
constant_folder: _,
} = &other.0;

let get_sig = |sf: &_| match sf {
// if SignatureFunc or CustomValidator are changed we should get
// a compile error here. To fix: modify the fields matched on here,
// maintaining the lack of `..` and, for each part that is
// serializable, ensure we are checking it for equality below.
SignatureFunc::TypeScheme(CustomValidator {
poly_func,
validate: _,
}) => Some(poly_func.clone()),
// This is ruled out by `new()` but leave it here for later.
SignatureFunc::CustomFunc(_) => None,
};

let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
lfs.iter()
.map(|lf| match lf {
// as with get_sig above, this should break if the hierarchy
// is changed, update similarly.
LowerFunc::FixedHugr { extensions, hugr } => {
Some((extensions.clone(), hugr.clone()))
}
// This is ruled out by `new()` but leave it here for later.
LowerFunc::CustomFunc(_) => None,
})
.collect_vec()
};

extension == other_extension
&& name == other_name
&& description == other_description
&& misc == other_misc
&& get_sig(signature_func) == get_sig(other_signature_func)
&& get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
}
}

#[test]
fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
Expand Down Expand Up @@ -686,4 +769,79 @@ mod test {
);
Ok(())
}

mod proptest {
use super::SimpleOpDef;
use ::proptest::prelude::*;

use crate::{
builder::test::simple_dfg_hugr,
extension::{
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc,
},
types::PolyFuncType,
};

impl Arbitrary for SignatureFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// TODO there is also SignatureFunc::CustomFunc, but for now
// this is not serialised. When it is, we should generate
// examples here .
any::<PolyFuncType>()
.prop_map(|x| SignatureFunc::TypeScheme(CustomValidator::from_polyfunc(x)))
.boxed()
}
}

impl Arbitrary for LowerFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
// TODO There is also LowerFunc::CustomFunc, but for now this is
// not serialised. When it is, we should generate examples here.
any::<ExtensionSet>()
.prop_map(|extensions| LowerFunc::FixedHugr {
extensions,
hugr: simple_dfg_hugr(),
})
.boxed()
}
}

impl Arbitrary for SimpleOpDef {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use crate::proptest::{any_serde_yaml_value, any_smolstr, any_string};
use proptest::collection::{hash_map, vec};
let misc = hash_map(any_string(), any_serde_yaml_value(), 0..3);
(
any::<ExtensionId>(),
any_smolstr(),
any_string(),
misc,
any::<SignatureFunc>(),
vec(any::<LowerFunc>(), 0..2),
)
.prop_map(
|(extension, name, description, misc, signature_func, lower_funcs)| {
Self::new(OpDef {
extension,
name,
description,
misc,
signature_func,
lower_funcs,
// TODO ``constant_folder` is not serialised, we should
// generate examples once it is.
constant_folder: None,
})
},
)
.boxed()
}
}
}
}
16 changes: 11 additions & 5 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::builder::{
};
use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T};
use crate::extension::simple_op::MakeRegisteredOp;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::extension::{test::SimpleOpDef, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::internal::HugrMutInternals;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG};
Expand All @@ -29,14 +29,14 @@ const NAT: Type = crate::extension::prelude::USIZE_T;
const QB: Type = crate::extension::prelude::QB_T;

/// Version 1 of the Testing HUGR serialisation format, see `testing_hugr.py`.
#[cfg(test)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
struct SerTestingV1 {
typ: Option<crate::types::Type>,
sum_type: Option<crate::types::SumType>,
poly_func_type: Option<crate::types::PolyFuncType>,
value: Option<crate::ops::Value>,
optype: Option<NodeSer>,
op_def: Option<SimpleOpDef>,
}

type TestingModel = SerTestingV1;
Expand Down Expand Up @@ -91,6 +91,7 @@ impl_sertesting_from!(crate::types::SumType, sum_type);
impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type);
impl_sertesting_from!(crate::ops::Value, value);
impl_sertesting_from!(NodeSer, optype);
impl_sertesting_from!(SimpleOpDef, op_def);

#[test]
fn empty_hugr_serialize() {
Expand Down Expand Up @@ -471,8 +472,8 @@ fn roundtrip_optype(#[case] optype: impl Into<OpType> + std::fmt::Debug) {
}

mod proptest {
use super::super::NodeSer;
use super::check_testing_roundtrip;
use super::{NodeSer, SimpleOpDef};
use crate::extension::ExtensionSet;
use crate::ops::{OpType, Value};
use crate::types::{PolyFuncType, Type};
Expand Down Expand Up @@ -513,8 +514,13 @@ mod proptest {
}

#[test]
fn prop_roundtrip_optype(ns: NodeSer) {
check_testing_roundtrip(ns)
fn prop_roundtrip_optype(op: NodeSer ) {
check_testing_roundtrip(op)
}

#[test]
fn prop_roundtrip_opdef(opdef: SimpleOpDef) {
check_testing_roundtrip(opdef)
}
}
}
61 changes: 24 additions & 37 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,51 +635,38 @@ mod test {
mod proptest {
use super::super::OpaqueValue;
use crate::{
ops::Value,
std_extensions::arithmetic::int_types::{ConstInt, LOG_WIDTH_MAX},
ops::{constant::CustomSerialized, Value},
std_extensions::arithmetic::int_types::ConstInt,
std_extensions::collections::ListValue,
types::{SumType, Type},
};
use ::proptest::prelude::*;
use ::proptest::{collection::vec, prelude::*};
impl Arbitrary for OpaqueValue {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use proptest::collection::vec;
let signed_strat = (..=LOG_WIDTH_MAX).prop_flat_map(|log_width| {
use i64;
let max_val = (2u64.pow(log_width as u32) / 2) as i64;
let min_val = -max_val - 1;
(min_val..=max_val).prop_map(move |v| {
OpaqueValue::new(
ConstInt::new_s(log_width, v).expect("guaranteed to be in bounds"),
// We intentionally do not include `ConstF64` because it does not
// roundtrip serialise
prop_oneof![
any::<ConstInt>().prop_map_into(),
any::<CustomSerialized>().prop_map_into()
]
.prop_recursive(
3, // No more than 3 branch levels deep
32, // Target around 32 total elements
3, // Each collection is up to 3 elements long
|child_strat| {
(Type::any_non_row_var(), vec(child_strat, 0..3)).prop_map(
|(typ, children)| {
Self::new(ListValue::new(
typ,
children.into_iter().map(|e| Value::Extension { e }),
))
},
)
})
});
let unsigned_strat = (..=LOG_WIDTH_MAX).prop_flat_map(|log_width| {
(0..2u64.pow(log_width as u32)).prop_map(move |v| {
OpaqueValue::new(
ConstInt::new_u(log_width, v).expect("guaranteed to be in bounds"),
)
})
});
prop_oneof![unsigned_strat, signed_strat]
.prop_recursive(
3, // No more than 3 branch levels deep
32, // Target around 32 total elements
3, // Each collection is up to 3 elements long
|element| {
(Type::any_non_row_var(), vec(element.clone(), 0..3)).prop_map(
|(typ, contents)| {
OpaqueValue::new(ListValue::new(
typ,
contents.into_iter().map(|e| Value::Extension { e }),
))
},
)
},
)
.boxed()
},
)
.boxed()
}
}

Expand Down
40 changes: 40 additions & 0 deletions hugr-core/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,4 +476,44 @@ mod test {
serde_yaml::from_value(serde_yaml::to_value(&ev).unwrap()).unwrap()
);
}

mod proptest {
use ::proptest::prelude::*;

use crate::{
extension::ExtensionSet,
ops::constant::CustomSerialized,
proptest::{any_serde_yaml_value, any_string},
types::Type,
};

impl Arbitrary for CustomSerialized {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
let typ = any::<Type>();
let extensions = any::<ExtensionSet>();
// here we manually construct a serialized `dyn CustomConst`.
// The "c" and "v" come from the `typetag::serde` annotation on
// `trait CustomConst`.
// TODO This is not ideal, if we were to accidentally
// generate a valid tag(e.g. "ConstInt") then things will
// go wrong: the serde::Deserialize impl for that type will
// interpret "v" and fail.
let value = (any_serde_yaml_value(), any_string()).prop_map(|(content, tag)| {
[("c".into(), tag.into()), ("v".into(), content)]
.into_iter()
.collect::<serde_yaml::Mapping>()
.into()
});
(typ, value, extensions)
.prop_map(|(typ, value, extensions)| CustomSerialized {
typ,
value,
extensions,
})
.boxed()
}
}
}
}
Loading

0 comments on commit d799f7f

Please sign in to comment.