From 080eda0a91ddc670b4655076556f8be98cc006f5 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 6 Nov 2023 16:19:23 +0000 Subject: [PATCH 1/3] refactor!: remove `SignatureDescription` (#644) BREAKING CHANGES: `SignatureDescription` no longer exists. --- src/extension/op_def.rs | 24 ---------- src/ops.rs | 8 +--- src/ops/custom.rs | 9 +--- src/ops/leaf.rs | 11 +---- src/types.rs | 2 +- src/types/signature.rs | 98 +---------------------------------------- 6 files changed, 5 insertions(+), 147 deletions(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 817adcec1..8fe92e4fc 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -8,8 +8,6 @@ use super::{ TypeParametrised, }; -use crate::types::SignatureDescription; - use crate::types::FunctionType; use crate::types::type_param::TypeArg; @@ -34,18 +32,6 @@ pub trait CustomSignatureFunc: Send + Sync { misc: &HashMap, extension_registry: &ExtensionRegistry, ) -> Result; - - /// Describe the signature of a node, given the operation name, - /// values for the type parameters, - /// and 'misc' data from the extension definition YAML. - fn describe_signature( - &self, - _name: &SmolStr, - _arg_values: &[TypeArg], - _misc: &HashMap, - ) -> SignatureDescription { - SignatureDescription::default() - } } // Note this is very much a utility, rather than definitive; @@ -208,16 +194,6 @@ impl OpDef { Ok(res) } - /// Optional description of the ports in the signature. - pub fn signature_desc(&self, args: &[TypeArg]) -> SignatureDescription { - match &self.signature_func { - SignatureFunc::FromDecl { .. } => { - todo!() - } - SignatureFunc::CustomFunc(bf) => bf.describe_signature(&self.name, args, &self.misc), - } - } - pub(crate) fn should_serialize_signature(&self) -> bool { match self.signature_func { SignatureFunc::CustomFunc(_) => true, diff --git a/src/ops.rs b/src/ops.rs index 4926e60b6..f6ef25004 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -9,7 +9,7 @@ pub mod leaf; pub mod module; pub mod tag; pub mod validate; -use crate::types::{EdgeKind, FunctionType, SignatureDescription, Type}; +use crate::types::{EdgeKind, FunctionType, Type}; use crate::PortIndex; use crate::{Direction, Port}; @@ -189,12 +189,6 @@ pub trait OpTrait { fn signature(&self) -> FunctionType { Default::default() } - /// Optional description of the ports in the signature. - /// - /// Only dataflow operations have a non-empty signature. - fn signature_desc(&self) -> SignatureDescription { - Default::default() - } /// Get the static input type of this operation if it has one (only Some for /// [`LoadConstant`] and [`Call`]) diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 77ef60399..b1c5a39b3 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; -use crate::types::{type_param::TypeArg, FunctionType, SignatureDescription}; +use crate::types::{type_param::TypeArg, FunctionType}; use crate::{Hugr, Node}; use super::tag::OpTag; @@ -76,13 +76,6 @@ impl OpTrait for ExternalOp { } } - fn signature_desc(&self) -> SignatureDescription { - match self { - Self::Opaque(_) => SignatureDescription::default(), - Self::Extension(ExtensionOp { def, args, .. }) => def.signature_desc(args), - } - } - fn tag(&self) -> OpTag { OpTag::Leaf } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index f09cd2328..768d35fcb 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -7,7 +7,7 @@ use super::{OpName, OpTag, OpTrait, StaticTag}; use crate::{ extension::{ExtensionId, ExtensionSet}, - types::{EdgeKind, FunctionType, SignatureDescription, Type, TypeRow}, + types::{EdgeKind, FunctionType, Type, TypeRow}, }; /// Dataflow operations with no children. @@ -118,15 +118,6 @@ impl OpTrait for LeafOp { } } - /// Optional description of the ports in the signature. - fn signature_desc(&self) -> SignatureDescription { - match self { - LeafOp::CustomOp(ext) => ext.signature_desc(), - // TODO: More port descriptions - _ => Default::default(), - } - } - fn other_input(&self) -> Option { Some(EdgeKind::StateOrder) } diff --git a/src/types.rs b/src/types.rs index cabdcf39d..8ea81504e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -10,7 +10,7 @@ pub mod type_row; pub use check::{ConstTypeError, CustomCheckFailure}; pub use custom::CustomType; -pub use signature::{FunctionType, Signature, SignatureDescription}; +pub use signature::{FunctionType, Signature}; pub use type_param::TypeArg; pub use type_row::TypeRow; diff --git a/src/types/signature.rs b/src/types/signature.rs index ce13b5e6e..b2995dec0 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -4,13 +4,11 @@ use pyo3::{pyclass, pymethods}; use delegate::delegate; -use smol_str::SmolStr; use std::fmt::{self, Display, Write}; -use std::ops::Index; use crate::extension::ExtensionSet; use crate::types::{Type, TypeRow}; -use crate::{Direction, IncomingPort, OutgoingPort, Port, PortIndex}; +use crate::{Direction, IncomingPort, OutgoingPort, Port}; #[cfg_attr(feature = "pyo3", pyclass)] #[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -249,97 +247,3 @@ impl Display for Signature { } } } - -/// Descriptive names for the ports in a [`Signature`]. -/// -/// This is a separate type from [`Signature`] as it is not normally used during the compiler operations. -#[cfg_attr(feature = "pyo3", pyclass)] -#[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct SignatureDescription { - /// Input of the function. - pub input: Vec, - /// Output of the function. - pub output: Vec, -} - -#[cfg_attr(feature = "pyo3", pymethods)] -impl SignatureDescription { - /// The number of wires in the signature. - #[inline(always)] - pub fn is_empty(&self) -> bool { - self.input.is_empty() && self.output.is_empty() - } -} - -impl SignatureDescription { - /// Create a new signature. - pub fn new(input: impl Into>, output: impl Into>) -> Self { - Self { - input: input.into(), - output: output.into(), - } - } - - /// Create a new signature with only linear inputs and outputs. - pub fn new_linear(linear: impl Into>) -> Self { - let linear = linear.into(); - SignatureDescription::new(linear.clone(), linear) - } - - pub(crate) fn row_zip<'a>( - type_row: &'a TypeRow, - name_row: &'a [SmolStr], - ) -> impl Iterator { - name_row - .iter() - .chain(&EmptyStringIterator) - .zip(type_row.iter()) - } - - /// Iterate over the input wires of the signature and their names. - /// - /// Unnamed wires are given an empty string name. - /// - /// TODO: Return Option<&String> instead of &String for the description. - pub fn input_zip<'a>( - &'a self, - signature: &'a Signature, - ) -> impl Iterator { - Self::row_zip(signature.input(), &self.input) - } - - /// Iterate over the output wires of the signature and their names. - /// - /// Unnamed wires are given an empty string name. - pub fn output_zip<'a>( - &'a self, - signature: &'a Signature, - ) -> impl Iterator { - Self::row_zip(signature.output(), &self.output) - } -} - -impl Index for SignatureDescription { - type Output = SmolStr; - - fn index(&self, index: Port) -> &Self::Output { - match index.direction() { - Direction::Incoming => self.input.get(index.index()).unwrap_or(EMPTY_STRING_REF), - Direction::Outgoing => self.output.get(index.index()).unwrap_or(EMPTY_STRING_REF), - } - } -} - -/// An iterator that always returns the an empty string. -pub(crate) struct EmptyStringIterator; - -/// A reference to an empty string. Used by [`EmptyStringIterator`]. -pub(crate) const EMPTY_STRING_REF: &SmolStr = &SmolStr::new_inline(""); - -impl<'a> Iterator for &'a EmptyStringIterator { - type Item = &'a SmolStr; - - fn next(&mut self) -> Option { - Some(EMPTY_STRING_REF) - } -} From 57270d508f11211e5357782ea7dcf9ec07115174 Mon Sep 17 00:00:00 2001 From: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Date: Mon, 6 Nov 2023 16:56:07 +0000 Subject: [PATCH 2/3] feat!: Remove "rotations" extension. (#645) Quaternion (and angle) type definitions will live in tket2. --- src/std_extensions.rs | 1 - src/std_extensions/rotation.rs | 400 --------------------------------- 2 files changed, 401 deletions(-) delete mode 100644 src/std_extensions/rotation.rs diff --git a/src/std_extensions.rs b/src/std_extensions.rs index 2ee80378f..bc4af73cb 100644 --- a/src/std_extensions.rs +++ b/src/std_extensions.rs @@ -6,4 +6,3 @@ pub mod arithmetic; pub mod collections; pub mod logic; pub mod quantum; -pub mod rotation; diff --git a/src/std_extensions/rotation.rs b/src/std_extensions/rotation.rs deleted file mode 100644 index 35b2102b9..000000000 --- a/src/std_extensions/rotation.rs +++ /dev/null @@ -1,400 +0,0 @@ -#![allow(missing_docs)] -//! This is an experiment, it is probably already outdated. - -use std::ops::{Add, Div, Mul, Neg, Sub}; - -use cgmath::num_traits::ToPrimitive; -use num_rational::Rational64; -use smol_str::SmolStr; - -#[cfg(feature = "pyo3")] -use pyo3::{pyclass, FromPyObject}; - -use crate::extension::ExtensionId; -use crate::types::type_param::TypeArg; -use crate::types::{CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow}; -use crate::values::CustomConst; -use crate::{ops, Extension}; - -pub const PI_NAME: &str = "PI"; -pub const ANGLE_T_NAME: &str = "angle"; -pub const QUAT_T_NAME: &str = "quat"; -pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("rotations"); - -pub const ANGLE_T: Type = Type::new_extension(CustomType::new_simple( - SmolStr::new_inline(ANGLE_T_NAME), - EXTENSION_ID, - TypeBound::Copyable, -)); - -pub const QUAT_T: Type = Type::new_extension(CustomType::new_simple( - SmolStr::new_inline(QUAT_T_NAME), - EXTENSION_ID, - TypeBound::Copyable, -)); -/// The extension with all the operations and types defined in this extension. -pub fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_ID); - - RotationType::Angle.add_to_extension(&mut extension); - RotationType::Quaternion.add_to_extension(&mut extension); - - extension - .add_op_custom_sig_simple( - "AngleAdd".into(), - "".into(), - vec![], - |_arg_values: &[TypeArg]| { - let t: TypeRow = - vec![Type::new_extension(RotationType::Angle.custom_type())].into(); - Ok(FunctionType::new(t.clone(), t)) - }, - ) - .unwrap(); - - let pi_val = RotationValue::PI; - - extension - .add_value(PI_NAME, ops::Const::new(pi_val.into(), ANGLE_T).unwrap()) - .unwrap(); - extension -} - -/// Custom types defined by this extension. -#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -pub enum RotationType { - Angle, - Quaternion, -} - -impl RotationType { - pub const fn name(&self) -> SmolStr { - match self { - RotationType::Angle => SmolStr::new_inline(ANGLE_T_NAME), - RotationType::Quaternion => SmolStr::new_inline(QUAT_T_NAME), - } - } - - pub const fn description(&self) -> &str { - match self { - RotationType::Angle => "Floating point angle", - RotationType::Quaternion => "Quaternion specifying rotation.", - } - } - - pub fn custom_type(self) -> CustomType { - CustomType::new(self.name(), [], EXTENSION_ID, TypeBound::Copyable) - } - - fn add_to_extension(self, extension: &mut Extension) { - extension - .add_type( - self.name(), - vec![], - self.description().to_string(), - TypeBound::Copyable.into(), - ) - .unwrap(); - } -} - -impl From for CustomType { - fn from(ty: RotationType) -> Self { - ty.custom_type() - } -} - -/// Constant values for [`RotationType`]. -#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -pub enum RotationValue { - Angle(AngleValue), - Quaternion(cgmath::Quaternion), -} - -impl RotationValue { - const PI: Self = Self::Angle(AngleValue::PI); - fn rotation_type(&self) -> RotationType { - match self { - RotationValue::Angle(_) => RotationType::Angle, - RotationValue::Quaternion(_) => RotationType::Quaternion, - } - } -} - -#[typetag::serde] -impl CustomConst for RotationValue { - fn name(&self) -> SmolStr { - match self { - RotationValue::Angle(val) => format!("AngleConstant({})", val.radians()), - RotationValue::Quaternion(val) => format!("QuatConstant({:?})", val), - } - .into() - } - - fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { - let self_typ = self.rotation_type(); - - if &self_typ.custom_type() == typ { - Ok(()) - } else { - Err(CustomCheckFailure::Message( - "Rotation constant type mismatch.".into(), - )) - } - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::values::downcast_equal_consts(self, other) - } -} - -// -// TODO: -// -// operations: -// -// AngleAdd, -// AngleMul, -// AngleNeg, -// QuatMul, -// RxF64, -// RzF64, -// TK1, -// Rotation, -// ToRotation, -// -// -// -// signatures: -// -// LeafOp::AngleAdd | LeafOp::AngleMul => FunctionType::new_linear([Type::Angle]), -// LeafOp::QuatMul => FunctionType::new_linear([Type::Quat64]), -// LeafOp::AngleNeg => FunctionType::new_linear([Type::Angle]), -// LeafOp::RxF64 | LeafOp::RzF64 => { -// FunctionType::new_df([Type::Qubit], [Type::Angle]) -// } -// LeafOp::TK1 => FunctionType::new_df(vec![Type::Qubit], vec![Type::Angle; 3]), -// LeafOp::Rotation => FunctionType::new_df([Type::Qubit], [Type::Quat64]), -// LeafOp::ToRotation => FunctionType::new_df( -// [ -// Type::Angle, -// Type::F64, -// Type::F64, -// Type::F64, -// ], -// [Type::Quat64], -// ), - -#[derive(Clone, Copy, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "pyo3", pyclass(name = "Rational"))] -pub struct Rational(pub Rational64); - -impl From for Rational { - fn from(r: Rational64) -> Self { - Self(r) - } -} - -// angle is contained value * pi in radians -#[derive(Clone, PartialEq, Debug, Copy, serde::Serialize, serde::Deserialize)] -#[cfg_attr(feature = "pyo3", derive(FromPyObject))] -pub enum AngleValue { - F64(f64), - Rational(Rational), -} - -impl AngleValue { - const PI: Self = AngleValue::Rational(Rational(Rational64::new_raw(1, 1))); - fn binary_op f64, G: FnOnce(Rational64, Rational64) -> Rational64>( - self, - rhs: Self, - opf: F, - opr: G, - ) -> Self { - match (self, rhs) { - (AngleValue::F64(x), AngleValue::F64(y)) => AngleValue::F64(opf(x, y)), - (AngleValue::F64(x), AngleValue::Rational(y)) - | (AngleValue::Rational(y), AngleValue::F64(x)) => { - AngleValue::F64(opf(x, y.0.to_f64().unwrap())) - } - (AngleValue::Rational(x), AngleValue::Rational(y)) => { - AngleValue::Rational(Rational(opr(x.0, y.0))) - } - } - } - - fn unary_op f64, G: FnOnce(Rational64) -> Rational64>( - self, - opf: F, - opr: G, - ) -> Self { - match self { - AngleValue::F64(x) => AngleValue::F64(opf(x)), - AngleValue::Rational(x) => AngleValue::Rational(Rational(opr(x.0))), - } - } - - pub fn to_f64(&self) -> f64 { - match self { - AngleValue::F64(x) => *x, - AngleValue::Rational(x) => x.0.to_f64().expect("Floating point conversion error."), - } - } - - pub fn radians(&self) -> f64 { - self.to_f64() * std::f64::consts::PI - } -} - -impl Add for AngleValue { - type Output = AngleValue; - - fn add(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x + y, |x, y| x + y) - } -} - -impl Sub for AngleValue { - type Output = AngleValue; - - fn sub(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x - y, |x, y| x - y) - } -} - -impl Mul for AngleValue { - type Output = AngleValue; - - fn mul(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x * y, |x, y| x * y) - } -} - -impl Div for AngleValue { - type Output = AngleValue; - - fn div(self, rhs: Self) -> Self::Output { - self.binary_op(rhs, |x, y| x / y, |x, y| x / y) - } -} - -impl Neg for AngleValue { - type Output = AngleValue; - - fn neg(self) -> Self::Output { - self.unary_op(|x| -x, |x| -x) - } -} - -impl Add for &AngleValue { - type Output = AngleValue; - - fn add(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x + y, |x, y| x + y) - } -} - -impl Sub for &AngleValue { - type Output = AngleValue; - - fn sub(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x - y, |x, y| x - y) - } -} - -impl Mul for &AngleValue { - type Output = AngleValue; - - fn mul(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x * y, |x, y| x * y) - } -} - -impl Div for &AngleValue { - type Output = AngleValue; - - fn div(self, rhs: Self) -> Self::Output { - self.binary_op(*rhs, |x, y| x / y, |x, y| x / y) - } -} - -impl Neg for &AngleValue { - type Output = AngleValue; - - fn neg(self) -> Self::Output { - self.unary_op(|x| -x, |x| -x) - } -} - -#[cfg(test)] -mod test { - - use rstest::{fixture, rstest}; - - use super::{AngleValue, RotationValue, ANGLE_T, ANGLE_T_NAME, EXTENSION_ID, PI_NAME}; - use crate::{ - extension::ExtensionId, - extension::SignatureError, - types::{CustomType, Type, TypeBound}, - values::CustomConst, - Extension, - }; - - #[fixture] - fn extension() -> Extension { - super::extension() - } - - #[rstest] - fn test_types(extension: Extension) { - let angle = extension.get_type(ANGLE_T_NAME).unwrap(); - - let custom = angle.instantiate_concrete([]).unwrap(); - - angle.check_custom(&custom).unwrap(); - - let wrong_ext = ExtensionId::new("wrong_extensions").unwrap(); - - let false_custom = CustomType::new( - custom.name().clone(), - vec![], - wrong_ext.clone(), - TypeBound::Copyable, - ); - assert_eq!( - angle.check_custom(&false_custom), - Err(SignatureError::ExtensionMismatch(EXTENSION_ID, wrong_ext,)) - ); - - assert_eq!(Type::new_extension(custom), ANGLE_T); - } - - #[rstest] - fn test_type_check(extension: Extension) { - let custom_type = extension - .get_type(ANGLE_T_NAME) - .unwrap() - .instantiate_concrete([]) - .unwrap(); - - let custom_value = RotationValue::Angle(AngleValue::F64(0.0)); - - // correct type - custom_value.check_custom_type(&custom_type).unwrap(); - - let wrong_custom_type = extension - .get_type("quat") - .unwrap() - .instantiate_concrete([]) - .unwrap(); - let res = custom_value.check_custom_type(&wrong_custom_type); - assert!(res.is_err()); - } - - #[rstest] - fn test_constant(extension: Extension) { - let pi_val = extension.get_value(PI_NAME).unwrap(); - - ANGLE_T.check_type(pi_val.typed_value().value()).unwrap(); - } -} From b7ebb0de32dc04a32a39d8f8429caac97dc38a76 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 7 Nov 2023 09:42:26 +0000 Subject: [PATCH 3/3] fix: erratic stack overflow in infer.rs (live_var) (#638) closes #633 * Add test case. Reduced this to a 2-node DFG (plus Input and Output). That this hasn't been showing up repeatedly in other tests - I put down to #388: most custom ops have empty `extension_reqs` so do not generate Plus constraints. * Run (50%-failing) test 10 times -> failure rate ~~ 1 in 2^10 * Fix calculation of `live_vars` and `live_metas` by one-off traversal of constraints+solutions in `fn results()`. --------- Co-authored-by: Craig Roy --- src/extension/infer.rs | 130 +++++++++++++++++++++++++++-------------- 1 file changed, 86 insertions(+), 44 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index dd29d71c4..47eb88e67 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -22,7 +22,7 @@ use super::validate::ExtensionError; use petgraph::graph as pg; -use std::collections::{HashMap, HashSet}; +use std::collections::{HashMap, HashSet, VecDeque}; use thiserror::Error; @@ -521,57 +521,46 @@ impl UnificationContext { pub fn results(&self) -> Result { // Check that all of the metavariables associated with nodes of the // graph are solved + let depended_upon = { + let mut h: HashMap> = HashMap::new(); + for (m, m2) in self.constraints.iter().flat_map(|(m, cs)| { + cs.iter().flat_map(|c| match c { + Constraint::Plus(_, m2) => Some((*m, self.resolve(*m2))), + _ => None, + }) + }) { + h.entry(m2).or_default().push(m); + } + h + }; + // Calculate everything dependent upon a variable. + // Note it would be better to find metas ALL of whose dependencies were (transitively) + // on variables, but this is more complex, and hard to define if there are cycles + // of PLUS constraints, so leaving that as a TODO until we've handled such cycles. + let mut depends_on_var = HashSet::new(); + let mut queue = VecDeque::from_iter(self.variables.iter()); + while let Some(m) = queue.pop_front() { + if depends_on_var.insert(m) { + if let Some(d) = depended_upon.get(m) { + queue.extend(d.iter()) + } + } + } + let mut results: ExtensionSolution = HashMap::new(); for (loc, meta) in self.extensions.iter() { if let Some(rs) = self.get_solution(meta) { if loc.1 == Direction::Incoming { results.insert(loc.0, rs.clone()); } - } else if self.live_var(meta).is_some() { - // If it depends on some other live meta, that's bad news. - return Err(InferExtensionError::Unsolved { location: *loc }); - } - // If it only depends on graph variables, then we don't have - // a *solution*, but it's fine - } - debug_assert!(self.live_metas().is_empty()); - Ok(results) - } - - // Get the live var associated with a meta. - // TODO: This should really be a list - fn live_var(&self, m: &Meta) -> Option { - if self.variables.contains(m) || self.variables.contains(&self.resolve(*m)) { - return None; - } - - // TODO: We should be doing something to ensure that these are the same check... - if self.get_solution(m).is_none() { - if let Some(cs) = self.get_constraints(m) { - for c in cs { - match c { - Constraint::Plus(_, m) => return self.live_var(m), - _ => panic!("we shouldn't be here!"), - } + } else { + // Unsolved nodes must be unsolved because they depend on graph variables. + if !depends_on_var.contains(&self.resolve(*meta)) { + return Err(InferExtensionError::Unsolved { location: *loc }); } } - Some(*m) - } else { - None } - } - - /// Return the set of "live" metavariables in the context. - /// "Live" here means a metavariable: - /// - Is associated to a location in the graph in `UnifyContext.extensions` - /// - Is still unsolved - /// - Isn't a variable - fn live_metas(&self) -> HashSet { - self.extensions - .values() - .filter_map(|m| self.live_var(m)) - .filter(|m| !self.variables.contains(m)) - .collect() + Ok(results) } /// Iterates over a set of metas (the argument) and tries to solve @@ -665,12 +654,16 @@ mod test { use super::*; use crate::builder::test::closed_dfg_root_hugr; + use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; + use crate::extension::prelude::QB_T; use crate::extension::ExtensionId; use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; - use crate::ops::OpType; + use crate::ops::custom::{ExternalOp, OpaqueOp}; use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; + use crate::ops::{LeafOp, OpType}; + use crate::type_row; use crate::types::{FunctionType, Type, TypeRow}; @@ -1539,4 +1532,53 @@ mod test { Ok(()) } + + /// This was stack-overflowing approx 50% of the time, + /// see https://github.com/CQCL/hugr/issues/633 + #[test] + fn plus_on_self() -> Result<(), Box> { + let ext = ExtensionId::new("unknown1").unwrap(); + let delta = ExtensionSet::singleton(&ext); + let ft = FunctionType::new_linear(type_row![QB_T, QB_T]).with_extension_delta(&delta); + let mut dfg = DFGBuilder::new(ft.clone())?; + + // While https://github.com/CQCL-DEV/hugr/issues/388 is unsolved, + // most operations have empty extension_reqs (not including their own extension). + // Define some that do. + let binop: LeafOp = ExternalOp::Opaque(OpaqueOp::new( + ext.clone(), + "2qb_op", + String::new(), + vec![], + Some(ft), + )) + .into(); + let unary_sig = FunctionType::new_linear(type_row![QB_T]) + .with_extension_delta(&ExtensionSet::singleton(&ext)); + let unop: LeafOp = ExternalOp::Opaque(OpaqueOp::new( + ext, + "1qb_op", + String::new(), + vec![], + Some(unary_sig), + )) + .into(); + // Constrain q1,q2 as PLUS(ext1, inputs): + let [q1, q2] = dfg + .add_dataflow_op(binop.clone(), dfg.input_wires())? + .outputs_arr(); + // Constrain q1 as PLUS(ext2, q2): + let [q1] = dfg.add_dataflow_op(unop, [q1])?.outputs_arr(); + // Constrain q1 as EQUALS(q2) by using both together + dfg.finish_hugr_with_outputs([q1, q2], &PRELUDE_REGISTRY)?; + // The combined q1+q2 variable now has two PLUS constraints - on itself and the inputs. + Ok(()) + } + + /// [plus_on_self] had about a 50% rate of failing with stack overflow. + /// So if we run 10 times, that would succeed about 1 run in 2^10, i.e. <0.1% + #[test] + fn plus_on_self_10_times() { + [0; 10].iter().for_each(|_| plus_on_self().unwrap()) + } }