diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs index 0915a4737b..8738e1872e 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -16,6 +16,16 @@ use crate::{ use super::IntOpDef; +use lazy_static::lazy_static; + +lazy_static! { + static ref INARROW_ERROR_VALUE: Value = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + } + .into(); +} + fn bitmask_from_width(width: u64) -> u64 { debug_assert!(width <= 64); if width == 64 { @@ -111,28 +121,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { let logwidth0: u8 = get_log_width(arg0).ok()?; let logwidth1: u8 = get_log_width(arg1).ok()?; let n0: &ConstInt = get_single_input_value(consts)?; + (logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?; let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); let sum_type = sum_with_error(int_out_type.clone()); - let err_value = || { - let err_val = ConstError { - signal: 0, - message: "Integer too large to narrow".to_string(), - }; - Value::sum(1, [err_val.into()], sum_type.clone()) + + let mk_out_const = |i, mb_v: Result| { + mb_v.and_then(|v| Value::sum(i, [v], sum_type)) .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) }; let n0val: u64 = n0.value_u(); let out_const: Value = if n0val >> (1 << logwidth1) != 0 { - err_value() + mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone())) } else { - Value::extension(ConstInt::new_u(logwidth1, n0val).unwrap()) + mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into)) }; - if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { - None - } else { - Some(vec![(0.into(), out_const)]) - } + Some(vec![(0.into(), out_const)]) }, ), }, @@ -145,29 +149,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { let logwidth0: u8 = get_log_width(arg0).ok()?; let logwidth1: u8 = get_log_width(arg1).ok()?; let n0: &ConstInt = get_single_input_value(consts)?; + (logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?; let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); let sum_type = sum_with_error(int_out_type.clone()); - let err_value = || { - let err_val = ConstError { - signal: 0, - message: "Integer too large to narrow".to_string(), - }; - Value::sum(1, [err_val.into()], sum_type.clone()) + let mk_out_const = |i, mb_v: Result| { + mb_v.and_then(|v| Value::sum(i, [v], sum_type)) .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) }; let n0val: i64 = n0.value_s(); let ub = 1i64 << ((1 << logwidth1) - 1); let out_const: Value = if n0val >= ub || n0val < -ub { - err_value() + mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone())) } else { - Value::extension(ConstInt::new_s(logwidth1, n0val).unwrap()) + mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into)) }; - if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { - None - } else { - Some(vec![(0.into(), out_const)]) - } + Some(vec![(0.into(), out_const)]) }, ), }, diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs index 6984783c36..af7e7e75b1 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs @@ -61,52 +61,38 @@ fn test_fold_iwiden_s() { assert_fully_folded(&h, &expected); } -#[test] -#[should_panic] -fn test_fold_inarrow_u() { - // pseudocode: - // - // x0 := int_u<5>(13); - // x1 := inarrow_u<5, 4>(x0); - // output x1 == int_u<4>(13); - let sum_type = sum_with_error(INT_TYPES[4].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 13).unwrap())); - let x1 = build - .add_dataflow_op(IntOpDef::inarrow_u.with_two_log_widths(5, 4), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(4, 13).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -#[should_panic] -fn test_fold_inarrow_s() { +#[rstest] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)] +fn test_fold_inarrow, E: std::fmt::Debug>( + #[case] mk_const: impl Fn(u8, I) -> Result, + #[case] op_def: IntOpDef, + #[case] from_log_width: u8, + #[case] to_log_width: u8, + #[case] val: I, + #[case] succeeds: bool, +) { // pseudocode: // // x0 := int_s<5>(-3); // x1 := inarrow_s<5, 4>(x0); // output x1 == int_s<4>(-3); - let sum_type = sum_with_error(INT_TYPES[4].to_owned()); + let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned()); let mut build = DFGBuilder::new(FunctionType::new( type_row![], vec![sum_type.clone().into()], )) .unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -3).unwrap())); + let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into()); let x1 = build - .add_dataflow_op(IntOpDef::inarrow_s.with_two_log_widths(5, 4), [x0]) + .add_dataflow_op( + op_def.with_two_log_widths(from_log_width, to_log_width), + [x0], + ) .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), @@ -115,7 +101,11 @@ fn test_fold_inarrow_s() { .unwrap(); let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(4, -3).unwrap()); + let expected = if succeeds { + Value::sum(0, [mk_const(to_log_width, val).unwrap().into()], sum_type).unwrap() + } else { + Value::sum(1, [super::INARROW_ERROR_VALUE.clone()], sum_type).unwrap() + }; assert_fully_folded(&h, &expected); }