Skip to content

Commit

Permalink
fix inarrow const folding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 13, 2024
1 parent f63072a commit 83ce472
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 62 deletions.
49 changes: 23 additions & 26 deletions hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ use crate::{

use super::IntOpDef;

use lazy_static::lazy_static;

lazy_static! {
static ref INARROW_ERROR_VALUE: Value = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
}
.into();
}

fn bitmask_from_width(width: u64) -> u64 {
debug_assert!(width <= 64);
if width == 64 {
Expand Down Expand Up @@ -111,28 +121,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let logwidth0: u8 = get_log_width(arg0).ok()?;
let logwidth1: u8 = get_log_width(arg1).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
(logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?;

let int_out_type = INT_TYPES[logwidth1 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())

let mk_out_const = |i, mb_v: Result<Value, _>| {
mb_v.and_then(|v| Value::sum(i, [v], sum_type))
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
};
let n0val: u64 = n0.value_u();
let out_const: Value = if n0val >> (1 << logwidth1) != 0 {
err_value()
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
} else {
Value::extension(ConstInt::new_u(logwidth1, n0val).unwrap())
mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into))
};
if logwidth0 < logwidth1 || n0.log_width() != logwidth0 {
None
} else {
Some(vec![(0.into(), out_const)])
}
Some(vec![(0.into(), out_const)])
},
),
},
Expand All @@ -145,29 +149,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let logwidth0: u8 = get_log_width(arg0).ok()?;
let logwidth1: u8 = get_log_width(arg1).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
(logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?;

let int_out_type = INT_TYPES[logwidth1 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
let mk_out_const = |i, mb_v: Result<Value, _>| {
mb_v.and_then(|v| Value::sum(i, [v], sum_type))
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
};
let n0val: i64 = n0.value_s();
let ub = 1i64 << ((1 << logwidth1) - 1);
let out_const: Value = if n0val >= ub || n0val < -ub {
err_value()
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
} else {
Value::extension(ConstInt::new_s(logwidth1, n0val).unwrap())
mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into))
};
if logwidth0 < logwidth1 || n0.log_width() != logwidth0 {
None
} else {
Some(vec![(0.into(), out_const)])
}
Some(vec![(0.into(), out_const)])
},
),
},
Expand Down
62 changes: 26 additions & 36 deletions hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,52 +61,38 @@ fn test_fold_iwiden_s() {
assert_fully_folded(&h, &expected);
}

#[test]
#[should_panic]
fn test_fold_inarrow_u() {
// pseudocode:
//
// x0 := int_u<5>(13);
// x1 := inarrow_u<5, 4>(x0);
// output x1 == int_u<4>(13);
let sum_type = sum_with_error(INT_TYPES[4].to_owned());
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![sum_type.clone().into()],
))
.unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 13).unwrap()));
let x1 = build
.add_dataflow_op(IntOpDef::inarrow_u.with_two_log_widths(5, 4), [x0])
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::int_types::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::extension(ConstInt::new_u(4, 13).unwrap());
assert_fully_folded(&h, &expected);
}

#[test]
#[should_panic]
fn test_fold_inarrow_s() {
#[rstest]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)]
fn test_fold_inarrow<I: Copy, C: Into<Value>, E: std::fmt::Debug>(
#[case] mk_const: impl Fn(u8, I) -> Result<C, E>,
#[case] op_def: IntOpDef,
#[case] from_log_width: u8,
#[case] to_log_width: u8,
#[case] val: I,
#[case] succeeds: bool,
) {
// pseudocode:
//
// x0 := int_s<5>(-3);
// x1 := inarrow_s<5, 4>(x0);
// output x1 == int_s<4>(-3);
let sum_type = sum_with_error(INT_TYPES[4].to_owned());
let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned());
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![sum_type.clone().into()],
))
.unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -3).unwrap()));
let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into());
let x1 = build
.add_dataflow_op(IntOpDef::inarrow_s.with_two_log_widths(5, 4), [x0])
.add_dataflow_op(
op_def.with_two_log_widths(from_log_width, to_log_width),
[x0],
)
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
Expand All @@ -115,7 +101,11 @@ fn test_fold_inarrow_s() {
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::extension(ConstInt::new_s(4, -3).unwrap());
let expected = if succeeds {
Value::sum(0, [mk_const(to_log_width, val).unwrap().into()], sum_type).unwrap()
} else {
Value::sum(1, [super::INARROW_ERROR_VALUE.clone()], sum_type).unwrap()
};
assert_fully_folded(&h, &expected);
}

Expand Down

0 comments on commit 83ce472

Please sign in to comment.