Skip to content

Commit

Permalink
cherry-pick #636
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Nov 1, 2023
1 parent 5950723 commit 568df82
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 63 deletions.
24 changes: 12 additions & 12 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! system (outside the `types` module), which also parses nested [`OpDef`]s.
use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;

Expand Down Expand Up @@ -301,18 +301,13 @@ pub enum ExtensionBuildError {
}

/// A set of extensions identified by their unique [`ExtensionId`].
#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(HashSet<ExtensionId>);
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

impl ExtensionSet {
/// Creates a new empty extension set.
pub fn new() -> Self {
Self(HashSet::new())
}

/// Creates a new extension set from some extensions.
pub fn new_from_extensions(extensions: impl Into<HashSet<ExtensionId>>) -> Self {
Self(extensions.into())
pub const fn new() -> Self {
Self(BTreeSet::new())
}

/// Adds a extension to the set.
Expand Down Expand Up @@ -350,13 +345,18 @@ impl ExtensionSet {

/// The things in other which are in not in self
pub fn missing_from(&self, other: &Self) -> Self {
ExtensionSet(HashSet::from_iter(other.0.difference(&self.0).cloned()))
ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
}

/// Iterate over the contained ExtensionIds
pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.iter()
}

/// True if this set contains no [ExtensionId]s
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

impl Display for ExtensionSet {
Expand All @@ -367,6 +367,6 @@ impl Display for ExtensionSet {

impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(HashSet::from_iter(iter))
Self(BTreeSet::from_iter(iter))
}
}
82 changes: 35 additions & 47 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
//! depend on these open variables, then the validation check for extensions
//! will succeed regardless of what the variable is instantiated to.
use super::{ExtensionId, ExtensionSet};
use super::ExtensionSet;
use crate::{
hugr::views::HugrView,
ops::{OpTag, OpTrait},
Expand Down Expand Up @@ -65,8 +65,8 @@ impl Meta {
enum Constraint {
/// A variable has the same value as another variable
Equal(Meta),
/// Variable extends the value of another by one extension
Plus(ExtensionId, Meta),
/// Variable extends the value of another by a set of extensions
Plus(ExtensionSet, Meta),
}

#[derive(Debug, Clone, PartialEq, Error)]
Expand Down Expand Up @@ -235,26 +235,6 @@ impl UnificationContext {
self.solved.get(&self.resolve(*m))
}

/// Convert an extension *set* difference in terms of a sequence of fresh
/// metas with `Plus` constraints which each add only one extension req.
fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) {
let mut last_meta = input;
// Create fresh metavariables with `Plus` constraints for
// each extension that should be added by the node
// Hence a extension delta [A, B] would lead to
// > ma = fresh_meta()
// > add_constraint(ma, Plus(a, input)
// > mb = fresh_meta()
// > add_constraint(mb, Plus(b, ma)
// > add_constraint(output, Equal(mb))
for r in delta.0.into_iter() {
let curr_meta = self.fresh_meta();
self.add_constraint(curr_meta, Constraint::Plus(r, last_meta));
last_meta = curr_meta;
}
self.add_constraint(output, Constraint::Equal(last_meta));
}

/// Return the metavariable corresponding to the given location on the
/// graph, either by making a new meta, or looking it up
fn make_or_get_meta(&mut self, node: Node, dir: Direction) -> Meta {
Expand Down Expand Up @@ -316,11 +296,13 @@ impl UnificationContext {
match node_type.signature() {
// Input extensions are open
None => {
self.gen_union_constraint(
m_input,
m_output,
node_type.op_signature().extension_reqs,
);
let delta = node_type.op_signature().extension_reqs;
let c = if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
};
self.add_constraint(m_output, c);
if matches!(
node_type.tag(),
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
Expand Down Expand Up @@ -530,8 +512,7 @@ impl UnificationContext {
// to a set which already contained it.
Constraint::Plus(r, other_meta) => {
if let Some(rs) = self.get_solution(other_meta) {
let mut rrs = rs.clone();
rrs.insert(r);
let rrs = rs.clone().union(r);
match self.get_solution(&meta) {
// Let's check that this is right?
Some(rs) => {
Expand Down Expand Up @@ -693,19 +674,19 @@ impl UnificationContext {
// Handle the case where the constraints for `m` contain a self
// reference, i.e. "m = Plus(E, m)", in which case the variable
// should be instantiated to E rather than the empty set.
let solution =
ExtensionSet::from_iter(self.get_constraints(&m).unwrap().iter().filter_map(
|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => {
Some(x.clone())
}
_ => None,
},
));
let solution = self
.get_constraints(&m)
.unwrap()
.iter()
.filter_map(|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x),
_ => None,
})
.fold(ExtensionSet::new(), ExtensionSet::union);
self.add_solution(m, solution);
}
}
Expand All @@ -719,6 +700,7 @@ mod test {

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::HugrError;
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
Expand Down Expand Up @@ -837,8 +819,14 @@ mod test {

ctx.solved.insert(metas[2], ExtensionSet::singleton(&A));
ctx.add_constraint(metas[1], Constraint::Equal(metas[2]));
ctx.add_constraint(metas[0], Constraint::Plus(B, metas[2]));
ctx.add_constraint(metas[4], Constraint::Plus(C, metas[0]));
ctx.add_constraint(
metas[0],
Constraint::Plus(ExtensionSet::singleton(&B), metas[2]),
);
ctx.add_constraint(
metas[4],
Constraint::Plus(ExtensionSet::singleton(&C), metas[0]),
);
ctx.add_constraint(metas[3], Constraint::Equal(metas[4]));
ctx.add_constraint(metas[5], Constraint::Equal(metas[0]));
ctx.main_loop()?;
Expand Down Expand Up @@ -911,8 +899,8 @@ mod test {
.insert((NodeIndex::new(4).into(), Direction::Incoming), ab);
ctx.variables.insert(a);
ctx.variables.insert(b);
ctx.add_constraint(ab, Constraint::Plus(A, b));
ctx.add_constraint(ab, Constraint::Plus(B, a));
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b));
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a));
let solution = ctx.main_loop()?;
// We'll only find concrete solutions for the Incoming extension reqs of
// the main node created by `Hugr::default`
Expand Down
6 changes: 2 additions & 4 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
//! Conversions between integer and floating-point values.
use std::collections::HashSet;

use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
type_row,
Expand Down Expand Up @@ -39,10 +37,10 @@ fn itof_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
pub fn extension() -> Extension {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::new_from_extensions(HashSet::from_iter(vec![
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
])),
]),
);

extension
Expand Down

0 comments on commit 568df82

Please sign in to comment.