From 3876a73c54ddc46cba735d5cfe81e5c9eeb135c9 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 8 Jan 2024 15:57:59 +0000 Subject: [PATCH] [feat] constant folding for logic extension --- src/algorithm/const_fold.rs | 28 ++++++++++++++++++++++- src/std_extensions/logic.rs | 45 ++++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 531181d58..50ea430c4 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -217,7 +217,7 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { mod test { use super::*; - use crate::extension::prelude::sum_with_error; + use crate::extension::prelude::{sum_with_error, BOOL_T}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::ops::OpType; use crate::std_extensions::arithmetic; @@ -225,6 +225,7 @@ mod test { use crate::std_extensions::arithmetic::float_ops::FloatOps; use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; + use crate::std_extensions::logic::{self, const_from_bool, NaryLogic}; use rstest::rstest; /// int to constant @@ -306,6 +307,31 @@ mod test { let expected = Const::new(expected, sum_type).unwrap(); assert_fully_folded(&h, &expected); } + + #[rstest] + #[case(NaryLogic::And, [true, true, true], true)] + #[case(NaryLogic::And, [true, false, true], false)] + #[case(NaryLogic::Or, [false, false, true], true)] + #[case(NaryLogic::Or, [false, false, false], false)] + fn test_logic_and( + #[case] op: NaryLogic, + #[case] ins: [bool; 3], + #[case] out: bool, + ) -> Result<(), Box> { + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + + let ins = ins.map(|b| build.add_load_const(const_from_bool(b)).unwrap()); + let logic_op = build.add_dataflow_op(op.with_n_inputs(ins.len() as u64), ins)?; + + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(logic_op.outputs(), ®)?; + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &const_from_bool(out)); + Ok(()) + } + fn assert_fully_folded(h: &Hugr, expected_const: &Const) { // check the hugr just loads and returns a single const let mut node_count = 0; diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index cdcae57cc..d3abc9d4c 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -3,6 +3,7 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::{ + algorithm::const_fold::sorted_consts, extension::{ prelude::BOOL_T, simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, @@ -14,7 +15,7 @@ use crate::{ type_param::{TypeArg, TypeParam}, FunctionType, }, - Extension, + Extension, IncomingPort, }; use lazy_static::lazy_static; /// Name of extension false value. @@ -46,6 +47,21 @@ impl MakeOpDef for NaryLogic { fn from_def(op_def: &OpDef) -> Result { try_from_name(op_def.name()) } + + fn post_opdef(&self, def: &mut OpDef) { + def.set_constant_folder(match self { + NaryLogic::And => |consts: &_| { + let inps = read_inputs(consts)?; + let res = inps.into_iter().all(|x| x); + Some(vec![(0.into(), const_from_bool(res))]) + }, + NaryLogic::Or => |consts: &_| { + let inps = read_inputs(consts)?; + let res = inps.into_iter().any(|x| x); + Some(vec![(0.into(), const_from_bool(res))]) + }, + }) + } } /// Make a [NaryLogic] operation concrete by setting the type argument. @@ -171,6 +187,33 @@ impl MakeRegisteredOp for NotOp { } } +fn read_inputs(consts: &[(IncomingPort, ops::Const)]) -> Option> { + let true_val = ops::Const::true_val(); + let false_val = ops::Const::false_val(); + let inps: Option> = sorted_consts(consts) + .into_iter() + .map(|c| { + if c == &true_val { + Some(true) + } else if c == &false_val { + Some(false) + } else { + None + } + }) + .collect(); + let inps = inps?; + Some(inps) +} + +pub(crate) fn const_from_bool(res: bool) -> ops::Const { + if res { + ops::Const::true_val() + } else { + ops::Const::false_val() + } +} + #[cfg(test)] pub(crate) mod test { use super::{extension, ConcreteLogicOp, NaryLogic, NotOp, FALSE_NAME, TRUE_NAME};