From 6951a2e0553a55bc795af7c9cd1546ab81b34cbe Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 13 May 2024 07:36:11 +0100 Subject: [PATCH 1/6] Add `validate` to `HugrView` --- hugr/src/hugr/views.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/hugr/src/hugr/views.rs b/hugr/src/hugr/views.rs index 5b3354dd9..323ba12c5 100644 --- a/hugr/src/hugr/views.rs +++ b/hugr/src/hugr/views.rs @@ -24,7 +24,8 @@ use itertools::{Itertools, MapInto}; use portgraph::render::{DotFormat, MermaidFormat}; use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; -use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE}; +use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE}; +use crate::extension::ExtensionRegistry; use crate::ops::handle::NodeHandle; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; @@ -460,6 +461,18 @@ pub trait HugrView: sealed::HugrInternals { self.value_types(node, Direction::Outgoing) .map(|(p, t)| (p.as_outgoing().unwrap(), t)) } + + /// Check the validity of the underlying HUGR. + fn validate(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> { + self.base_hugr().validate(reg) + } + + /// Check the validity of the underlying HUGR, but don't check consistency + /// of extension requirements between connected nodes or between parents and + /// children. + fn validate_no_extensions(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> { + self.base_hugr().validate_no_extensions(reg) + } } /// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s From 02a1a27320df326e7ac6c646573f989a97c3cf0c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 13 May 2024 09:43:01 +0100 Subject: [PATCH 2/6] Add verification to `constant_fold_pass`. Add a failing test --- hugr/src/algorithm.rs | 31 ++++ hugr/src/algorithm/const_fold.rs | 158 +++++++++++++++--- hugr/src/hugr/views.rs | 4 +- .../arithmetic/int_ops/const_fold/test.rs | 2 + 4 files changed, 171 insertions(+), 24 deletions(-) diff --git a/hugr/src/algorithm.rs b/hugr/src/algorithm.rs index 585e25a01..403632ecd 100644 --- a/hugr/src/algorithm.rs +++ b/hugr/src/algorithm.rs @@ -4,3 +4,34 @@ pub mod const_fold; mod half_node; pub mod merge_bbs; pub mod nest_cfgs; + +#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)] +/// A type for algorithms to take as configuration, specifying how much +/// verification they should do. Algorithms that accept this configuration +/// should at least verify that input HUGRs are valid, and that output HUGRs are +/// valid. +/// +/// The default level is `None` because verification can be expensive. +pub enum VerifyLevel { + /// Do no verification. + None, + /// Verify using [HugrView::validate_no_extensions]. This is useful when you + /// do not expect valid Extension annotations on Nodes. + /// + /// [HugrView::validate_no_extensions]: crate::HugrView::validate_no_extensions + WithoutExtensions, + /// Verify using [HugrView::validate]. + /// + /// [HugrView::validate]: crate::HugrView::validate + WithExtensions, +} + +impl Default for VerifyLevel { + fn default() -> Self { + if cfg!(test) { + Self::WithoutExtensions + } else { + Self::None + } + } +} diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index b32d54a8c..a50a9d164 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -3,8 +3,11 @@ use std::collections::{BTreeSet, HashMap}; use itertools::Itertools; +use thiserror::Error; +use crate::hugr::{SimpleReplacementError, ValidationError}; use crate::types::SumType; +use crate::Direction; use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, extension::{ConstFoldResult, ExtensionRegistry}, @@ -19,6 +22,89 @@ use crate::{ Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; +use super::VerifyLevel; + +#[derive(Error, Debug)] +#[allow(missing_docs)] +pub enum ConstFoldError { + #[error("Failed to verify {label} HUGR: {err}")] + VerifyError { + label: String, + #[source] + err: ValidationError, + }, + #[error(transparent)] + SimpleReplaceError(#[from] SimpleReplacementError), +} + +impl ConstFoldError { + fn verify_err(label: impl Into, err: ValidationError) -> Self { + Self::VerifyError { + label: label.into(), + err, + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +/// TODO +pub struct ConstFoldConfig { + verify: VerifyLevel, +} + +impl ConstFoldConfig { + /// TODO + pub fn new() -> Self { + Self::default() + } + + /// TODO + pub fn with_verify(mut self, verify: VerifyLevel) -> Self { + self.verify = verify; + self + } + + fn verify_impl( + &self, + label: &str, + h: &impl HugrView, + reg: &ExtensionRegistry, + ) -> Result<(), ConstFoldError> { + match self.verify { + VerifyLevel::None => Ok(()), + VerifyLevel::WithoutExtensions => h.validate_no_extensions(reg), + VerifyLevel::WithExtensions => h.validate(reg), + } + .map_err(|err| ConstFoldError::verify_err(label, err)) + } + + /// TODO + pub fn run(&self, h: &mut impl HugrMut, reg: &ExtensionRegistry) -> Result<(), ConstFoldError> { + self.verify_impl("input", h, reg)?; + loop { + // would be preferable if the candidates were updated to be just the + // neighbouring nodes of those added. + let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); + if rewrites.is_empty() { + break; + } + for (replace, removes) in rewrites { + h.apply_rewrite(replace)?; + for rem in removes { + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + if h.apply_rewrite(RemoveConst(const_node)).is_err() { + // const cannot be removed - no problem + continue; + } + } + } + } + } + self.verify_impl("output", h, reg) + } +} + /// Tag some output constants with [`OutgoingPort`] inferred from the ordering. fn out_row(consts: impl IntoIterator) -> ConstFoldResult { let vec = consts @@ -43,9 +129,10 @@ pub(crate) fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { .map(|(_, c)| c) .collect() } + /// For a given op and consts, attempt to evaluate the op. pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - match op { + let fold_result = match op { OpType::Noop { .. } => out_row([consts.first()?.1.clone()]), OpType::MakeTuple { .. } => { out_row([Value::tuple(sorted_consts(consts).into_iter().cloned())]) @@ -69,7 +156,10 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR ext_op.constant_fold(consts) } _ => None, - } + }; + assert!(fold_result.as_ref().map_or(true, |x| x.len() + == op.value_port_count(Direction::Outgoing))); + fold_result } /// Generate a graph that loads and outputs `consts` in order, validating @@ -184,27 +274,8 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< } /// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { - loop { - // would be preferable if the candidates were updated to be just the - // neighbouring nodes of those added. - let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); - if rewrites.is_empty() { - break; - } - for (replace, removes) in rewrites { - h.apply_rewrite(replace).unwrap(); - for rem in removes { - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - if h.apply_rewrite(RemoveConst(const_node)).is_err() { - // const cannot be removed - no problem - continue; - } - } - } - } - } +pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { + ConstFoldConfig::default().run(h, reg).unwrap() } #[cfg(test)] @@ -395,4 +466,45 @@ mod test { let expected = Value::false_val(); assert_fully_folded(&h, &expected); } + + #[test] + #[should_panic] + fn orphan_output() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // x2 := or(x0,x1) + // output x2 == true; + // + // We arange things so that the `or` folds away first, leaving the not + // with no outputs. + use crate::hugr::NodeType; + use crate::ops::handle::NodeHandle; + + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let true_wire = build.add_load_value(Value::true_val()); + // this Not will be manually replaced + let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); + let r = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + [true_wire, orig_not.out_wire(0)], + ) + .unwrap(); + let or_node = r.node(); + let parent = build.dfg_node; + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + + // we delete the original Not and create a new One. This means it will be + // traversed by `constant_fold_pass` after the Or. + let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); + h.connect(true_wire.node(), true_wire.source(), new_not, 0); + h.disconnect(or_node, IncomingPort::from(1)); + h.connect(new_not, 0, or_node, 1); + h.remove_node(orig_not.node()); + constant_fold_pass(&mut h, ®); + assert_fully_folded(&h, &Value::true_val()) + } } diff --git a/hugr/src/hugr/views.rs b/hugr/src/hugr/views.rs index 323ba12c5..a22cd309c 100644 --- a/hugr/src/hugr/views.rs +++ b/hugr/src/hugr/views.rs @@ -24,7 +24,9 @@ use itertools::{Itertools, MapInto}; use portgraph::render::{DotFormat, MermaidFormat}; use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; -use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE}; +use super::{ + Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE, +}; use crate::extension::ExtensionRegistry; use crate::ops::handle::NodeHandle; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs index 5241bf230..6984783c3 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs @@ -62,6 +62,7 @@ fn test_fold_iwiden_s() { } #[test] +#[should_panic] fn test_fold_inarrow_u() { // pseudocode: // @@ -90,6 +91,7 @@ fn test_fold_inarrow_u() { } #[test] +#[should_panic] fn test_fold_inarrow_s() { // pseudocode: // From b2c815b6cf10bcf4a5958ed8fdbf1fadfa19957c Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 13 May 2024 10:11:20 +0100 Subject: [PATCH 3/6] fix constant folding using invalid rewrites --- hugr/src/algorithm/const_fold.rs | 59 +++++++++++++++----------------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index a50a9d164..ce4466b55 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -53,12 +53,12 @@ pub struct ConstFoldConfig { } impl ConstFoldConfig { - /// TODO + /// Create a new `ConstFoldConfig` with default configuration. pub fn new() -> Self { Self::default() } - /// TODO + /// Build a `ConstFoldConfig` with the given [VerifyLevel]. pub fn with_verify(mut self, verify: VerifyLevel) -> Self { self.verify = verify; self @@ -78,26 +78,27 @@ impl ConstFoldConfig { .map_err(|err| ConstFoldError::verify_err(label, err)) } - /// TODO + /// Run the Constant Folding pass. pub fn run(&self, h: &mut impl HugrMut, reg: &ExtensionRegistry) -> Result<(), ConstFoldError> { self.verify_impl("input", h, reg)?; loop { - // would be preferable if the candidates were updated to be just the - // neighbouring nodes of those added. - let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); - if rewrites.is_empty() { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { break; - } - for (replace, removes) in rewrites { - h.apply_rewrite(replace)?; - for rem in removes { - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - if h.apply_rewrite(RemoveConst(const_node)).is_err() { - // const cannot be removed - no problem - continue; - } - } + }; + h.apply_rewrite(replace)?; + for rem in removes { + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = h.apply_rewrite(RemoveConst(const_node)); } } } @@ -230,18 +231,16 @@ fn fold_op( }) .unzip(); // attempt to evaluate op - let folded = fold_leaf_op(neighbour_op, &in_consts)?; - let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); - let nu_out = op_outs + let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? .into_iter() .enumerate() - .filter_map(|(i, out)| { - // map from the ports the op was linked to, to the output ports of - // the replacement. - hugr.single_linked_input(op_node, out) - .map(|np| (np, i.into())) + .filter_map(|(i, (op_out, konst))| { + // for each used port of the op give the nu_out entry and the + // corresponding Value + hugr.single_linked_input(op_node, op_out) + .map(|np| ((np, i.into()), konst)) }) - .collect(); + .unzip(); let replacement = const_graph(consts, reg); let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) .expect("Operation should form valid subgraph."); @@ -262,11 +261,8 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; let load_op = hugr.get_optype(load_n).as_load_constant()?; let const_node = hugr - .linked_outputs(load_n, load_op.constant_port()) - .exactly_one() - .ok()? + .single_linked_output(load_n, load_op.constant_port())? .0; - let const_op = hugr.get_optype(const_node).as_const()?; // TODO avoid const clone here @@ -468,7 +464,6 @@ mod test { } #[test] - #[should_panic] fn orphan_output() { // pseudocode: // x0 := bool(true) From 862419f8456876f71d529702fa6b16d23a72a468 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 13 May 2024 10:22:31 +0100 Subject: [PATCH 4/6] fix inarrow const folding tests --- .../arithmetic/int_ops/const_fold.rs | 49 +++++++-------- .../arithmetic/int_ops/const_fold/test.rs | 62 ++++++++----------- 2 files changed, 49 insertions(+), 62 deletions(-) diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs index 0915a4737..8738e1872 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -16,6 +16,16 @@ use crate::{ use super::IntOpDef; +use lazy_static::lazy_static; + +lazy_static! { + static ref INARROW_ERROR_VALUE: Value = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + } + .into(); +} + fn bitmask_from_width(width: u64) -> u64 { debug_assert!(width <= 64); if width == 64 { @@ -111,28 +121,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { let logwidth0: u8 = get_log_width(arg0).ok()?; let logwidth1: u8 = get_log_width(arg1).ok()?; let n0: &ConstInt = get_single_input_value(consts)?; + (logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?; let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); let sum_type = sum_with_error(int_out_type.clone()); - let err_value = || { - let err_val = ConstError { - signal: 0, - message: "Integer too large to narrow".to_string(), - }; - Value::sum(1, [err_val.into()], sum_type.clone()) + + let mk_out_const = |i, mb_v: Result| { + mb_v.and_then(|v| Value::sum(i, [v], sum_type)) .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) }; let n0val: u64 = n0.value_u(); let out_const: Value = if n0val >> (1 << logwidth1) != 0 { - err_value() + mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone())) } else { - Value::extension(ConstInt::new_u(logwidth1, n0val).unwrap()) + mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into)) }; - if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { - None - } else { - Some(vec![(0.into(), out_const)]) - } + Some(vec![(0.into(), out_const)]) }, ), }, @@ -145,29 +149,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { let logwidth0: u8 = get_log_width(arg0).ok()?; let logwidth1: u8 = get_log_width(arg1).ok()?; let n0: &ConstInt = get_single_input_value(consts)?; + (logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?; let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); let sum_type = sum_with_error(int_out_type.clone()); - let err_value = || { - let err_val = ConstError { - signal: 0, - message: "Integer too large to narrow".to_string(), - }; - Value::sum(1, [err_val.into()], sum_type.clone()) + let mk_out_const = |i, mb_v: Result| { + mb_v.and_then(|v| Value::sum(i, [v], sum_type)) .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) }; let n0val: i64 = n0.value_s(); let ub = 1i64 << ((1 << logwidth1) - 1); let out_const: Value = if n0val >= ub || n0val < -ub { - err_value() + mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone())) } else { - Value::extension(ConstInt::new_s(logwidth1, n0val).unwrap()) + mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into)) }; - if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { - None - } else { - Some(vec![(0.into(), out_const)]) - } + Some(vec![(0.into(), out_const)]) }, ), }, diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs index 6984783c3..af7e7e75b 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs @@ -61,52 +61,38 @@ fn test_fold_iwiden_s() { assert_fully_folded(&h, &expected); } -#[test] -#[should_panic] -fn test_fold_inarrow_u() { - // pseudocode: - // - // x0 := int_u<5>(13); - // x1 := inarrow_u<5, 4>(x0); - // output x1 == int_u<4>(13); - let sum_type = sum_with_error(INT_TYPES[4].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 13).unwrap())); - let x1 = build - .add_dataflow_op(IntOpDef::inarrow_u.with_two_log_widths(5, 4), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(4, 13).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -#[should_panic] -fn test_fold_inarrow_s() { +#[rstest] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)] +fn test_fold_inarrow, E: std::fmt::Debug>( + #[case] mk_const: impl Fn(u8, I) -> Result, + #[case] op_def: IntOpDef, + #[case] from_log_width: u8, + #[case] to_log_width: u8, + #[case] val: I, + #[case] succeeds: bool, +) { // pseudocode: // // x0 := int_s<5>(-3); // x1 := inarrow_s<5, 4>(x0); // output x1 == int_s<4>(-3); - let sum_type = sum_with_error(INT_TYPES[4].to_owned()); + let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned()); let mut build = DFGBuilder::new(FunctionType::new( type_row![], vec![sum_type.clone().into()], )) .unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -3).unwrap())); + let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into()); let x1 = build - .add_dataflow_op(IntOpDef::inarrow_s.with_two_log_widths(5, 4), [x0]) + .add_dataflow_op( + op_def.with_two_log_widths(from_log_width, to_log_width), + [x0], + ) .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), @@ -115,7 +101,11 @@ fn test_fold_inarrow_s() { .unwrap(); let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(4, -3).unwrap()); + let expected = if succeeds { + Value::sum(0, [mk_const(to_log_width, val).unwrap().into()], sum_type).unwrap() + } else { + Value::sum(1, [super::INARROW_ERROR_VALUE.clone()], sum_type).unwrap() + }; assert_fully_folded(&h, &expected); } From 1e7a6f564221639dae900175d3eec8bf75546af2 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 13 May 2024 12:08:59 +0100 Subject: [PATCH 5/6] address review --- hugr/src/algorithm/const_fold.rs | 50 ++++++++++++++++++- .../arithmetic/int_ops/const_fold/test.rs | 12 ++++- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index ce4466b55..aafdb81f7 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -47,7 +47,7 @@ impl ConstFoldError { } #[derive(Debug, Clone, Copy, Default)] -/// TODO +/// A configuration for the Constant Folding pass. pub struct ConstFoldConfig { verify: VerifyLevel, } @@ -96,6 +96,10 @@ impl ConstFoldConfig { }; h.apply_rewrite(replace)?; for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. if let Ok(const_node) = h.apply_rewrite(rem) { // if the LoadConst was removed, try removing the Const too. let _ = h.apply_rewrite(RemoveConst(const_node)); @@ -502,4 +506,48 @@ mod test { constant_fold_pass(&mut h, ®); assert_fully_folded(&h, &Value::true_val()) } + + #[test] + fn test_folding_pass_issue_996() { + // pseudocode: + // + // x0 := 3.0 + // x1 := 4.0 + // x2 := fne(x0, x1); // true + // x3 := flt(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := -10.0 + // x6 := flt(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); + let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); + let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); + let x4 = build + .add_dataflow_op( + NaryLogic::And.with_n_inputs(2), + x2.outputs().chain(x3.outputs()), + ) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); + let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); + let x7 = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + x4.outputs().chain(x6.outputs()), + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); + } } diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs index af7e7e75b..959240a51 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs @@ -76,11 +76,19 @@ fn test_fold_inarrow, E: std::fmt::Debug>( #[case] val: I, #[case] succeeds: bool, ) { - // pseudocode: + // For the first case, pseudocode: // // x0 := int_s<5>(-3); // x1 := inarrow_s<5, 4>(x0); - // output x1 == int_s<4>(-3); + // output x1 == sum(-3)]>; + // + // Other cases vary by: + // (mk_const, op_def) => create signed or unsigned constants, create + // inarrow_s or inarrow_u ops; + // (from_log_width, to_log_width) => the args to use to create the op; + // val => the value to pass to the op + // succeeds => whether to expect a int variant or an error + // variant. let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned()); let mut build = DFGBuilder::new(FunctionType::new( type_row![], From 20b35d67cbf3c778abd67757c0c4d8cf1954a9e0 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Tue, 14 May 2024 10:18:53 +0100 Subject: [PATCH 6/6] remove `VerifyLevel` etc, assert! -> debug_assert! --- hugr/src/algorithm.rs | 31 --------- hugr/src/algorithm/const_fold.rs | 116 +++++++++++-------------------- 2 files changed, 39 insertions(+), 108 deletions(-) diff --git a/hugr/src/algorithm.rs b/hugr/src/algorithm.rs index 403632ecd..585e25a01 100644 --- a/hugr/src/algorithm.rs +++ b/hugr/src/algorithm.rs @@ -4,34 +4,3 @@ pub mod const_fold; mod half_node; pub mod merge_bbs; pub mod nest_cfgs; - -#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)] -/// A type for algorithms to take as configuration, specifying how much -/// verification they should do. Algorithms that accept this configuration -/// should at least verify that input HUGRs are valid, and that output HUGRs are -/// valid. -/// -/// The default level is `None` because verification can be expensive. -pub enum VerifyLevel { - /// Do no verification. - None, - /// Verify using [HugrView::validate_no_extensions]. This is useful when you - /// do not expect valid Extension annotations on Nodes. - /// - /// [HugrView::validate_no_extensions]: crate::HugrView::validate_no_extensions - WithoutExtensions, - /// Verify using [HugrView::validate]. - /// - /// [HugrView::validate]: crate::HugrView::validate - WithExtensions, -} - -impl Default for VerifyLevel { - fn default() -> Self { - if cfg!(test) { - Self::WithoutExtensions - } else { - Self::None - } - } -} diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index aafdb81f7..f2b0dab24 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -22,8 +22,6 @@ use crate::{ Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; -use super::VerifyLevel; - #[derive(Error, Debug)] #[allow(missing_docs)] pub enum ConstFoldError { @@ -37,79 +35,6 @@ pub enum ConstFoldError { SimpleReplaceError(#[from] SimpleReplacementError), } -impl ConstFoldError { - fn verify_err(label: impl Into, err: ValidationError) -> Self { - Self::VerifyError { - label: label.into(), - err, - } - } -} - -#[derive(Debug, Clone, Copy, Default)] -/// A configuration for the Constant Folding pass. -pub struct ConstFoldConfig { - verify: VerifyLevel, -} - -impl ConstFoldConfig { - /// Create a new `ConstFoldConfig` with default configuration. - pub fn new() -> Self { - Self::default() - } - - /// Build a `ConstFoldConfig` with the given [VerifyLevel]. - pub fn with_verify(mut self, verify: VerifyLevel) -> Self { - self.verify = verify; - self - } - - fn verify_impl( - &self, - label: &str, - h: &impl HugrView, - reg: &ExtensionRegistry, - ) -> Result<(), ConstFoldError> { - match self.verify { - VerifyLevel::None => Ok(()), - VerifyLevel::WithoutExtensions => h.validate_no_extensions(reg), - VerifyLevel::WithExtensions => h.validate(reg), - } - .map_err(|err| ConstFoldError::verify_err(label, err)) - } - - /// Run the Constant Folding pass. - pub fn run(&self, h: &mut impl HugrMut, reg: &ExtensionRegistry) -> Result<(), ConstFoldError> { - self.verify_impl("input", h, reg)?; - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { - break; - }; - h.apply_rewrite(replace)?; - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = h.apply_rewrite(RemoveConst(const_node)); - } - } - } - self.verify_impl("output", h, reg) - } -} - /// Tag some output constants with [`OutgoingPort`] inferred from the ordering. fn out_row(consts: impl IntoIterator) -> ConstFoldResult { let vec = consts @@ -162,7 +87,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR } _ => None, }; - assert!(fold_result.as_ref().map_or(true, |x| x.len() + debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() == op.value_port_count(Direction::Outgoing))); fold_result } @@ -275,7 +200,44 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< /// Exhaustively apply constant folding to a HUGR. pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - ConstFoldConfig::default().run(h, reg).unwrap() + #[cfg(test)] + let verify = |label, h: &H| { + h.validate_no_extensions(reg).unwrap_or_else(|err| { + panic!( + "constant_fold_pass: failed to verify {label} HUGR: {err}\n{}", + h.mermaid_string() + ) + }) + }; + #[cfg(test)] + verify("input", h); + loop { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { + break; + }; + h.apply_rewrite(replace).unwrap(); + for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = h.apply_rewrite(RemoveConst(const_node)); + } + } + } + #[cfg(test)] + verify("output", h); } #[cfg(test)]