From 568df826f9ddb4d309f60028e73488605ed33ec2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 14:55:16 +0000 Subject: [PATCH] cherry-pick #636 --- src/extension.rs | 24 +++--- src/extension/infer.rs | 82 +++++++++----------- src/std_extensions/arithmetic/conversions.rs | 6 +- 3 files changed, 49 insertions(+), 63 deletions(-) diff --git a/src/extension.rs b/src/extension.rs index aec4f6020..dd7401cfe 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -4,7 +4,7 @@ //! system (outside the `types` module), which also parses nested [`OpDef`]s. use std::collections::hash_map::Entry; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; @@ -301,18 +301,13 @@ pub enum ExtensionBuildError { } /// A set of extensions identified by their unique [`ExtensionId`]. -#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct ExtensionSet(HashSet); +#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct ExtensionSet(BTreeSet); impl ExtensionSet { /// Creates a new empty extension set. - pub fn new() -> Self { - Self(HashSet::new()) - } - - /// Creates a new extension set from some extensions. - pub fn new_from_extensions(extensions: impl Into>) -> Self { - Self(extensions.into()) + pub const fn new() -> Self { + Self(BTreeSet::new()) } /// Adds a extension to the set. @@ -350,13 +345,18 @@ impl ExtensionSet { /// The things in other which are in not in self pub fn missing_from(&self, other: &Self) -> Self { - ExtensionSet(HashSet::from_iter(other.0.difference(&self.0).cloned())) + ExtensionSet::from_iter(other.0.difference(&self.0).cloned()) } /// Iterate over the contained ExtensionIds pub fn iter(&self) -> impl Iterator { self.0.iter() } + + /// True if this set contains no [ExtensionId]s + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } impl Display for ExtensionSet { @@ -367,6 +367,6 @@ impl Display for ExtensionSet { impl FromIterator for ExtensionSet { fn from_iter>(iter: I) -> Self { - Self(HashSet::from_iter(iter)) + Self(BTreeSet::from_iter(iter)) } } diff --git a/src/extension/infer.rs b/src/extension/infer.rs index a34929e7a..d07b8b70d 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -10,7 +10,7 @@ //! depend on these open variables, then the validation check for extensions //! will succeed regardless of what the variable is instantiated to. -use super::{ExtensionId, ExtensionSet}; +use super::ExtensionSet; use crate::{ hugr::views::HugrView, ops::{OpTag, OpTrait}, @@ -65,8 +65,8 @@ impl Meta { enum Constraint { /// A variable has the same value as another variable Equal(Meta), - /// Variable extends the value of another by one extension - Plus(ExtensionId, Meta), + /// Variable extends the value of another by a set of extensions + Plus(ExtensionSet, Meta), } #[derive(Debug, Clone, PartialEq, Error)] @@ -235,26 +235,6 @@ impl UnificationContext { self.solved.get(&self.resolve(*m)) } - /// Convert an extension *set* difference in terms of a sequence of fresh - /// metas with `Plus` constraints which each add only one extension req. - fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) { - let mut last_meta = input; - // Create fresh metavariables with `Plus` constraints for - // each extension that should be added by the node - // Hence a extension delta [A, B] would lead to - // > ma = fresh_meta() - // > add_constraint(ma, Plus(a, input) - // > mb = fresh_meta() - // > add_constraint(mb, Plus(b, ma) - // > add_constraint(output, Equal(mb)) - for r in delta.0.into_iter() { - let curr_meta = self.fresh_meta(); - self.add_constraint(curr_meta, Constraint::Plus(r, last_meta)); - last_meta = curr_meta; - } - self.add_constraint(output, Constraint::Equal(last_meta)); - } - /// Return the metavariable corresponding to the given location on the /// graph, either by making a new meta, or looking it up fn make_or_get_meta(&mut self, node: Node, dir: Direction) -> Meta { @@ -316,11 +296,13 @@ impl UnificationContext { match node_type.signature() { // Input extensions are open None => { - self.gen_union_constraint( - m_input, - m_output, - node_type.op_signature().extension_reqs, - ); + let delta = node_type.op_signature().extension_reqs; + let c = if delta.is_empty() { + Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) + }; + self.add_constraint(m_output, c); if matches!( node_type.tag(), OpTag::Alias | OpTag::Function | OpTag::FuncDefn @@ -530,8 +512,7 @@ impl UnificationContext { // to a set which already contained it. Constraint::Plus(r, other_meta) => { if let Some(rs) = self.get_solution(other_meta) { - let mut rrs = rs.clone(); - rrs.insert(r); + let rrs = rs.clone().union(r); match self.get_solution(&meta) { // Let's check that this is right? Some(rs) => { @@ -693,19 +674,19 @@ impl UnificationContext { // Handle the case where the constraints for `m` contain a self // reference, i.e. "m = Plus(E, m)", in which case the variable // should be instantiated to E rather than the empty set. - let solution = - ExtensionSet::from_iter(self.get_constraints(&m).unwrap().iter().filter_map( - |c| match c { - // If `m` has been merged, [`self.variables`] entry - // will have already been updated to the merged - // value by [`self.merge_equal_metas`] so we don't - // need to worry about resolving it. - Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => { - Some(x.clone()) - } - _ => None, - }, - )); + let solution = self + .get_constraints(&m) + .unwrap() + .iter() + .filter_map(|c| match c { + // If `m` has been merged, [`self.variables`] entry + // will have already been updated to the merged + // value by [`self.merge_equal_metas`] so we don't + // need to worry about resolving it. + Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x), + _ => None, + }) + .fold(ExtensionSet::new(), ExtensionSet::union); self.add_solution(m, solution); } } @@ -719,6 +700,7 @@ mod test { use super::*; use crate::builder::test::closed_dfg_root_hugr; + use crate::extension::ExtensionId; use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; use crate::hugr::HugrError; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; @@ -837,8 +819,14 @@ mod test { ctx.solved.insert(metas[2], ExtensionSet::singleton(&A)); ctx.add_constraint(metas[1], Constraint::Equal(metas[2])); - ctx.add_constraint(metas[0], Constraint::Plus(B, metas[2])); - ctx.add_constraint(metas[4], Constraint::Plus(C, metas[0])); + ctx.add_constraint( + metas[0], + Constraint::Plus(ExtensionSet::singleton(&B), metas[2]), + ); + ctx.add_constraint( + metas[4], + Constraint::Plus(ExtensionSet::singleton(&C), metas[0]), + ); ctx.add_constraint(metas[3], Constraint::Equal(metas[4])); ctx.add_constraint(metas[5], Constraint::Equal(metas[0])); ctx.main_loop()?; @@ -911,8 +899,8 @@ mod test { .insert((NodeIndex::new(4).into(), Direction::Incoming), ab); ctx.variables.insert(a); ctx.variables.insert(b); - ctx.add_constraint(ab, Constraint::Plus(A, b)); - ctx.add_constraint(ab, Constraint::Plus(B, a)); + ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b)); + ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a)); let solution = ctx.main_loop()?; // We'll only find concrete solutions for the Incoming extension reqs of // the main node created by `Hugr::default` diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 3918fd8ae..d07b97f62 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,7 +1,5 @@ //! Conversions between integer and floating-point values. -use std::collections::HashSet; - use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, type_row, @@ -39,10 +37,10 @@ fn itof_sig(arg_values: &[TypeArg]) -> Result { pub fn extension() -> Extension { let mut extension = Extension::new_with_reqs( EXTENSION_ID, - ExtensionSet::new_from_extensions(HashSet::from_iter(vec![ + ExtensionSet::from_iter(vec![ super::int_types::EXTENSION_ID, super::float_types::EXTENSION_ID, - ])), + ]), ); extension