Skip to content

Commit

Permalink
Emulate TypeBounds on parameters via constraints.
Browse files Browse the repository at this point in the history
  • Loading branch information
zrho committed Oct 31, 2024
1 parent 9962f97 commit 158edcf
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 34 deletions.
1 change: 1 addition & 0 deletions hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ workspace = true
extension_inference = []
declarative = ["serde_yaml"]
model_unstable = ["hugr-model"]
default = ["model_unstable"]

[[test]]
name = "model"
Expand Down
45 changes: 37 additions & 8 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
TypeBase, TypeEnum,
TypeBase, TypeBound, TypeEnum,
},
Direction, Hugr, HugrView, IncomingPort, Node, Port,
};
Expand Down Expand Up @@ -44,6 +44,8 @@ struct Context<'a> {
term_map: FxHashMap<model::Term<'a>, model::TermId>,
/// The current scope for local variables.
local_scope: Option<model::NodeId>,
/// Constraints to be added to the local scope.
local_constraints: Vec<model::TermId>,
/// Mapping from extension operations to their declarations.
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
}
Expand All @@ -61,6 +63,7 @@ impl<'a> Context<'a> {
term_map: FxHashMap::default(),
local_scope: None,
decl_operations: FxHashMap::default(),
local_constraints: Vec::new(),
}
}

Expand Down Expand Up @@ -171,9 +174,11 @@ impl<'a> Context<'a> {
}

fn with_local_scope<T>(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
let old_scope = self.local_scope.replace(node);
let prev_local_scope = self.local_scope.replace(node);
let prev_local_constraints = std::mem::take(&mut self.local_constraints);
let result = f(self);
self.local_scope = old_scope;
self.local_scope = prev_local_scope;
self.local_constraints = prev_local_constraints;
result
}

Expand Down Expand Up @@ -648,14 +653,23 @@ impl<'a> Context<'a> {
t: &PolyFuncTypeBase<RV>,
) -> (&'a [model::Param<'a>], model::TermId) {
let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump);
let scope = self
.local_scope
.expect("exporting poly func type outside of local scope");

for (i, param) in t.params().iter().enumerate() {
let name = self.bump.alloc_str(&i.to_string());
let r#type = self.export_type_param(param);
let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _)));
let param = model::Param::Implicit { name, r#type };
params.push(param)
}

params.extend(
self.local_constraints
.drain(..)
.map(|constraint| model::Param::Constraint { constraint }),
);

let body = self.export_func_type(t.body());

(params.into_bump_slice(), body)
Expand Down Expand Up @@ -766,20 +780,35 @@ impl<'a> Context<'a> {
self.make_term(model::Term::List { items, tail: None })
}

pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId {
pub fn export_type_param(
&mut self,
t: &TypeParam,
var: Option<model::LocalRef<'static>>,
) -> model::TermId {
match t {
// This ignores the type bound for now.
TypeParam::Type { .. } => self.make_term(model::Term::Type),
TypeParam::Type { b } => {
if let (Some(var), TypeBound::Copyable) = (var, b) {
let term = self.make_term(model::Term::Var(var));
let copy = self.make_term(model::Term::CopyConstraint { term });
let discard = self.make_term(model::Term::DiscardConstraint { term });
self.local_constraints.extend([copy, discard]);
}

self.make_term(model::Term::Type)
}
// This ignores the type bound for now.
TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType),
TypeParam::String => self.make_term(model::Term::StrType),
TypeParam::List { param } => {
let item_type = self.export_type_param(param);
let item_type = self.export_type_param(param, None);
self.make_term(model::Term::ListType { item_type })
}
TypeParam::Tuple { params } => {
let items = self.bump.alloc_slice_fill_iter(
params.iter().map(|param| self.export_type_param(param)),
params
.iter()
.map(|param| self.export_type_param(param, None)),
);
let types = self.make_term(model::Term::List { items, tail: None });
self.make_term(model::Term::ApplyFull {
Expand Down
138 changes: 112 additions & 26 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct Context<'a> {
nodes: FxHashMap<model::NodeId, Node>,

/// The types of the local variables that are currently in scope.
local_variables: FxIndexMap<&'a str, model::TermId>,
local_variables: FxIndexMap<&'a str, LocalVar>,

custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>,
}
Expand Down Expand Up @@ -157,16 +157,16 @@ impl<'a> Context<'a> {
fn resolve_local_ref(
&self,
local_ref: &model::LocalRef,
) -> Result<(usize, model::TermId), ImportError> {
) -> Result<(usize, LocalVar), ImportError> {
let term = match local_ref {
model::LocalRef::Index(_, index) => self
.local_variables
.get_index(*index as usize)
.map(|(_, term)| (*index as usize, *term)),
.map(|(_, v)| (*index as usize, *v)),
model::LocalRef::Named(name) => self
.local_variables
.get_full(name)
.map(|(index, _, term)| (index, *term)),
.map(|(index, _, v)| (index, *v)),
};

term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into())
Expand Down Expand Up @@ -883,20 +883,65 @@ impl<'a> Context<'a> {
let mut imported_params = Vec::with_capacity(decl.params.len());

for param in decl.params {
// TODO: `PolyFuncType` should be able to handle constraints
// and distinguish between implicit and explicit parameters.
match param {
model::Param::Implicit { name, r#type } => {
imported_params.push(ctx.import_type_param(*r#type)?);
ctx.local_variables.insert(name, *r#type);
ctx.local_variables.insert(name, LocalVar::new(*r#type));
}
model::Param::Explicit { name, r#type } => {
imported_params.push(ctx.import_type_param(*r#type)?);
ctx.local_variables.insert(name, *r#type);
ctx.local_variables.insert(name, LocalVar::new(*r#type));
}
model::Param::Constraint { constraint: _ } => {
return Err(error_unsupported!("constraints"));
model::Param::Constraint { .. } => {}
}
}

for param in decl.params {
if let model::Param::Constraint { constraint } = param {
let constraint = ctx.get_term(*constraint)?;

match constraint {
model::Term::CopyConstraint { term } => {
let model::Term::Var(var) = ctx.get_term(*term)? else {
return Err(error_unsupported!(
"constraint on term that is not a variable"
));
};

let var = ctx.resolve_local_ref(var)?.0;
ctx.local_variables.get_index_mut(var).unwrap().1.copy = true;
}
model::Term::DiscardConstraint { term } => {
let model::Term::Var(var) = ctx.get_term(*term)? else {
return Err(error_unsupported!(
"constraint on term that is not a variable"
));
};

let var = ctx.resolve_local_ref(var)?.0;
ctx.local_variables.get_index_mut(var).unwrap().1.discard = true;
}
_ => {
return Err(error_unsupported!("constraint other than copy or discard"))
}
}
}
}

let mut index = 0;

for param in decl.params {
// TODO: `PolyFuncType` should be able to distinguish between implicit and explicit parameters.
match param {
model::Param::Implicit { r#type, .. } => {
let bound = ctx.local_variables.get_index(index).unwrap().1.bound()?;
imported_params.push(ctx.import_type_param(*r#type, bound)?);
index += 1;
}
model::Param::Explicit { r#type, .. } => {
let bound = ctx.local_variables.get_index(index).unwrap().1.bound()?;
imported_params.push(ctx.import_type_param(*r#type, bound)?);
index += 1;
}
model::Param::Constraint { constraint: _ } => {}
}
}

Expand All @@ -906,17 +951,15 @@ impl<'a> Context<'a> {
}

/// Import a [`TypeParam`] from a term that represents a static type.
fn import_type_param(&mut self, term_id: model::TermId) -> Result<TypeParam, ImportError> {
fn import_type_param(
&mut self,
term_id: model::TermId,
bound: TypeBound,
) -> Result<TypeParam, ImportError> {
match self.get_term(term_id)? {
model::Term::Wildcard => Err(error_uninferred!("wildcard")),

model::Term::Type => {
// As part of the migration from `TypeBound`s to constraints, we pretend that all
// `TypeBound`s are copyable.
Ok(TypeParam::Type {
b: TypeBound::Copyable,
})
}
model::Term::Type => Ok(TypeParam::Type { b: bound }),

model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")),
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")),
Expand All @@ -928,7 +971,7 @@ impl<'a> Context<'a> {
model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")),

model::Term::ListType { item_type } => {
let param = Box::new(self.import_type_param(*item_type)?);
let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?);
Ok(TypeParam::List { param })
}

Expand All @@ -942,7 +985,11 @@ impl<'a> Context<'a> {
| model::Term::List { .. }
| model::Term::ExtSet { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Control { .. }
| model::Term::CopyConstraint { .. }
| model::Term::DiscardConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}

model::Term::ControlType => {
Err(error_unsupported!("type of control types as `TypeParam`"))
Expand All @@ -959,8 +1006,9 @@ impl<'a> Context<'a> {
}

model::Term::Var(var) => {
let (index, var_type) = self.resolve_local_ref(var)?;
let decl = self.import_type_param(var_type)?;
let (index, var) = self.resolve_local_ref(var)?;
let bound = var.bound()?;
let decl = self.import_type_param(var.r#type, bound)?;
Ok(TypeArg::new_var_use(index, decl))
}

Expand Down Expand Up @@ -998,7 +1046,11 @@ impl<'a> Context<'a> {

model::Term::FuncType { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Control { .. }
| model::Term::CopyConstraint { .. }
| model::Term::DiscardConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
}

Expand Down Expand Up @@ -1099,7 +1151,11 @@ impl<'a> Context<'a> {
| model::Term::List { .. }
| model::Term::Control { .. }
| model::Term::ControlType
| model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Nat(_)
| model::Term::DiscardConstraint { .. }
| model::Term::CopyConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
}

Expand Down Expand Up @@ -1246,3 +1302,33 @@ impl<'a> Names<'a> {
Ok(Self { items })
}
}

#[derive(Debug, Clone, Copy)]
struct LocalVar {
r#type: model::TermId,
copy: bool,
discard: bool,
}

impl LocalVar {
pub fn new(r#type: model::TermId) -> Self {
Self {
r#type,
copy: false,
discard: false,
}
}

pub fn bound(&self) -> Result<TypeBound, ImportError> {
match (self.copy, self.discard) {
(true, true) => Ok(TypeBound::Copyable),
(false, false) => Ok(TypeBound::Any),
(true, false) => Err(error_unsupported!(
"type that is copyable but not discardable"
)),
(false, true) => Err(error_unsupported!(
"type that is discardable but not copyable"
)),
}
}
}
7 changes: 7 additions & 0 deletions hugr-core/tests/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() {
"../../hugr-model/tests/fixtures/model-params.edn"
)));
}

#[test]
pub fn test_roundtrip_constraints() {
insta::assert_snapshot!(roundtrip(include_str!(
"../../hugr-model/tests/fixtures/model-constraints.edn"
)));
}
12 changes: 12 additions & 0 deletions hugr-core/tests/snapshots/model__roundtrip_constraints.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))"
---
(hugr 0)

(declare-func foo.func
(forall ?0 type)
(forall ?1 type)
(where (copy ?0))
(where (discard ?0))
[?0 ?1] [?0 ?0 ?1] (ext))
2 changes: 2 additions & 0 deletions hugr-model/capnp/hugr-v0.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ struct Term {
funcType @17 :FuncType;
control @18 :TermId;
controlType @19 :Void;
copyConstraint @20 :TermId;
discardConstraint @21 :TermId;
}

struct Apply {
Expand Down
7 changes: 7 additions & 0 deletions hugr-model/src/v0/binary/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult
Which::Control(values) => model::Term::Control {
values: model::TermId(values),
},

Which::CopyConstraint(term) => model::Term::CopyConstraint {
term: model::TermId(term),
},
Which::DiscardConstraint(term) => model::Term::DiscardConstraint {
term: model::TermId(term),
},
})
}

Expand Down
8 changes: 8 additions & 0 deletions hugr-model/src/v0/binary/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,13 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) {
builder.set_outputs(outputs.0);
builder.set_extensions(extensions.0);
}

model::Term::CopyConstraint { term } => {
builder.set_copy_constraint(term.0);
}

model::Term::DiscardConstraint { term } => {
builder.set_discard_constraint(term.0);
}
}
}
Loading

0 comments on commit 158edcf

Please sign in to comment.