From 555f618b68f9d5d6734a3eb1d28e7df148322d90 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 12:43:48 +0000 Subject: [PATCH 01/10] Implement Hash for ExtensionSet --- src/extension.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/extension.rs b/src/extension.rs index aec4f6020..c94621891 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -3,9 +3,10 @@ //! TODO: YAML declaration and parsing. This should be similar to a plugin //! system (outside the `types` module), which also parses nested [`OpDef`]s. -use std::collections::hash_map::Entry; +use std::collections::hash_map::{DefaultHasher, Entry}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Debug, Display, Formatter}; +use std::hash::{BuildHasher, BuildHasherDefault}; use std::sync::Arc; use smol_str::SmolStr; @@ -304,6 +305,20 @@ pub enum ExtensionBuildError { #[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct ExtensionSet(HashSet); +impl std::hash::Hash for ExtensionSet { + fn hash(&self, state: &mut H) { + // Hash each item individually and combine with a *weak* hash combiner + // (i.e. that is associative and commutative, hence item iteration order doesn't matter). + // Here we just use xor. + let item_h = BuildHasherDefault::::default(); + self.0 + .iter() + .map(|e_id| item_h.hash_one(e_id)) + .fold(0, |a, b| a ^ b) + .hash(state); + } +} + impl ExtensionSet { /// Creates a new empty extension set. pub fn new() -> Self { From aed077f124cce643bce41ef93bb3676902102776 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 12:45:24 +0000 Subject: [PATCH 02/10] Make Constraint::Plus take an ExtensionSet, but generate Equal if set empty --- src/extension/infer.rs | 75 ++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index db0b66694..fa503addd 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -10,7 +10,7 @@ //! depend on these open variables, then the validation check for extensions //! will succeed regardless of what the variable is instantiated to. -use super::{ExtensionId, ExtensionSet}; +use super::ExtensionSet; use crate::{ hugr::views::HugrView, ops::{OpTag, OpTrait}, @@ -65,8 +65,8 @@ impl Meta { enum Constraint { /// A variable has the same value as another variable Equal(Meta), - /// Variable extends the value of another by one extension - Plus(ExtensionId, Meta), + /// Variable extends the value of another by a set of extensions + Plus(ExtensionSet, Meta), } #[derive(Debug, Clone, PartialEq, Error)] @@ -230,24 +230,15 @@ impl UnificationContext { self.solved.get(&self.resolve(*m)) } - /// Convert an extension *set* difference in terms of a sequence of fresh - /// metas with `Plus` constraints which each add only one extension req. fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) { - let mut last_meta = input; - // Create fresh metavariables with `Plus` constraints for - // each extension that should be added by the node - // Hence a extension delta [A, B] would lead to - // > ma = fresh_meta() - // > add_constraint(ma, Plus(a, input) - // > mb = fresh_meta() - // > add_constraint(mb, Plus(b, ma) - // > add_constraint(output, Equal(mb)) - for r in delta.0.into_iter() { - let curr_meta = self.fresh_meta(); - self.add_constraint(curr_meta, Constraint::Plus(r, last_meta)); - last_meta = curr_meta; - } - self.add_constraint(output, Constraint::Equal(last_meta)); + self.add_constraint( + output, + if delta.is_subset(&ExtensionSet::new()) { + Constraint::Equal(input) + } else { + Constraint::Plus(delta, input) + }, + ); } /// Return the metavariable corresponding to the given location on the @@ -510,8 +501,7 @@ impl UnificationContext { // to a set which already contained it. Constraint::Plus(r, other_meta) => { if let Some(rs) = self.get_solution(other_meta) { - let mut rrs = rs.clone(); - rrs.insert(r); + let rrs = rs.clone().union(r); match self.get_solution(&meta) { // Let's check that this is right? Some(rs) => { @@ -664,19 +654,19 @@ impl UnificationContext { // Handle the case where the constraints for `m` contain a self // reference, i.e. "m = Plus(E, m)", in which case the variable // should be instantiated to E rather than the empty set. - let solution = - ExtensionSet::from_iter(self.get_constraints(&m).unwrap().iter().filter_map( - |c| match c { - // If `m` has been merged, [`self.variables`] entry - // will have already been updated to the merged - // value by [`self.merge_equal_metas`] so we don't - // need to worry about resolving it. - Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => { - Some(x.clone()) - } - _ => None, - }, - )); + let solution = self + .get_constraints(&m) + .unwrap() + .iter() + .filter_map(|c| match c { + // If `m` has been merged, [`self.variables`] entry + // will have already been updated to the merged + // value by [`self.merge_equal_metas`] so we don't + // need to worry about resolving it. + Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x), + _ => None, + }) + .fold(ExtensionSet::new(), ExtensionSet::union); self.add_solution(m, solution); } } @@ -690,6 +680,7 @@ mod test { use super::*; use crate::builder::test::closed_dfg_root_hugr; + use crate::extension::ExtensionId; use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; @@ -807,8 +798,14 @@ mod test { ctx.solved.insert(metas[2], ExtensionSet::singleton(&A)); ctx.add_constraint(metas[1], Constraint::Equal(metas[2])); - ctx.add_constraint(metas[0], Constraint::Plus(B, metas[2])); - ctx.add_constraint(metas[4], Constraint::Plus(C, metas[0])); + ctx.add_constraint( + metas[0], + Constraint::Plus(ExtensionSet::singleton(&B), metas[2]), + ); + ctx.add_constraint( + metas[4], + Constraint::Plus(ExtensionSet::singleton(&C), metas[0]), + ); ctx.add_constraint(metas[3], Constraint::Equal(metas[4])); ctx.add_constraint(metas[5], Constraint::Equal(metas[0])); ctx.main_loop()?; @@ -881,8 +878,8 @@ mod test { .insert((NodeIndex::new(4).into(), Direction::Incoming), ab); ctx.variables.insert(a); ctx.variables.insert(b); - ctx.add_constraint(ab, Constraint::Plus(A, b)); - ctx.add_constraint(ab, Constraint::Plus(B, a)); + ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b)); + ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a)); let solution = ctx.main_loop()?; // We'll only find concrete solutions for the Incoming extension reqs of // the main node created by `Hugr::default` From a0471593208ac5d2a0305ca7cc380e29570a2e77 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 12:48:06 +0000 Subject: [PATCH 03/10] Add ExtensionSet::is_empty --- src/extension.rs | 5 +++++ src/extension/infer.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/extension.rs b/src/extension.rs index c94621891..5f7052277 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -372,6 +372,11 @@ impl ExtensionSet { pub fn iter(&self) -> impl Iterator { self.0.iter() } + + /// True if this set contains no [ExtensionId]s + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } impl Display for ExtensionSet { diff --git a/src/extension/infer.rs b/src/extension/infer.rs index fa503addd..c0f4db047 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -233,7 +233,7 @@ impl UnificationContext { fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) { self.add_constraint( output, - if delta.is_subset(&ExtensionSet::new()) { + if delta.is_empty() { Constraint::Equal(input) } else { Constraint::Plus(delta, input) From 7adc30b566c83b8609d752050d14d67b2a5d7155 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 12:51:19 +0000 Subject: [PATCH 04/10] Inline gen_union_constraint --- src/extension/infer.rs | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index c0f4db047..d67a4d74c 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -230,17 +230,6 @@ impl UnificationContext { self.solved.get(&self.resolve(*m)) } - fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) { - self.add_constraint( - output, - if delta.is_empty() { - Constraint::Equal(input) - } else { - Constraint::Plus(delta, input) - }, - ); - } - /// Return the metavariable corresponding to the given location on the /// graph, either by making a new meta, or looking it up fn make_or_get_meta(&mut self, node: Node, dir: Direction) -> Meta { @@ -302,11 +291,13 @@ impl UnificationContext { match node_type.signature() { // Input extensions are open None => { - self.gen_union_constraint( - m_input, - m_output, - node_type.op_signature().extension_reqs, - ); + let delta = node_type.op_signature().extension_reqs; + let c = if delta.is_empty() { + Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) + }; + self.add_constraint(m_output, c); if matches!( node_type.tag(), OpTag::Alias | OpTag::Function | OpTag::FuncDefn From a50b9629bda46d6800fa91bbe7fe5dcafa528891 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 13:04:11 +0000 Subject: [PATCH 05/10] Ach, try to work out feature being unstable in 1.70 --- src/extension.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/extension.rs b/src/extension.rs index 5f7052277..db7ad25fe 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -6,7 +6,7 @@ use std::collections::hash_map::{DefaultHasher, Entry}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Debug, Display, Formatter}; -use std::hash::{BuildHasher, BuildHasherDefault}; +use std::hash::{BuildHasher, BuildHasherDefault, Hasher}; use std::sync::Arc; use smol_str::SmolStr; @@ -313,7 +313,11 @@ impl std::hash::Hash for ExtensionSet { let item_h = BuildHasherDefault::::default(); self.0 .iter() - .map(|e_id| item_h.hash_one(e_id)) + .map(|e_id| { + let mut h = item_h.build_hasher(); + e_id.hash(&mut h); + h.finish() + }) .fold(0, |a, b| a ^ b) .hash(state); } From d1cad80fbb875f5eba96ecd174c369dab236a7ee Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 13:36:43 +0000 Subject: [PATCH 06/10] Given we can't use hash_one (prev. commit), no point in BuildHasherDefault etc. --- src/extension.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/extension.rs b/src/extension.rs index db7ad25fe..ddd2e8a34 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -6,7 +6,7 @@ use std::collections::hash_map::{DefaultHasher, Entry}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Debug, Display, Formatter}; -use std::hash::{BuildHasher, BuildHasherDefault, Hasher}; +use std::hash::Hasher; use std::sync::Arc; use smol_str::SmolStr; @@ -310,11 +310,10 @@ impl std::hash::Hash for ExtensionSet { // Hash each item individually and combine with a *weak* hash combiner // (i.e. that is associative and commutative, hence item iteration order doesn't matter). // Here we just use xor. - let item_h = BuildHasherDefault::::default(); self.0 .iter() .map(|e_id| { - let mut h = item_h.build_hasher(); + let mut h = DefaultHasher::new(); e_id.hash(&mut h); h.finish() }) From 828f70a4ba42d6a00e1acffd99c7f260e510cfdc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 13:46:29 +0000 Subject: [PATCH 07/10] Drop ExtensionSet::new_from_extensions - use from_iter --- src/extension.rs | 5 ----- src/std_extensions/arithmetic/conversions.rs | 6 ++---- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/extension.rs b/src/extension.rs index ddd2e8a34..780dcc501 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -328,11 +328,6 @@ impl ExtensionSet { Self(HashSet::new()) } - /// Creates a new extension set from some extensions. - pub fn new_from_extensions(extensions: impl Into>) -> Self { - Self(extensions.into()) - } - /// Adds a extension to the set. pub fn insert(&mut self, extension: &ExtensionId) { self.0.insert(extension.clone()); diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 3918fd8ae..d07b97f62 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,7 +1,5 @@ //! Conversions between integer and floating-point values. -use std::collections::HashSet; - use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, type_row, @@ -39,10 +37,10 @@ fn itof_sig(arg_values: &[TypeArg]) -> Result { pub fn extension() -> Extension { let mut extension = Extension::new_with_reqs( EXTENSION_ID, - ExtensionSet::new_from_extensions(HashSet::from_iter(vec![ + ExtensionSet::from_iter(vec![ super::int_types::EXTENSION_ID, super::float_types::EXTENSION_ID, - ])), + ]), ); extension From 04d8a7056f1f3726893863e8e9c68fbf03e34cbe Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 13:48:15 +0000 Subject: [PATCH 08/10] ExtensionSet::missing_from use from_iter --- src/extension.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extension.rs b/src/extension.rs index 780dcc501..ec12ec572 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -363,7 +363,7 @@ impl ExtensionSet { /// The things in other which are in not in self pub fn missing_from(&self, other: &Self) -> Self { - ExtensionSet(HashSet::from_iter(other.0.difference(&self.0).cloned())) + ExtensionSet::from_iter(other.0.difference(&self.0).cloned()) } /// Iterate over the contained ExtensionIds From bbde64e38e324ad2121b861a0d111587cdbd05c9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 13:49:22 +0000 Subject: [PATCH 09/10] ExtensionSet: use BTreeSet not HashSet, derive Hash --- src/extension.rs | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/src/extension.rs b/src/extension.rs index ec12ec572..84bb447df 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -3,10 +3,9 @@ //! TODO: YAML declaration and parsing. This should be similar to a plugin //! system (outside the `types` module), which also parses nested [`OpDef`]s. -use std::collections::hash_map::{DefaultHasher, Entry}; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::hash_map::Entry; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::fmt::{Debug, Display, Formatter}; -use std::hash::Hasher; use std::sync::Arc; use smol_str::SmolStr; @@ -302,30 +301,13 @@ pub enum ExtensionBuildError { } /// A set of extensions identified by their unique [`ExtensionId`]. -#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -pub struct ExtensionSet(HashSet); - -impl std::hash::Hash for ExtensionSet { - fn hash(&self, state: &mut H) { - // Hash each item individually and combine with a *weak* hash combiner - // (i.e. that is associative and commutative, hence item iteration order doesn't matter). - // Here we just use xor. - self.0 - .iter() - .map(|e_id| { - let mut h = DefaultHasher::new(); - e_id.hash(&mut h); - h.finish() - }) - .fold(0, |a, b| a ^ b) - .hash(state); - } -} +#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct ExtensionSet(BTreeSet); impl ExtensionSet { /// Creates a new empty extension set. pub fn new() -> Self { - Self(HashSet::new()) + Self(BTreeSet::new()) } /// Adds a extension to the set. @@ -385,6 +367,6 @@ impl Display for ExtensionSet { impl FromIterator for ExtensionSet { fn from_iter>(iter: I) -> Self { - Self(HashSet::from_iter(iter)) + Self(BTreeSet::from_iter(iter)) } } From 0e077ac5126c2204cbedceb6b1314a52f71e916f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 1 Nov 2023 13:51:05 +0000 Subject: [PATCH 10/10] Driveby: make new() return const --- src/extension.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extension.rs b/src/extension.rs index 84bb447df..dd7401cfe 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -306,7 +306,7 @@ pub struct ExtensionSet(BTreeSet); impl ExtensionSet { /// Creates a new empty extension set. - pub fn new() -> Self { + pub const fn new() -> Self { Self(BTreeSet::new()) }