Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Combine ExtensionSolutions (no separate closure) #884

Merged
merged 20 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions quantinuum-hugr/src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,17 @@ use thiserror::Error;
/// been inferred for their inputs.
pub type ExtensionSolution = HashMap<Node, ExtensionSet>;

/// Infer extensions for a hugr. This is the main API exposed by this module
/// Infer extensions for a hugr. This is the main API exposed by this module.
///
/// Return a tuple of the solutions found for locations on the graph, and a
/// closure: a solution which would be valid if all of the variables in the graph
/// were instantiated to an empty extension set. This is used (by validation) to
/// concretise the extension requirements of the whole hugr.
pub fn infer_extensions(
hugr: &impl HugrView,
) -> Result<(ExtensionSolution, ExtensionSolution), InferExtensionError> {
/// Return all the solutions found for locations on the graph, these can be
/// passed to [`validate_with_extension_closure`]
///
/// [`validate_with_extension_closure`]: crate::Hugr::validate_with_extension_closure
pub fn infer_extensions(hugr: &impl HugrView) -> Result<ExtensionSolution, InferExtensionError> {
let mut ctx = UnificationContext::new(hugr);
let solution = ctx.main_loop()?;
ctx.main_loop()?;
ctx.instantiate_variables();
let closed_solution = ctx.main_loop()?;
let closure: ExtensionSolution = closed_solution
.into_iter()
.filter(|(node, _)| !solution.contains_key(node))
.collect();
Ok((solution, closure))
ctx.main_loop()
}

/// Metavariables don't need much
Expand Down
35 changes: 28 additions & 7 deletions quantinuum-hugr/src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ fn from_graph() -> Result<(), Box<dyn Error>> {

hugr.connect(mult_c, 0, output, 0);

let (_, closure) = infer_extensions(&hugr)?;
let solution = infer_extensions(&hugr)?;
let empty = ExtensionSet::new();
let ab = ExtensionSet::from_iter([A, B]);
assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*closure.get(&(mult_c)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab)).unwrap(), empty);
assert_eq!(*closure.get(&add_b).unwrap(), ExtensionSet::singleton(&A));
assert_eq!(*solution.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*solution.get(&(mult_c)).unwrap(), ab);
assert_eq!(*solution.get(&(add_ab)).unwrap(), empty);
assert_eq!(*solution.get(&add_b).unwrap(), ExtensionSet::singleton(&A));
Ok(())
}

Expand Down Expand Up @@ -249,8 +249,7 @@ fn dangling_src() -> Result<(), Box<dyn Error>> {
hugr.connect(src, 0, mult, 1);
hugr.connect(mult, 0, output, 0);

let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
hugr.infer_extensions()?;
assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs);
assert_eq!(
hugr.get_nodetype(mult.node()).io_extensions().unwrap(),
Expand Down Expand Up @@ -795,6 +794,28 @@ fn test_cfg_loops() -> Result<(), Box<dyn Error>> {
Ok(())
}

#[test]
#[cfg(feature = "extension_inference")]
fn test_validate_with_closure() -> Result<(), Box<dyn Error>> {
let mut hugr = make_looping_cfg(
ExtensionSet::new(),
ExtensionSet::singleton(&A),
ExtensionSet::singleton(&B),
)?;
assert_matches!(
hugr.validate(&PRELUDE_REGISTRY),
Err(ValidationError::ExtensionError(_))
);

let immut = hugr.clone();
let soln = infer_extensions(&immut)?;
immut.validate_with_extension_closure(soln, &PRELUDE_REGISTRY)?;
croyzor marked this conversation as resolved.
Show resolved Hide resolved

hugr.update_validate(&PRELUDE_REGISTRY)?; // Solution written in, hence:
hugr.validate(&PRELUDE_REGISTRY)?;
Ok(())
}

#[test]
/// A control flow graph consisting of an entry node and a single block
/// which adds a resource and links to both itself and the exit node.
Expand Down
26 changes: 11 additions & 15 deletions quantinuum-hugr/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(not(feature = "extension_inference"))]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand Down Expand Up @@ -198,29 +196,27 @@ impl Hugr {
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
resolve_extension_ops(self, extension_registry)?;
let closure = self.infer_extensions()?;
self.validate_with_extension_closure(closure, extension_registry)?;
self.infer_extensions()?;
self.validate(extension_registry)?;
Ok(())
}

/// Infer extension requirements and add new information to `op_types` field
/// (if the "extension_inference" feature is on; otherwise, do nothing)
///
/// See [`infer_extensions`] for details on the "closure" value
#[cfg(feature = "extension_inference")]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
let (solution, extension_closure) = infer_extensions(self)?;
self.instantiate_extensions(solution);
Ok(extension_closure)
}
/// Do nothing - this functionality is gated by the feature "extension_inference"
#[cfg(not(feature = "extension_inference"))]
pub fn infer_extensions(&mut self) -> Result<ExtensionSolution, InferExtensionError> {
Ok(HashMap::new())
pub fn infer_extensions(&mut self) -> Result<(), InferExtensionError> {
#[cfg(feature = "extension_inference")]
{
let solution = infer_extensions(self)?;
self.instantiate_extensions(&solution);
}
Ok(())
}

#[allow(dead_code)]
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
fn instantiate_extensions(&mut self, solution: &ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for (node, input_extensions) in solution.iter() {
Expand Down
7 changes: 3 additions & 4 deletions quantinuum-hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ fn children_restrictions() {
b.update_validate(&EMPTY_REG),
Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy)
);
let closure = b.infer_extensions().unwrap();
b.infer_extensions().unwrap();
b.set_parent(new_def, root);

// After moving the previous definition to a valid place,
// add an input node to the module subgraph
let new_input = b.add_node_with_parent(root, ops::Input::new(type_row![]));
assert_matches!(
b.validate_with_extension_closure(closure, &EMPTY_REG),
b.validate(&EMPTY_REG),
Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)}
);
}
Expand Down Expand Up @@ -591,8 +591,7 @@ mod extension_tests {
.unwrap();
// Write Extension annotations into the Hugr while it's still well-formed
// enough for us to compute them
let closure = b.infer_extensions().unwrap();
b.instantiate_extensions(closure);
b.infer_extensions().unwrap();
b.validate(&EMPTY_REG).unwrap();
b.replace_op(
copy,
Expand Down