From 84a92d9af0daed077739e16abd7c5139b492f357 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 16:15:00 +0000 Subject: [PATCH] wip: constant folding refactor!: Closes Flatten `Prim(Type/Value)` in to parent enum #665 BREAKING_CHANGES: In serialization, extension and function values no longer wrapped by "pv". --- src/algorithm.rs | 1 + src/algorithm/const_fold.rs | 60 ++++++++++++++++++++++ src/std_extensions/arithmetic/int_types.rs | 11 +++- src/values.rs | 13 ++++- 4 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 src/algorithm/const_fold.rs diff --git a/src/algorithm.rs b/src/algorithm.rs index 0023b59167..633231504e 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,4 +1,5 @@ //! Algorithms using the Hugr. +pub mod const_fold; mod half_node; pub mod nest_cfgs; diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs new file mode 100644 index 0000000000..eeee6bba5e --- /dev/null +++ b/src/algorithm/const_fold.rs @@ -0,0 +1,60 @@ +//! Constant folding routines. + +use crate::{ + ops::{Const, OpType}, + values::Value, + IncomingPort, OutgoingPort, +}; + +/// For a given op and consts, attempt to evaluate the op. +pub fn fold_const( + op: &OpType, + consts: &[(IncomingPort, Const)], +) -> Option> { + consts.iter().find_map(|(_, cnst)| match cnst.value() { + Value::Extension { c: (c,) } => c.fold(op, consts), + Value::Tuple { .. } => todo!(), + Value::Sum { .. } => todo!(), + Value::Function { .. } => None, + }) +} + +#[cfg(test)] +mod test { + use crate::{ + extension::PRELUDE_REGISTRY, + ops::LeafOp, + std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}, + types::TypeArg, + }; + use rstest::rstest; + + use super::*; + + fn i2c(b: u64) -> Const { + Const::new( + ConstIntU::new(5, b).unwrap().into(), + INT_TYPES[5].to_owned(), + ) + .unwrap() + } + + fn u64_add() -> LeafOp { + crate::std_extensions::arithmetic::int_types::EXTENSION + .instantiate_extension_op("iadd", [TypeArg::BoundedNat { n: 5 }], &PRELUDE_REGISTRY) + .unwrap() + .into() + } + #[rstest] + #[case(0, 0, 0)] + #[case(0, 1, 1)] + #[case(23, 435, 458)] + // c = a && b + fn test_and(#[case] a: u64, #[case] b: u64, #[case] c: u64) { + let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))]; + let add_op: OpType = u64_add().into(); + let out = fold_const(&add_op, &consts).unwrap(); + + assert_eq!(&out[..], &[(0.into(), i2c(c))]); + } +} diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 7a67de28ae..079310f96b 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -6,12 +6,13 @@ use smol_str::SmolStr; use crate::{ extension::ExtensionId, + ops::OpType, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, }, values::CustomConst, - Extension, + Extension, IncomingPort, OutgoingPort, }; use lazy_static::lazy_static; /// The extension identifier. @@ -161,6 +162,14 @@ impl CustomConst for ConstIntU { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn fold( + &self, + _op: &OpType, + _consts: &[(IncomingPort, crate::ops::Const)], + ) -> Option> { + None + } } #[typetag::serde] diff --git a/src/values.rs b/src/values.rs index 4286540661..07391c5920 100644 --- a/src/values.rs +++ b/src/values.rs @@ -9,7 +9,8 @@ use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; use crate::macros::impl_box_clone; -use crate::{Hugr, HugrView}; +use crate::ops::OpType; +use crate::{Hugr, HugrView, IncomingPort, OutgoingPort}; use crate::types::{CustomCheckFailure, CustomType}; @@ -143,6 +144,16 @@ pub trait CustomConst: // false unless overloaded false } + + /// Attempt to evaluate an operation given some constant inputs - typically + /// involving instances of Self + fn fold( + &self, + _op: &OpType, + _consts: &[(IncomingPort, crate::ops::Const)], + ) -> Option> { + None + } } /// Const equality for types that have PartialEq