From 158edcfdef2028210f7ccea6d674d8996edc3cd3 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Mon, 14 Oct 2024 15:21:28 +0100 Subject: [PATCH] Emulate `TypeBound`s on parameters via constraints. --- hugr-core/Cargo.toml | 1 + hugr-core/src/export.rs | 45 +++++- hugr-core/src/import.rs | 138 ++++++++++++++---- hugr-core/tests/model.rs | 7 + .../model__roundtrip_constraints.snap | 12 ++ hugr-model/capnp/hugr-v0.capnp | 2 + hugr-model/src/v0/binary/read.rs | 7 + hugr-model/src/v0/binary/write.rs | 8 + hugr-model/src/v0/mod.rs | 12 ++ hugr-model/src/v0/text/hugr.pest | 4 + hugr-model/src/v0/text/parse.rs | 10 ++ hugr-model/src/v0/text/print.rs | 8 + .../tests/fixtures/model-constraints.edn | 9 ++ 13 files changed, 229 insertions(+), 34 deletions(-) create mode 100644 hugr-core/tests/snapshots/model__roundtrip_constraints.snap create mode 100644 hugr-model/tests/fixtures/model-constraints.edn diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index aa7c631a20..81c2bfc810 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -20,6 +20,7 @@ workspace = true extension_inference = [] declarative = ["serde_yaml"] model_unstable = ["hugr-model"] +default = ["model_unstable"] [[test]] name = "model" diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index e7a85c98fa..93e8049f3e 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeEnum, + TypeBase, TypeBound, TypeEnum, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -44,6 +44,8 @@ struct Context<'a> { term_map: FxHashMap, model::TermId>, /// The current scope for local variables. local_scope: Option, + /// Constraints to be added to the local scope. + local_constraints: Vec, /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, } @@ -61,6 +63,7 @@ impl<'a> Context<'a> { term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), + local_constraints: Vec::new(), } } @@ -171,9 +174,11 @@ impl<'a> Context<'a> { } fn with_local_scope(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { - let old_scope = self.local_scope.replace(node); + let prev_local_scope = self.local_scope.replace(node); + let prev_local_constraints = std::mem::take(&mut self.local_constraints); let result = f(self); - self.local_scope = old_scope; + self.local_scope = prev_local_scope; + self.local_constraints = prev_local_constraints; result } @@ -648,14 +653,23 @@ impl<'a> Context<'a> { t: &PolyFuncTypeBase, ) -> (&'a [model::Param<'a>], model::TermId) { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let scope = self + .local_scope + .expect("exporting poly func type outside of local scope"); for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param); + let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); let param = model::Param::Implicit { name, r#type }; params.push(param) } + params.extend( + self.local_constraints + .drain(..) + .map(|constraint| model::Param::Constraint { constraint }), + ); + let body = self.export_func_type(t.body()); (params.into_bump_slice(), body) @@ -766,20 +780,35 @@ impl<'a> Context<'a> { self.make_term(model::Term::List { items, tail: None }) } - pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { + pub fn export_type_param( + &mut self, + t: &TypeParam, + var: Option>, + ) -> model::TermId { match t { // This ignores the type bound for now. - TypeParam::Type { .. } => self.make_term(model::Term::Type), + TypeParam::Type { b } => { + if let (Some(var), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(var)); + let copy = self.make_term(model::Term::CopyConstraint { term }); + let discard = self.make_term(model::Term::DiscardConstraint { term }); + self.local_constraints.extend([copy, discard]); + } + + self.make_term(model::Term::Type) + } // This ignores the type bound for now. TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), TypeParam::String => self.make_term(model::Term::StrType), TypeParam::List { param } => { - let item_type = self.export_type_param(param); + let item_type = self.export_type_param(param, None); self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { let items = self.bump.alloc_slice_fill_iter( - params.iter().map(|param| self.export_type_param(param)), + params + .iter() + .map(|param| self.export_type_param(param, None)), ); let types = self.make_term(model::Term::List { items, tail: None }); self.make_term(model::Term::ApplyFull { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index d981049fba..5b039291a2 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -114,7 +114,7 @@ struct Context<'a> { nodes: FxHashMap, /// The types of the local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, model::TermId>, + local_variables: FxIndexMap<&'a str, LocalVar>, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, } @@ -157,16 +157,16 @@ impl<'a> Context<'a> { fn resolve_local_ref( &self, local_ref: &model::LocalRef, - ) -> Result<(usize, model::TermId), ImportError> { + ) -> Result<(usize, LocalVar), ImportError> { let term = match local_ref { model::LocalRef::Index(_, index) => self .local_variables .get_index(*index as usize) - .map(|(_, term)| (*index as usize, *term)), + .map(|(_, v)| (*index as usize, *v)), model::LocalRef::Named(name) => self .local_variables .get_full(name) - .map(|(index, _, term)| (index, *term)), + .map(|(index, _, v)| (index, *v)), }; term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) @@ -883,20 +883,65 @@ impl<'a> Context<'a> { let mut imported_params = Vec::with_capacity(decl.params.len()); for param in decl.params { - // TODO: `PolyFuncType` should be able to handle constraints - // and distinguish between implicit and explicit parameters. match param { model::Param::Implicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); + ctx.local_variables.insert(name, LocalVar::new(*r#type)); } model::Param::Explicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); + ctx.local_variables.insert(name, LocalVar::new(*r#type)); } - model::Param::Constraint { constraint: _ } => { - return Err(error_unsupported!("constraints")); + model::Param::Constraint { .. } => {} + } + } + + for param in decl.params { + if let model::Param::Constraint { constraint } = param { + let constraint = ctx.get_term(*constraint)?; + + match constraint { + model::Term::CopyConstraint { term } => { + let model::Term::Var(var) = ctx.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + let var = ctx.resolve_local_ref(var)?.0; + ctx.local_variables.get_index_mut(var).unwrap().1.copy = true; + } + model::Term::DiscardConstraint { term } => { + let model::Term::Var(var) = ctx.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + let var = ctx.resolve_local_ref(var)?.0; + ctx.local_variables.get_index_mut(var).unwrap().1.discard = true; + } + _ => { + return Err(error_unsupported!("constraint other than copy or discard")) + } + } + } + } + + let mut index = 0; + + for param in decl.params { + // TODO: `PolyFuncType` should be able to distinguish between implicit and explicit parameters. + match param { + model::Param::Implicit { r#type, .. } => { + let bound = ctx.local_variables.get_index(index).unwrap().1.bound()?; + imported_params.push(ctx.import_type_param(*r#type, bound)?); + index += 1; + } + model::Param::Explicit { r#type, .. } => { + let bound = ctx.local_variables.get_index(index).unwrap().1.bound()?; + imported_params.push(ctx.import_type_param(*r#type, bound)?); + index += 1; } + model::Param::Constraint { constraint: _ } => {} } } @@ -906,17 +951,15 @@ impl<'a> Context<'a> { } /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param(&mut self, term_id: model::TermId) -> Result { + fn import_type_param( + &mut self, + term_id: model::TermId, + bound: TypeBound, + ) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Type => { - // As part of the migration from `TypeBound`s to constraints, we pretend that all - // `TypeBound`s are copyable. - Ok(TypeParam::Type { - b: TypeBound::Copyable, - }) - } + model::Term::Type => Ok(TypeParam::Type { b: bound }), model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), @@ -928,7 +971,7 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), model::Term::ListType { item_type } => { - let param = Box::new(self.import_type_param(*item_type)?); + let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?); Ok(TypeParam::List { param }) } @@ -942,7 +985,11 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::ExtSet { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::CopyConstraint { .. } + | model::Term::DiscardConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } model::Term::ControlType => { Err(error_unsupported!("type of control types as `TypeParam`")) @@ -959,8 +1006,9 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var_type) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var_type)?; + let (index, var) = self.resolve_local_ref(var)?; + let bound = var.bound()?; + let decl = self.import_type_param(var.r#type, bound)?; Ok(TypeArg::new_var_use(index, decl)) } @@ -998,7 +1046,11 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::CopyConstraint { .. } + | model::Term::DiscardConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1099,7 +1151,11 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::Control { .. } | model::Term::ControlType - | model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Nat(_) + | model::Term::DiscardConstraint { .. } + | model::Term::CopyConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1246,3 +1302,33 @@ impl<'a> Names<'a> { Ok(Self { items }) } } + +#[derive(Debug, Clone, Copy)] +struct LocalVar { + r#type: model::TermId, + copy: bool, + discard: bool, +} + +impl LocalVar { + pub fn new(r#type: model::TermId) -> Self { + Self { + r#type, + copy: false, + discard: false, + } + } + + pub fn bound(&self) -> Result { + match (self.copy, self.discard) { + (true, true) => Ok(TypeBound::Copyable), + (false, false) => Ok(TypeBound::Any), + (true, false) => Err(error_unsupported!( + "type that is copyable but not discardable" + )), + (false, true) => Err(error_unsupported!( + "type that is discardable but not copyable" + )), + } + } +} diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 611eda660d..d9ef0d2c9a 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() { "../../hugr-model/tests/fixtures/model-params.edn" ))); } + +#[test] +pub fn test_roundtrip_constraints() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-constraints.edn" + ))); +} diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap new file mode 100644 index 0000000000..c0f959094a --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -0,0 +1,12 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" +--- +(hugr 0) + +(declare-func foo.func + (forall ?0 type) + (forall ?1 type) + (where (copy ?0)) + (where (discard ?0)) + [?0 ?1] [?0 ?0 ?1] (ext)) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 95db81205a..644607dc50 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -157,6 +157,8 @@ struct Term { funcType @17 :FuncType; control @18 :TermId; controlType @19 :Void; + copyConstraint @20 :TermId; + discardConstraint @21 :TermId; } struct Apply { diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 681bd4ea9a..5e305f964a 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -332,6 +332,13 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::Control(values) => model::Term::Control { values: model::TermId(values), }, + + Which::CopyConstraint(term) => model::Term::CopyConstraint { + term: model::TermId(term), + }, + Which::DiscardConstraint(term) => model::Term::DiscardConstraint { + term: model::TermId(term), + }, }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index a4b64d646c..6da291e44a 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -212,5 +212,13 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { builder.set_outputs(outputs.0); builder.set_extensions(extensions.0); } + + model::Term::CopyConstraint { term } => { + builder.set_copy_constraint(term.0); + } + + model::Term::DiscardConstraint { term } => { + builder.set_discard_constraint(term.0); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index cb8713b322..40f17a4fc0 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -662,6 +662,18 @@ pub enum Term<'a> { /// /// `ctrl : static` ControlType, + + /// Constraint that requires a runtime type to be copyable. + CopyConstraint { + /// The runtime type that must be copyable. + term: TermId, + }, + + /// Constraint that requires a runtime type to be discardable. + DiscardConstraint { + /// The runtime type that must be discardable. + term: TermId, + }, } /// A parameter to a function or alias. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 33974a76aa..5d181e8a82 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -88,6 +88,8 @@ term = { | term_ctrl_type | term_apply_full | term_apply + | term_copy + | term_discard } term_wildcard = { "_" } @@ -110,3 +112,5 @@ term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } +term_copy = { "(" ~ "copy" ~ term ~ ")" } +term_discard = { "(" ~ "discard" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index b669ce38cc..facb656e50 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -211,6 +211,16 @@ impl<'a> ParseContext<'a> { Term::Control { values } } + Rule::term_copy => { + let term = self.parse_term(inner.next().unwrap())?; + Term::CopyConstraint { term } + } + + Rule::term_discard => { + let term = self.parse_term(inner.next().unwrap())?; + Term::DiscardConstraint { term } + } + r => unreachable!("term: {:?}", r), }; diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 494c10df2c..c0d4afb84e 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -598,6 +598,14 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("ctrl"); Ok(()) } + Term::CopyConstraint { term } => self.print_parens(|this| { + this.print_text("copy"); + this.print_term(*term) + }), + Term::DiscardConstraint { term } => self.print_parens(|this| { + this.print_text("discard"); + this.print_term(*term) + }), } } diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn new file mode 100644 index 0000000000..9a232ab109 --- /dev/null +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -0,0 +1,9 @@ +(hugr 0) + +(declare-func foo.func + (forall ?x type) + (forall ?y type) + (where (copy ?x)) + (where (discard ?x)) + [?x ?y] [?x ?x ?y] + (ext))