diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index e19de46df..b9ea98989 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -217,15 +217,24 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { #[cfg(test)] mod test { + use super::*; + use crate::extension::prelude::sum_with_error; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::std_extensions::arithmetic; - + use crate::std_extensions::arithmetic::conversions::ConvertOpDef; use crate::std_extensions::arithmetic::float_ops::FloatOps; use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - + use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; use rstest::rstest; - use super::*; + /// int to constant + fn i2c(b: u64) -> Const { + Const::new( + ConstIntU::new(5, b).unwrap().into(), + INT_TYPES[5].to_owned(), + ) + .unwrap() + } /// float to constant fn f2c(f: f64) -> Const { @@ -244,19 +253,19 @@ mod test { assert_eq!(&out[..], &[(0.into(), f2c(c))]); } - #[test] fn test_big() { /* - Test hugr approximately calculates - let x = (5.5, 3.25); - x.0 - x.1 == 2.25 + Test approximately calculates + let x = (5.6, 3.2); + int(x.0 - x.1) == 2 */ + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); let mut build = - DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); + DFGBuilder::new(FunctionType::new(type_row![], vec![sum_type.clone()])).unwrap(); let tup = build - .add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)])) + .add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)])) .unwrap(); let unpack = build @@ -271,19 +280,31 @@ mod test { let sub = build .add_dataflow_op(FloatOps::fsub, unpack.outputs()) .unwrap(); + let to_int = build + .add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs()) + .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), arithmetic::float_types::EXTENSION.to_owned(), arithmetic::float_ops::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), ]) .unwrap(); - let mut h = build.finish_hugr_with_outputs(sub.outputs(), ®).unwrap(); - assert_eq!(h.node_count(), 7); + let mut h = build + .finish_hugr_with_outputs(to_int.outputs(), ®) + .unwrap(); + assert_eq!(h.node_count(), 8); constant_fold_pass(&mut h, ®); - assert_fully_folded(&h, &f2c(2.25)); + let expected = Value::Sum { + tag: 0, + value: Box::new(i2c(2).value().clone()), + }; + let expected = Const::new(expected, sum_type).unwrap(); + assert_fully_folded(&h, &expected); } fn assert_fully_folded(h: &Hugr, expected_const: &Const) { // check the hugr just loads and returns a single const diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 23b457f7c..46e187fd5 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -19,6 +19,7 @@ use crate::{ use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; use lazy_static::lazy_static; +mod const_fold; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); @@ -63,8 +64,21 @@ impl MakeOpDef for ConvertOpDef { } .to_string() } + + fn post_opdef(&self, def: &mut OpDef) { + const_fold::set_fold(self, def) + } } +impl ConvertOpDef { + /// Initialise a conversion op with an integer log width type argument. + pub fn with_width(self, log_width: u8) -> ConvertOpType { + ConvertOpType { + def: self, + log_width: log_width as u64, + } + } +} /// Concrete convert operation with integer width set. #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { diff --git a/src/std_extensions/arithmetic/conversions/const_fold.rs b/src/std_extensions/arithmetic/conversions/const_fold.rs new file mode 100644 index 000000000..3814c0504 --- /dev/null +++ b/src/std_extensions/arithmetic/conversions/const_fold.rs @@ -0,0 +1,134 @@ +use crate::{ + extension::{ + prelude::{sum_with_error, ConstError}, + ConstFold, ConstFoldResult, OpDef, + }, + ops, + std_extensions::arithmetic::{ + float_types::ConstF64, + int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES}, + }, + types::ConstTypeError, + values::{CustomConst, Value}, + IncomingPort, +}; + +use super::ConvertOpDef; + +pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) { + use ConvertOpDef::*; + + match op { + trunc_u => def.set_constant_folder(TruncU), + trunc_s => def.set_constant_folder(TruncS), + convert_u => def.set_constant_folder(ConvertU), + convert_s => def.set_constant_folder(ConvertS), + } +} + +fn get_input(consts: &[(IncomingPort, ops::Const)]) -> Option<&T> { + let [(_, c)] = consts else { + return None; + }; + c.get_custom_value() +} + +fn fold_trunc( + type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + convert: impl Fn(f64, u8) -> Result, +) -> ConstFoldResult { + let f: &ConstF64 = get_input(consts)?; + let f = f.value(); + let [arg] = type_args else { + return None; + }; + let log_width = get_log_width(arg).ok()?; + let int_type = INT_TYPES[log_width as usize].to_owned(); + let sum_type = sum_with_error(int_type.clone()); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Can't truncate non-finite float".to_string(), + }; + let sum_val = Value::Sum { + tag: 1, + value: Box::new(err_val.into()), + }; + + ops::Const::new(sum_val, sum_type.clone()).unwrap() + }; + let out_const: ops::Const = if !f.is_finite() { + err_value() + } else { + let cv = convert(f, log_width); + if let Ok(cv) = cv { + let sum_val = Value::Sum { + tag: 0, + value: Box::new(cv), + }; + + ops::Const::new(sum_val, sum_type).unwrap() + } else { + err_value() + } + }; + + Some(vec![(0.into(), out_const)]) +} + +struct TruncU; + +impl ConstFold for TruncU { + fn fold( + &self, + type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + fold_trunc(type_args, consts, |f, log_width| { + ConstIntU::new(log_width, f.trunc() as u64).map(Into::into) + }) + } +} + +struct TruncS; + +impl ConstFold for TruncS { + fn fold( + &self, + type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + fold_trunc(type_args, consts, |f, log_width| { + ConstIntS::new(log_width, f.trunc() as i64).map(Into::into) + }) + } +} + +struct ConvertU; + +impl ConstFold for ConvertU { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let u: &ConstIntU = get_input(consts)?; + let f = u.value() as f64; + Some(vec![(0.into(), ConstF64::new(f).into())]) + } +} + +struct ConvertS; + +impl ConstFold for ConvertS { + fn fold( + &self, + _type_args: &[crate::types::TypeArg], + consts: &[(IncomingPort, ops::Const)], + ) -> ConstFoldResult { + let u: &ConstIntS = get_input(consts)?; + let f = u.value() as f64; + Some(vec![(0.into(), ConstF64::new(f).into())]) + } +}