Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 13, 2024
1 parent 862419f commit 1e7a6f5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
50 changes: 49 additions & 1 deletion hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl ConstFoldError {
}

#[derive(Debug, Clone, Copy, Default)]
/// TODO
/// A configuration for the Constant Folding pass.
pub struct ConstFoldConfig {
verify: VerifyLevel,
}
Expand Down Expand Up @@ -96,6 +96,10 @@ impl ConstFoldConfig {
};
h.apply_rewrite(replace)?;
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = h.apply_rewrite(RemoveConst(const_node));
Expand Down Expand Up @@ -502,4 +506,48 @@ mod test {
constant_fold_pass(&mut h, &reg);
assert_fully_folded(&h, &Value::true_val())
}

#[test]
fn test_folding_pass_issue_996() {
// pseudocode:
//
// x0 := 3.0
// x1 := 4.0
// x2 := fne(x0, x1); // true
// x3 := flt(x0, x1); // true
// x4 := and(x2, x3); // true
// x5 := -10.0
// x6 := flt(x0, x5) // false
// x7 := or(x4, x6) // true
// output x7
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0)));
let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0)));
let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap();
let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap();
let x4 = build
.add_dataflow_op(
NaryLogic::And.with_n_inputs(2),
x2.outputs().chain(x3.outputs()),
)
.unwrap();
let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0)));
let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap();
let x7 = build
.add_dataflow_op(
NaryLogic::Or.with_n_inputs(2),
x4.outputs().chain(x6.outputs()),
)
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
logic::EXTENSION.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x7.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
}
12 changes: 10 additions & 2 deletions hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,19 @@ fn test_fold_inarrow<I: Copy, C: Into<Value>, E: std::fmt::Debug>(
#[case] val: I,
#[case] succeeds: bool,
) {
// pseudocode:
// For the first case, pseudocode:
//
// x0 := int_s<5>(-3);
// x1 := inarrow_s<5, 4>(x0);
// output x1 == int_s<4>(-3);
// output x1 == sum<tag=0,[int_s<4>(-3)]>;
//
// Other cases vary by:
// (mk_const, op_def) => create signed or unsigned constants, create
// inarrow_s or inarrow_u ops;
// (from_log_width, to_log_width) => the args to use to create the op;
// val => the value to pass to the op
// succeeds => whether to expect a int<to_log_width> variant or an error
// variant.
let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned());
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
Expand Down

0 comments on commit 1e7a6f5

Please sign in to comment.