Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Values (and hence Consts) know their extensions #733

Merged
merged 12 commits into from
Dec 12, 2023
12 changes: 6 additions & 6 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ pub(crate) mod test {
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down Expand Up @@ -887,8 +887,8 @@ pub(crate) mod test {
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?,
Expand Down Expand Up @@ -929,8 +929,8 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
separate_headers: bool,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down
28 changes: 8 additions & 20 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,8 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(
&mut self,
constant: ops::Const,
extensions: impl Into<Option<ExtensionSet>>,
) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, extensions.into()))?;
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?;

Ok(const_n.into())
}
Expand Down Expand Up @@ -356,20 +352,16 @@ pub trait Dataflow: Container {
fn load_const(&mut self, cid: &ConstID) -> Result<Wire, BuildError> {
let const_node = cid.node();
let nodetype = self.hugr().get_nodetype(const_node);
let input_extensions = nodetype.input_extensions().cloned();
let op: ops::Const = nodetype
.op()
.clone()
.try_into()
.expect("ConstID does not refer to Const op.");

let load_n = self.add_dataflow_node(
NodeType::new(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
input_extensions,
),
let load_n = self.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
)?;
Expand All @@ -382,12 +374,8 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(
&mut self,
constant: ops::Const,
extensions: ExtensionSet,
) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant, extensions)?;
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}

Expand Down
8 changes: 3 additions & 5 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ mod test {
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.add_load_const(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let c = middle_b.add_load_const(ops::Const::unary_unit_sum())?;
let [inw] = middle_b.input_wires_arr();
middle_b.finish_with_outputs(c, [inw])?
};
Expand All @@ -398,8 +398,7 @@ mod test {
#[test]
fn test_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_variants = vec![type_row![]];

let mut entry_b =
Expand Down Expand Up @@ -427,8 +426,7 @@ mod test {
#[test]
fn test_non_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_variants = vec![type_row![]];
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ mod test {
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;
let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?;
let tru_const = fbuild.add_constant(Const::true_val())?;
let _fdef = {
let const_wire = fbuild.load_const(&tru_const)?;
let [int] = fbuild.input_wires_arr();
Expand Down
13 changes: 3 additions & 10 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(
ConstUsize::new(1).into(),
ExtensionSet::singleton(&PRELUDE_ID),
)?;
let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -148,8 +145,7 @@ mod test {
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
let signature = loop_b.loop_signature()?.clone();
let const_val = Const::true_val();
let const_wire =
loop_b.add_load_const(Const::true_val(), ExtensionSet::new())?;
let const_wire = loop_b.add_load_const(Const::true_val())?;
let lift_node = loop_b.add_dataflow_op(
ops::LeafOp::Lift {
type_row: vec![const_val.const_type().clone()].into(),
Expand Down Expand Up @@ -177,10 +173,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(
ConstUsize::new(2).into(),
ExtensionSet::singleton(&PRELUDE_ID),
)?;
let wire = branch_1.add_load_const(ConstUsize::new(2).into())?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
10 changes: 10 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ impl ExtensionSet {
self
}

/// Returns the union of an arbitrary collection of [ExtensionSet]s
pub fn union_over(sets: impl IntoIterator<Item = Self>) -> Self {
// `union` clones the receiver, which we do not need to do here
let mut res = ExtensionSet::new();
for s in sets {
res.0.extend(s.0)
}
res
}

/// The things in other which are in not in self
pub fn missing_from(&self, other: &Self) -> Self {
ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
Expand Down
16 changes: 6 additions & 10 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,11 @@ impl UnificationContext {
match node_type.io_extensions() {
// Input extensions are open
None => {
let c = if let Some(sig) = node_type.op_signature() {
let delta = sig.extension_reqs;
if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
}
} else {
let delta = node_type.op().extension_delta();
let c = if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
};
self.add_constraint(m_output, c);
}
Expand Down Expand Up @@ -703,7 +699,7 @@ impl UnificationContext {
});

let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip();
let solution = rs.iter().fold(ExtensionSet::new(), ExtensionSet::union);
let solution = ExtensionSet::union_over(rs);
let unresolved_metas = other_ms
.into_iter()
.filter(|other_m| m != *other_m)
Expand Down Expand Up @@ -731,7 +727,7 @@ impl UnificationContext {
Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)),
Constraint::Equal(_) => None,
})
.fold(ExtensionSet::new(), |a, b| a.union(b));
.fold(ExtensionSet::new(), ExtensionSet::union);

for m in cc.iter() {
self.add_solution(*m, combined_solution.clone());
Expand Down
6 changes: 5 additions & 1 deletion src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
Extension,
};

use super::{ExtensionRegistry, SignatureError, SignatureFromArgs};
use super::{ExtensionRegistry, ExtensionSet, SignatureError, SignatureFromArgs};
struct ArrayOpCustom;

const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()];
Expand Down Expand Up @@ -181,6 +181,10 @@ impl CustomConst for ConstUsize {
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::values::downcast_equal_consts(self, other)
}

fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::singleton(&PRELUDE_ID)
}
}

impl KnownTypeConst for ConstUsize {
Expand Down
13 changes: 3 additions & 10 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,9 @@ impl NodeType {
/// `None`` if the [Self::input_extensions] is `None`.
/// Otherwise, will return Some, with the output extensions computed from the node's delta
pub fn io_extensions(&self) -> Option<(&ExtensionSet, ExtensionSet)> {
self.input_extensions.as_ref().map(|e| {
(
e,
self.op
.dataflow_signature()
.map(|ft| ft.extension_reqs)
.unwrap_or_default()
.union(e),
)
})
self.input_extensions
.as_ref()
.map(|e| (e, self.op.extension_delta().union(e)))
}

/// Gets the underlying [OpType] i.e. without any [input_extensions]
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Rewrite for OutlineCfg {
.unwrap();
let cfg = cfg.finish_sub_container().unwrap();
let unit_sum = new_block_bldr
.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())
.add_constant(ops::Const::unary_unit_sum())
.unwrap();
let pred_wire = new_block_bldr.load_const(&unit_sum).unwrap();
new_block_bldr
Expand Down
7 changes: 4 additions & 3 deletions src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,15 @@ mod test {
.unwrap()
.into();
let just_list = TypeRow::from(vec![listy.clone()]);
let exset = ExtensionSet::singleton(&collections::EXTENSION_NAME);
let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]);

let mut cfg = CFGBuilder::new(
FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset),
// One might expect an extension_delta of "collections" here, but push/pop
// have an empty delta themselves, pending https://github.com/CQCL/hugr/issues/388
FunctionType::new_endo(just_list.clone()),
)?;

let pred_const = cfg.add_constant(ops::Const::unary_unit_sum(), None)?;
let pred_const = cfg.add_constant(ops::Const::unary_unit_sum())?;

let entry = single_node_block(&mut cfg, pop, &pred_const, true)?;
let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?;
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {
let empty_list = Value::Extension {
c: (Box::new(collections::ListValue::new(vec![])),),
};
let cst = def.add_load_const(Const::new(empty_list, list_of_var)?, just_colns)?;
let cst = def.add_load_const(Const::new(empty_list, list_of_var)?)?;
let res = def.finish_hugr_with_outputs([cst], &reg);
assert_matches!(
res.unwrap_err(),
Expand Down
18 changes: 11 additions & 7 deletions src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,22 @@ fn value_types() {
#[rustversion::since(1.75)] // uses impl in return position
#[test]
fn static_targets() {
use crate::extension::prelude::{ConstUsize, USIZE_T};
use crate::extension::{
prelude::{ConstUsize, PRELUDE_ID, USIZE_T},
ExtensionSet,
};
use itertools::Itertools;
let mut dfg = DFGBuilder::new(
FunctionType::new(type_row![], type_row![USIZE_T])
.with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID)),
)
.unwrap();

let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap();

let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap();
let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap();

let load = dfg.load_const(&c).unwrap();

let h = dfg
.finish_hugr_with_outputs([load], &crate::extension::PRELUDE_REGISTRY)
.unwrap();
let h = dfg.finish_prelude_hugr_with_outputs([load]).unwrap();

assert_eq!(h.static_source(load.node()), Some(c.node()));

Expand Down
8 changes: 8 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod leaf;
pub mod module;
pub mod tag;
pub mod validate;
use crate::extension::ExtensionSet;
use crate::types::{EdgeKind, FunctionType, Type};
use crate::{Direction, OutgoingPort, Port};
use crate::{IncomingPort, PortIndex};
Expand Down Expand Up @@ -278,6 +279,13 @@ pub trait OpTrait {
fn dataflow_signature(&self) -> Option<FunctionType> {
None
}

/// The delta between the input extensions specified for a node,
/// and the output extensions calculated for that node
fn extension_delta(&self) -> ExtensionSet {
ExtensionSet::new()
}

/// The edge kind for the non-dataflow or constant inputs of the operation,
/// not described by the signature.
///
Expand Down
Loading