Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement RemoveConst and RemoveConstIgnore #757

Merged
merged 12 commits into from
Jan 3, 2024
6 changes: 3 additions & 3 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?;
fn add_constant(&mut self, constant: impl Into<ops::Const>) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?;

Ok(const_n.into())
}
Expand Down Expand Up @@ -374,7 +374,7 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}
Expand Down
4 changes: 2 additions & 2 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?;
let const_wire = loop_b.add_load_const(ConstUsize::new(1))?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -173,7 +173,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(ConstUsize::new(2).into())?;
let wire = branch_1.add_load_const(ConstUsize::new(2))?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
1 change: 1 addition & 0 deletions src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Rewrite operations on the HUGR - replacement, outlining, etc.

pub mod consts;
pub mod insert_identity;
pub mod outline_cfg;
pub mod replace;
Expand Down
231 changes: 231 additions & 0 deletions src/hugr/rewrite/consts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
//! Rewrite operations involving Const and LoadConst operations

use std::iter;

use crate::{
hugr::{HugrError, HugrMut},
HugrView, Node,
};
#[rustversion::since(1.75)] // uses impl in return position
use itertools::Itertools;
use thiserror::Error;

use super::Rewrite;

/// Remove a [`crate::ops::LoadConstant`] node with no outputs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "with no outputs" - it has an outport; the condition is that there is nothing connected i.e. consuming it (right??). "No successors" or "no consumers"?

#[derive(Debug, Clone)]
pub struct RemoveConstIgnore(pub Node);
ss2165 marked this conversation as resolved.
Show resolved Hide resolved

/// Error from an [`RemoveConstIgnore`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RemoveConstIgnoreError {
/// Invalid node.
#[error("Node is invalid (either not in HUGR or not LoadConst).")]
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
InvalidNode(Node),
/// Node in use.
#[error("Node: {0:?} has non-zero outgoing connections.")]
ValueUsed(Node),
/// Not connected to a Const.
#[error("Node: {0:?} is not connected to a Const node.")]
Copy link
Contributor

@acl-cqc acl-cqc Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this does raise the question of, if this has happened, what the heck are you supposed to do with it ;-). I guess you have to rewind to a state where you can avoid getting into that mess. Consider defining ApplyResult = Option<Node> and then return None in such cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we necessarily fail validation if this is the case? It is too much to ask every rewrite to handle and report every error that would be detected by validate().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I removed this error case

NoConst(Node),
/// Removal error
#[error("Removing node caused error: {0:?}.")]
RemoveFail(#[from] HugrError),
}

#[rustversion::since(1.75)] // uses impl in return position
impl Rewrite for RemoveConstIgnore {
type Error = RemoveConstIgnoreError;

// The Const node the LoadConstant was connected to.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
return Err(RemoveConstIgnoreError::InvalidNode(node));
}

if h.out_value_types(node)
.next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about h.out_value_types(node).exactly_one().unwrap()? Aren't we allowed to assume the input Hugr validates?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, just use h.output_neighbours().next().is_some() ?

Copy link
Member Author

@ss2165 ss2165 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe things connected by order edges would show up there as well and I want to make sure I get the value edge.

.is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some())
{
return Err(RemoveConstIgnoreError::ValueUsed(node));
}

Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let node = self.0;
let source = h
.input_neighbours(node)
.exactly_one()
.map_err(|_| RemoveConstIgnoreError::NoConst(node))?;
h.remove_node(node)?;

Ok(source)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.0)
}
}

/// Remove a [`crate::ops::Const`] node with no outputs.
#[derive(Debug, Clone)]
pub struct RemoveConst(pub Node);

/// Error from an [`RemoveConst`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum RemoveConstError {
/// Invalid node.
#[error("Node is invalid (either not in HUGR or not Const).")]
InvalidNode(Node),
/// Node in use.
#[error("Node: {0:?} has non-zero outgoing connections.")]
ValueUsed(Node),
/// Removal error
#[error("Removing node caused error: {0:?}.")]
RemoveFail(#[from] HugrError),
}

impl Rewrite for RemoveConst {
type Error = RemoveConstError;

// The parent of the Const node.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
let node = self.0;

if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
return Err(RemoveConstError::InvalidNode(node));
}

if h.output_neighbours(node).next().is_some() {
return Err(RemoveConstError::ValueUsed(node));
}

Ok(())
}

fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
let node = self.0;
let source = h
Copy link
Contributor

@acl-cqc acl-cqc Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to parent? Or something, but I don't see how it's source

.get_parent(node)
.expect("Const node without a parent shouldn't happen.");
h.remove_node(node)?;

Ok(source)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
iter::once(self.0)
}
}

#[rustversion::since(1.75)] // uses impl in return position
#[cfg(test)]
mod test {
use super::*;
use crate::{
builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
extension::{
prelude::{ConstUsize, USIZE_T},
PRELUDE_REGISTRY,
},
hugr::HugrMut,
ops::{handle::NodeHandle, LeafOp},
type_row,
types::FunctionType,
};
#[test]
fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thorough, test :)

let mut build = ModuleBuilder::new();
let con_node = build.add_constant(ConstUsize::new(2))?;

let mut dfg_build =
build.define_function("main", FunctionType::new_endo(type_row![]).into())?;
let load_1 = dfg_build.load_const(&con_node)?;
let load_2 = dfg_build.load_const(&con_node)?;
let tup = dfg_build.add_dataflow_op(
LeafOp::MakeTuple {
tys: type_row![USIZE_T, USIZE_T],
},
[load_1, load_2],
)?;
dfg_build.finish_sub_container()?;

let mut h = build.finish_prelude_hugr()?;
assert_eq!(h.node_count(), 8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could change this to a Function-rooted Hugr to reduce counts by 1. Would be nice to comment why there are so many (Module, Function, Input, Output, Const, LoadConstant*2, Tuple)...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to test that non-local const->load edges work ok
Will add comment

let tup_node = tup.node();
// can't remove invalid node
assert_eq!(
h.apply_rewrite(RemoveConst(tup_node)),
Err(RemoveConstError::InvalidNode(tup_node))
);

assert_eq!(
h.apply_rewrite(RemoveConstIgnore(tup_node)),
Err(RemoveConstIgnoreError::InvalidNode(tup_node))
);
let load_1_node = load_1.node();
let load_2_node = load_2.node();
let con_node = con_node.node();

let remove_1 = RemoveConstIgnore(load_1_node);
assert_eq!(
remove_1.invalidation_set().exactly_one().ok(),
Some(load_1_node)
);

let remove_2 = RemoveConstIgnore(load_2_node);

let remove_con = RemoveConst(con_node);
assert_eq!(
remove_con.invalidation_set().exactly_one().ok(),
Some(con_node)
);

// can't remove nodes in use
assert_eq!(
h.apply_rewrite(remove_1.clone()),
Err(RemoveConstIgnoreError::ValueUsed(load_1_node))
);

// remove the use
h.remove_node(tup_node)?;

// remove first load
let reported_con_node = h.apply_rewrite(remove_1)?;
assert_eq!(reported_con_node, con_node);

// still can't remove const, in use by second load
assert_eq!(
h.apply_rewrite(remove_con.clone()),
Err(RemoveConstError::ValueUsed(con_node))
);

// remove second use
let reported_con_node = h.apply_rewrite(remove_2)?;
assert_eq!(reported_con_node, con_node);
// remove const
assert_eq!(h.apply_rewrite(remove_con)?, h.root());

assert_eq!(h.node_count(), 4);
assert!(h.validate(&PRELUDE_REGISTRY).is_ok());
Ok(())
}
}
2 changes: 1 addition & 1 deletion src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ fn static_targets() {
)
.unwrap();

let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap();
let c = dfg.add_constant(ConstUsize::new(1)).unwrap();

let load = dfg.load_const(&c).unwrap();

Expand Down