Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add verification to constant folding #1030

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ use crate::{

use super::IntOpDef;

use lazy_static::lazy_static;

lazy_static! {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to assume Alec has sufficiently reviewed the changes in this commit

static ref INARROW_ERROR_VALUE: Value = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
}
.into();
}

fn bitmask_from_width(width: u64) -> u64 {
debug_assert!(width <= 64);
if width == 64 {
Expand Down Expand Up @@ -111,28 +121,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let logwidth0: u8 = get_log_width(arg0).ok()?;
let logwidth1: u8 = get_log_width(arg1).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
(logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?;

let int_out_type = INT_TYPES[logwidth1 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())

let mk_out_const = |i, mb_v: Result<Value, _>| {
mb_v.and_then(|v| Value::sum(i, [v], sum_type))
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
};
let n0val: u64 = n0.value_u();
let out_const: Value = if n0val >> (1 << logwidth1) != 0 {
err_value()
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
} else {
Value::extension(ConstInt::new_u(logwidth1, n0val).unwrap())
mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into))
};
if logwidth0 < logwidth1 || n0.log_width() != logwidth0 {
None
} else {
Some(vec![(0.into(), out_const)])
}
Some(vec![(0.into(), out_const)])
},
),
},
Expand All @@ -145,29 +149,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let logwidth0: u8 = get_log_width(arg0).ok()?;
let logwidth1: u8 = get_log_width(arg1).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
(logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?;

let int_out_type = INT_TYPES[logwidth1 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
let mk_out_const = |i, mb_v: Result<Value, _>| {
mb_v.and_then(|v| Value::sum(i, [v], sum_type))
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
};
let n0val: i64 = n0.value_s();
let ub = 1i64 << ((1 << logwidth1) - 1);
let out_const: Value = if n0val >= ub || n0val < -ub {
err_value()
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
} else {
Value::extension(ConstInt::new_s(logwidth1, n0val).unwrap())
mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into))
};
if logwidth0 < logwidth1 || n0.log_width() != logwidth0 {
None
} else {
Some(vec![(0.into(), out_const)])
}
Some(vec![(0.into(), out_const)])
},
),
},
Expand Down
62 changes: 26 additions & 36 deletions hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,52 +61,38 @@ fn test_fold_iwiden_s() {
assert_fully_folded(&h, &expected);
}

#[test]
#[should_panic]
fn test_fold_inarrow_u() {
// pseudocode:
//
// x0 := int_u<5>(13);
// x1 := inarrow_u<5, 4>(x0);
// output x1 == int_u<4>(13);
let sum_type = sum_with_error(INT_TYPES[4].to_owned());
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![sum_type.clone().into()],
))
.unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 13).unwrap()));
let x1 = build
.add_dataflow_op(IntOpDef::inarrow_u.with_two_log_widths(5, 4), [x0])
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::int_types::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::extension(ConstInt::new_u(4, 13).unwrap());
assert_fully_folded(&h, &expected);
}

#[test]
#[should_panic]
fn test_fold_inarrow_s() {
#[rstest]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)]
fn test_fold_inarrow<I: Copy, C: Into<Value>, E: std::fmt::Debug>(
#[case] mk_const: impl Fn(u8, I) -> Result<C, E>,
#[case] op_def: IntOpDef,
#[case] from_log_width: u8,
#[case] to_log_width: u8,
#[case] val: I,
#[case] succeeds: bool,
) {
// pseudocode:
//
// x0 := int_s<5>(-3);
// x1 := inarrow_s<5, 4>(x0);
// output x1 == int_s<4>(-3);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be removed or rewritten in more general terms.

let sum_type = sum_with_error(INT_TYPES[4].to_owned());
let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned());
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![sum_type.clone().into()],
))
.unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -3).unwrap()));
let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into());
let x1 = build
.add_dataflow_op(IntOpDef::inarrow_s.with_two_log_widths(5, 4), [x0])
.add_dataflow_op(
op_def.with_two_log_widths(from_log_width, to_log_width),
[x0],
)
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
Expand All @@ -115,7 +101,11 @@ fn test_fold_inarrow_s() {
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::extension(ConstInt::new_s(4, -3).unwrap());
let expected = if succeeds {
Value::sum(0, [mk_const(to_log_width, val).unwrap().into()], sum_type).unwrap()
} else {
Value::sum(1, [super::INARROW_ERROR_VALUE.clone()], sum_type).unwrap()
};
assert_fully_folded(&h, &expected);
}

Expand Down