Skip to content

Commit

Permalink
fix: Three bugfixes in model import and export. (#1844)
Browse files Browse the repository at this point in the history
This PR fixes three bugs in the import/export between `hugr_model` and
`hugr_core`.

1. The types of the source and target ports of a CFG region must be
wrapped in a `ctrl` type on export.
2. Importing links that have a single input or output but aren't
otherwise connected should be valid.
3. Runtime types are valid `TypeArg`s.
  • Loading branch information
zrho authored Jan 13, 2025
1 parent c72a359 commit 87cb536
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 33 deletions.
29 changes: 26 additions & 3 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
TypeBase, TypeBound, TypeEnum,
TypeBase, TypeBound, TypeEnum, TypeRow,
},
Direction, Hugr, HugrView, IncomingPort, Node, Port,
};
Expand Down Expand Up @@ -728,8 +728,31 @@ impl<'a> Context<'a> {
}

// Get the signature of the control flow region.
// This is the same as the signature of the parent node.
let signature = Some(self.export_func_type(&self.hugr.signature(node).unwrap()));
let signature = {
let node_signature = self.hugr.signature(node).unwrap();

let mut wrap_ctrl = |types: &TypeRow| {
let types = self.export_type_row(types);
let types_ctrl = self.make_term(model::Term::Control { values: types });
self.make_term(model::Term::List {
parts: self
.bump
.alloc_slice_copy(&[model::ListPart::Item(types_ctrl)]),
})
};

let inputs = wrap_ctrl(node_signature.input());
let outputs = wrap_ctrl(node_signature.output());
let extensions = self.export_ext_set(&node_signature.runtime_reqs);

let func_type = self.make_term(model::Term::FuncType {
inputs,
outputs,
extensions,
});

Some(func_type)
};

let scope = match closure {
model::ScopeClosure::Closed => {
Expand Down
47 changes: 26 additions & 21 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,24 +216,26 @@ impl<'a> Context<'a> {
}
}

if inputs.is_empty() || outputs.is_empty() {
return Err(error_unsupported!(
"link {}#{} is missing either an input or an output port",
link_id.0,
link_id.1
));
}

// We connect the first output to all the inputs, and the first input to all the outputs
// (except the first one, which we already connected to the first input). This should
// result in the hugr having a (hyper)edge that connects all the ports.
// There should be a better way to do this.
for (node, port) in inputs.iter() {
self.hugr.connect(outputs[0].0, outputs[0].1, *node, *port);
}

for (node, port) in outputs.iter().skip(1) {
self.hugr.connect(*node, *port, inputs[0].0, inputs[0].1);
match (inputs.as_slice(), outputs.as_slice()) {
([], []) => {
unreachable!();
}
(_, [output]) => {
for (node, port) in inputs.iter() {
self.hugr.connect(output.0, output.1, *node, *port);
}
}
([input], _) => {
for (node, port) in outputs.iter() {
self.hugr.connect(*node, *port, input.0, input.1);
}
}
_ => {
return Err(error_unsupported!(
"link {:?} would require hyperedge",
link_id
));
}
}

inputs.clear();
Expand Down Expand Up @@ -996,7 +998,6 @@ impl<'a> Context<'a> {
model::Term::ListType { .. } => Err(error_unsupported!("`(list ...)` as `TypeArg`")),
model::Term::ExtSetType => Err(error_unsupported!("`ext-set` as `TypeArg`")),
model::Term::Type => Err(error_unsupported!("`type` as `TypeArg`")),
model::Term::ApplyFull { .. } => Err(error_unsupported!("custom types as `TypeArg`")),
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")),
model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")),
model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")),
Expand All @@ -1010,8 +1011,12 @@ impl<'a> Context<'a> {

model::Term::FuncType { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. }
| model::Term::NonLinearConstraint { .. } => {
| model::Term::ApplyFull { .. } => {
let ty = self.import_type(term_id)?;
Ok(TypeArg::Type { ty })
}

model::Term::Control { .. } | model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
Expand Down
13 changes: 8 additions & 5 deletions hugr-core/tests/snapshots/model__roundtrip_cfg.snap
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.
(signature (-> [?0] [?0] (ext)))
(cfg
[%2] [%3]
(signature (-> [?0] [?0] (ext)))
(signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext)))
(block [%2] [%6]
(signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext)))
(dfg
[%4] [%5]
(signature (-> [?0] [(adt [[?0]])] (ext)))
(tag 0 [%4] [%5] (signature (-> [?0] [(adt [[?0]])] (ext))))))
(block [%6] [%3]
(signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext)))
(block [%6] [%3 %9]
(signature (-> [(ctrl [?0])] [(ctrl [?0]) (ctrl [?0])] (ext)))
(dfg
[%7] [%8]
(signature (-> [?0] [(adt [[?0]])] (ext)))
(tag 0 [%7] [%8] (signature (-> [?0] [(adt [[?0]])] (ext))))))))))
(signature (-> [?0] [(adt [[?0] [?0]])] (ext)))
(tag
0
[%7] [%8]
(signature (-> [?0] [(adt [[?0] [?0]])] (ext))))))))))
6 changes: 6 additions & 0 deletions hugr-core/tests/snapshots/model__roundtrip_constraints.snap
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,9 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons
(forall ?0 type)
(where (nonlinear ?0))
[(@ prelude.Array ?0)] [(@ prelude.Array ?0) (@ prelude.Array ?0)] (ext))

(define-func util.copy
(forall ?0 type)
(where (nonlinear ?0))
[?0] [?0 ?0] (ext)
(dfg [%0] [%0 %0] (signature (-> [?0] [?0 ?0] (ext)))))
8 changes: 4 additions & 4 deletions hugr-model/tests/fixtures/model-cfg.edn
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
(signature (-> [?a] [?a] (ext)))
(cfg [%2] [%4]
(signature (-> [(ctrl [?a])] [(ctrl [?a])] (ext)))
(block [%2] [%4]
(signature (-> [(ctrl [?a])] [(ctrl [?a])] (ext)))
(block [%2] [%4 %2]
(signature (-> [(ctrl [?a])] [(ctrl [?a]) (ctrl [?a])] (ext)))
(dfg [%5] [%6]
(signature (-> [?a] [(adt [[?a]])] (ext)))
(signature (-> [?a] [(adt [[?a] [?a]])] (ext)))
(tag 0 [%5] [%6]
(signature (-> [?a] [(adt [[?a]])] (ext))))))))))
(signature (-> [?a] [(adt [[?a] [?a]])] (ext))))))))))
7 changes: 7 additions & 0 deletions hugr-model/tests/fixtures/model-constraints.edn
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@
(forall ?t type)
(where (nonlinear ?t))
[(@ prelude.Array ?t)] [(@ prelude.Array ?t) (@ prelude.Array ?t)] (ext))

(define-func util.copy
(forall ?t type)
(where (nonlinear ?t))
[?t] [?t ?t] (ext)
(dfg [%0] [%0 %0]
(signature (-> [?t] [?t ?t] (ext)))))

0 comments on commit 87cb536

Please sign in to comment.