Skip to content

Commit

Permalink
wip: constant folding
Browse files Browse the repository at this point in the history
refactor!: Closes Flatten `Prim(Type/Value)` in to parent enum #665

BREAKING_CHANGES: In serialization, extension and function values no longer
wrapped by "pv".
  • Loading branch information
ss2165 committed Nov 24, 2023
1 parent 592dd28 commit bffed99
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Algorithms using the Hugr.
pub mod const_fold;
mod half_node;
pub mod nest_cfgs;
60 changes: 60 additions & 0 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//! Constant folding routines.
use crate::{
ops::{Const, OpType},
values::Value,
IncomingPort, OutgoingPort,
};

/// For a given op and consts, attempt to evaluate the op.
pub fn fold_const(
op: &OpType,
consts: &[(IncomingPort, Const)],
) -> Option<Vec<(OutgoingPort, Const)>> {
consts.iter().find_map(|(_, cnst)| match cnst.value() {
Value::Extension { c: (c,) } => c.fold(op, consts),
Value::Tuple { .. } => todo!(),
Value::Sum { .. } => todo!(),
Value::Function { .. } => None,
})
}

#[cfg(test)]
mod test {
use crate::{
extension::PRELUDE_REGISTRY,
ops::LeafOp,
std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES},
types::TypeArg,
};
use rstest::rstest;

use super::*;

fn i2c(b: u64) -> Const {
Const::new(
ConstIntU::new(5, b).unwrap().into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
}

fn u64_add() -> LeafOp {
crate::std_extensions::arithmetic::int_types::EXTENSION
.instantiate_extension_op("iadd", [TypeArg::BoundedNat { n: 5 }], &PRELUDE_REGISTRY)
.unwrap()
.into()
}
#[rstest]
#[case(0, 0, 0)]
#[case(0, 1, 1)]
#[case(23, 435, 458)]
// c = a && b
fn test_and(#[case] a: u64, #[case] b: u64, #[case] c: u64) {
let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))];
let add_op: OpType = u64_add().into();
let out = fold_const(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), i2c(c))]);
}
}
11 changes: 10 additions & 1 deletion src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use smol_str::SmolStr;

use crate::{
extension::ExtensionId,
ops::OpType,
types::{
type_param::{TypeArg, TypeArgError, TypeParam},
ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound,
},
values::CustomConst,
Extension,
Extension, IncomingPort, OutgoingPort,
};
use lazy_static::lazy_static;
/// The extension identifier.
Expand Down Expand Up @@ -161,6 +162,14 @@ impl CustomConst for ConstIntU {
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::values::downcast_equal_consts(self, other)
}

fn fold(
&self,
_op: &OpType,
_consts: &[(IncomingPort, crate::ops::Const)],
) -> Option<Vec<(OutgoingPort, crate::ops::Const)>> {
None
}
}

#[typetag::serde]
Expand Down
13 changes: 12 additions & 1 deletion src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use crate::macros::impl_box_clone;
use crate::{Hugr, HugrView};
use crate::ops::OpType;
use crate::{Hugr, HugrView, IncomingPort, OutgoingPort};

use crate::types::{CustomCheckFailure, CustomType};

Expand Down Expand Up @@ -143,6 +144,16 @@ pub trait CustomConst:
// false unless overloaded
false
}

/// Attempt to evaluate an operation given some constant inputs - typically
/// involving instances of Self
fn fold(
&self,
_op: &OpType,
_consts: &[(IncomingPort, crate::ops::Const)],
) -> Option<Vec<(OutgoingPort, crate::ops::Const)>> {
None
}
}

/// Const equality for types that have PartialEq
Expand Down

0 comments on commit bffed99

Please sign in to comment.