Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make FuncDecl/FuncDefn polymorphic #692

Merged
merged 28 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0de995f
FuncDefn/FuncDecl store PolyFuncType - some horrendous hacks with // …
acl-cqc Nov 13, 2023
6c0edbb
build_main/define_function take PolyFuncType not Signature, but tail_…
acl-cqc Nov 13, 2023
3078259
extract_subgraph and FunctionBuilder::new take PolyFuncType not Signa…
acl-cqc Nov 13, 2023
271a15a
believe define_declaration is correct, but refactor a bit
acl-cqc Nov 13, 2023
b781155
declare (FuncDecl) takes PolyFuncType
acl-cqc Nov 13, 2023
c56e8c2
builder fn call takes TypeArgs - and ExtensionRegistry, ugh
acl-cqc Nov 13, 2023
9fd164b
validate_children: return early if no children
acl-cqc Nov 13, 2023
78316a1
validate_children=>validate_subtree, breaking more tests (wrong choic…
acl-cqc Nov 13, 2023
9c3048a
Separate validate_subtree
acl-cqc Nov 13, 2023
9cb43a1
Revert "validate_children: return early if no children"
acl-cqc Nov 13, 2023
41f0970
Properly check that args to binary compute_signature have no type vars
acl-cqc Nov 13, 2023
793a00c
Don't implement TypeParametrised for OpDef
acl-cqc Nov 14, 2023
c55019c
Remove comments - those were correct changes
acl-cqc Nov 14, 2023
6f5526f
fix doclink
acl-cqc Nov 14, 2023
2fa23e9
Merge 'origin/main' into new/poly_funcdeclfn
acl-cqc Nov 16, 2023
9a656d9
Add PolyFuncType::new, make validate public
acl-cqc Nov 15, 2023
293b663
validate_port from validate_subtree not validate_node; some tests
acl-cqc Nov 16, 2023
d9da463
Finally give in and add "impl From<TypeBound> for TypeParam"
acl-cqc Nov 16, 2023
3652d56
Test (and comment) re. nested FuncDefns
acl-cqc Nov 16, 2023
d2e281b
Comments re. Dataflow::call()
acl-cqc Nov 16, 2023
de374c5
Avoid Itertools::Either - can use ExternalOp
acl-cqc Nov 16, 2023
9b23560
Update since-1.75 test
acl-cqc Nov 20, 2023
9b0b5b9
Provide a BuildError::SignatureError variant
acl-cqc Nov 20, 2023
44166ce
Merge remote-tracking branch 'origin/main' into new/poly_funcdeclfn
acl-cqc Nov 20, 2023
f069c92
Comment re FuncDefn::validate_op_children
acl-cqc Nov 20, 2023
c3d4c83
CANNOT use type vars from enclosing defn's
acl-cqc Nov 21, 2023
95ddb36
Test no_polymorphic_consts, enforce no TVs on Static edges, add ListV…
acl-cqc Nov 21, 2023
7e8c515
wth -> with
acl-cqc Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pub(crate) mod test {

use crate::hugr::{views::HugrView, HugrMut, NodeType};
use crate::ops;
use crate::types::{FunctionType, Signature, Type};
use crate::types::{FunctionType, PolyFuncType, Type};
use crate::{type_row, Hugr};

use super::handle::BuildHandle;
Expand All @@ -123,7 +123,7 @@ pub(crate) mod test {
}

pub(super) fn build_main(
signature: Signature,
signature: PolyFuncType,
f: impl FnOnce(FunctionBuilder<&mut Hugr>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
) -> Result<Hugr, BuildError> {
let mut module_builder = ModuleBuilder::new();
Expand Down
30 changes: 19 additions & 11 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
};

use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE_REGISTRY};
use crate::types::{FunctionType, Signature, Type, TypeRow};
use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow};

use itertools::Itertools;

Expand Down Expand Up @@ -90,22 +90,19 @@ pub trait Container {
fn define_function(
&mut self,
name: impl Into<String>,
signature: Signature,
signature: PolyFuncType,
) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
let body = signature.body.clone();
let f_node = self.add_child_node(NodeType::new(
ops::FuncDefn {
name: name.into(),
signature: signature.signature.clone(),
signature,
},
signature.input_extensions.clone(),
ExtensionSet::new(),
))?;

let db = DFGBuilder::create_with_io(
self.hugr_mut(),
f_node,
signature.signature,
Some(signature.input_extensions),
)?;
let db =
DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?;
Ok(FunctionBuilder::from_dfg_builder(db))
}

Expand Down Expand Up @@ -598,11 +595,14 @@ pub trait Dataflow: Container {
fn call<const DEFINED: bool>(
&mut self,
function: &FuncID<DEFINED>,
type_args: &[TypeArg],
input_wires: impl IntoIterator<Item = Wire>,
// Sadly required as we substituting in type_args may result in recomputing bounds of types:
exts: &ExtensionRegistry,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let hugr = self.hugr();
let def_op = hugr.get_optype(function.node());
let signature = match def_op {
let type_scheme = match def_op {
OpType::FuncDefn(ops::FuncDefn { signature, .. })
| OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(),
_ => {
Expand All @@ -612,6 +612,14 @@ pub trait Dataflow: Container {
})
}
};
let signature = type_scheme.instantiate(type_args, exts).map_err(|e| {
BuildError::InvalidHUGR(ValidationError::SignatureError {
// TODO this is rather a horrendous hack. Do we need some way of returning a SignatureError without a node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should make a new error constructor for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done, hope the comment clarifies

// (as the call node this refers to is not constructed yet)? Or, pass in an "instantiated" FuncID?
node: function.node(),
cause: e,
})
})?;
let op: OpType = ops::Call { signature }.into();
let const_in_port = op.static_input_port().unwrap();
let op_id = self.add_dataflow_op(op, input_wires)?;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ mod test {
let build_result = {
let mut module_builder = ModuleBuilder::new();
let mut func_builder = module_builder
.define_function("main", FunctionType::new(vec![NAT], type_row![NAT]).pure())?;
.define_function("main", FunctionType::new(vec![NAT], type_row![NAT]).into())?;
let _f_id = {
let [int] = func_builder.input_wires_arr();

Expand Down
4 changes: 2 additions & 2 deletions src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ mod test {
#[test]
fn simple_linear() {
let build_res = build_main(
FunctionType::new(type_row![QB, QB], type_row![QB, QB]).pure(),
FunctionType::new(type_row![QB, QB], type_row![QB, QB]).into(),
|mut f_build| {
let wires = f_build.input_wires().collect();

Expand Down Expand Up @@ -184,7 +184,7 @@ mod test {
.into(),
);
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).pure(),
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

Expand Down
2 changes: 1 addition & 1 deletion src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ mod test {
let mut module_builder = ModuleBuilder::new();
let mut fbuild = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;
let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?;
let _fdef = {
Expand Down
28 changes: 12 additions & 16 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::marker::PhantomData;
use crate::hugr::{HugrView, NodeType, ValidationError};
use crate::ops;

use crate::types::{FunctionType, Signature};
use crate::types::{FunctionType, PolyFuncType};

use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::Node;
Expand Down Expand Up @@ -146,21 +146,17 @@ impl FunctionBuilder<Hugr> {
/// # Errors
///
/// Error in adding DFG child nodes.
pub fn new(name: impl Into<String>, signature: Signature) -> Result<Self, BuildError> {
pub fn new(name: impl Into<String>, signature: PolyFuncType) -> Result<Self, BuildError> {
let body = signature.body.clone();
let op = ops::FuncDefn {
signature: signature.clone().into(),
signature,
name: name.into(),
};

let base = Hugr::new(NodeType::new(op, signature.input_extensions.clone()));
let base = Hugr::new(NodeType::new_pure(op));
let root = base.root();

let db = DFGBuilder::create_with_io(
base,
root,
signature.signature,
Some(signature.input_extensions),
)?;
let db = DFGBuilder::create_with_io(base, root, body, Some(ExtensionSet::new()))?;
Ok(Self::from_dfg_builder(db))
}
}
Expand Down Expand Up @@ -239,7 +235,7 @@ pub(crate) mod test {
let _f_id = {
let mut func_builder = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).pure(),
FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(),
)?;

let [int, qb] = func_builder.input_wires_arr();
Expand Down Expand Up @@ -273,7 +269,7 @@ pub(crate) mod test {

let f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).pure(),
FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).into(),
)?;

f(f_build)?;
Expand Down Expand Up @@ -323,7 +319,7 @@ pub(crate) mod test {

let f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![QB], type_row![QB, QB]).pure(),
FunctionType::new(type_row![QB], type_row![QB, QB]).into(),
)?;

let [q1] = f_build.input_wires_arr();
Expand All @@ -340,7 +336,7 @@ pub(crate) mod test {
let builder = || -> Result<Hugr, BuildError> {
let mut f_build = FunctionBuilder::new(
"main",
FunctionType::new(type_row![BIT], type_row![BIT]).pure(),
FunctionType::new(type_row![BIT], type_row![BIT]).into(),
)?;

let [i1] = f_build.input_wires_arr();
Expand All @@ -364,7 +360,7 @@ pub(crate) mod test {
fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
let mut f_build = FunctionBuilder::new(
"main",
FunctionType::new(type_row![QB], type_row![QB]).pure(),
FunctionType::new(type_row![QB], type_row![QB]).into(),
)?;

let [i1] = f_build.input_wires_arr();
Expand Down Expand Up @@ -408,7 +404,7 @@ pub(crate) mod test {
let (dfg_node, f_node) = {
let mut f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![BIT], type_row![BIT]).pure(),
FunctionType::new(type_row![BIT], type_row![BIT]).into(),
)?;

let [i1] = f_build.input_wires_arr();
Expand Down
42 changes: 17 additions & 25 deletions src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ use crate::{
extension::ExtensionRegistry,
hugr::{hugrmut::sealed::HugrMutInternals, views::HugrView, ValidationError},
ops,
types::{Type, TypeBound},
types::{PolyFuncType, Type, TypeBound},
};

use crate::ops::handle::{AliasID, FuncID, NodeHandle};

use crate::types::Signature;

use crate::Node;
use smol_str::SmolStr;

Expand Down Expand Up @@ -86,16 +84,13 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
op_desc: "crate::ops::OpType::FuncDecl",
})?
.clone();

let body = signature.body.clone();
self.hugr_mut().replace_op(
f_node,
NodeType::new_pure(ops::FuncDefn {
name,
signature: signature.clone(),
}),
NodeType::new_pure(ops::FuncDefn { name, signature }),
)?;

let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, signature, None)?;
let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, None)?;
Ok(FunctionBuilder::from_dfg_builder(db))
}

Expand All @@ -108,17 +103,13 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
pub fn declare(
&mut self,
name: impl Into<String>,
signature: Signature,
signature: PolyFuncType,
) -> Result<FuncID<false>, BuildError> {
// TODO add param names to metadata
let rs = signature.input_extensions.clone();
let declare_n = self.add_child_node(NodeType::new(
ops::FuncDecl {
signature: signature.into(),
name: name.into(),
},
rs,
))?;
let declare_n = self.add_child_node(NodeType::new_pure(ops::FuncDecl {
signature,
name: name.into(),
}))?;

Ok(declare_n.into())
}
Expand Down Expand Up @@ -176,7 +167,7 @@ mod test {
test::{n_identity, NAT},
Dataflow, DataflowSubContainer,
},
extension::EMPTY_REG,
extension::{EMPTY_REG, PRELUDE_REGISTRY},
type_row,
types::FunctionType,
};
Expand All @@ -189,11 +180,11 @@ mod test {

let f_id = module_builder.declare(
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;

let mut f_build = module_builder.define_declaration(&f_id)?;
let call = f_build.call(&f_id, f_build.input_wires())?;
let call = f_build.call(&f_id, &[], f_build.input_wires(), &PRELUDE_REGISTRY)?;

f_build.finish_with_outputs(call.outputs())?;
module_builder.finish_prelude_hugr()
Expand All @@ -216,7 +207,7 @@ mod test {
vec![qubit_state_type.get_alias_type()],
vec![qubit_state_type.get_alias_type()],
)
.pure(),
.into(),
)?;
n_identity(f_build)?;
module_builder.finish_hugr(&EMPTY_REG)
Expand All @@ -232,16 +223,17 @@ mod test {

let mut f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).into(),
)?;
let local_build = f_build.define_function(
"local",
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).into(),
)?;
let [wire] = local_build.input_wires_arr();
let f_id = local_build.finish_with_outputs([wire, wire])?;

let call = f_build.call(f_id.handle(), f_build.input_wires())?;
let call =
f_build.call(f_id.handle(), &[], f_build.input_wires(), &PRELUDE_REGISTRY)?;

f_build.finish_with_outputs(call.outputs())?;
module_builder.finish_prelude_hugr()
Expand Down
13 changes: 11 additions & 2 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,19 @@ mod test {
let mut fbuild = module_builder.define_function(
"main",
FunctionType::new(type_row![BIT], type_row![NAT])
.with_input_extensions(ExtensionSet::singleton(&PRELUDE_ID)),
.with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID))
.into(),
)?;
let _fdef = {
let [b1] = fbuild.input_wires_arr();
let [b1] = fbuild
.add_dataflow_op(
ops::LeafOp::Lift {
type_row: type_row![BIT],
new_extension: PRELUDE_ID,
},
fbuild.input_wires(),
)?
.outputs_arr();
let loop_id = {
let mut loop_b =
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
Expand Down
6 changes: 3 additions & 3 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
.into(),
)?;

let [w] = func_builder.input_wires_arr();
Expand All @@ -979,7 +979,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
.into(),
)?;

let [w] = func_builder.input_wires_arr();
Expand Down Expand Up @@ -1013,7 +1013,7 @@ fn funcdefn_signature_mismatch2() -> Result<(), Box<dyn Error>> {
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
.into(),
)?;

let [w] = func_builder.input_wires_arr();
Expand Down
Loading
Loading