Skip to content

Commit

Permalink
feat: Allow specifying type signatures in do bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Marwes committed Dec 29, 2020
1 parent d3bfc59 commit fac08dc
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 24 deletions.
5 changes: 5 additions & 0 deletions base/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ pub struct ExprField<'ast, Id, E> {
#[derive(Eq, PartialEq, Debug, AstClone)]
pub struct Do<'ast, Id> {
pub id: Option<SpannedPattern<'ast, Id>>,
pub typ: Option<AstType<'ast, Id>>,
pub bound: &'ast mut SpannedExpr<'ast, Id>,
pub body: &'ast mut SpannedExpr<'ast, Id>,
pub flat_map_id: Option<&'ast mut SpannedExpr<'ast, Id>>,
Expand Down Expand Up @@ -872,13 +873,17 @@ pub fn walk_expr<'a, 'ast, V>(v: &mut V, e: &'a $($mut)* SpannedExpr<'ast, V::Id

Expr::Do(Do {
ref $($mut)* id,
ref $($mut)* typ,
ref $($mut)* bound,
ref $($mut)* body,
ref $($mut)* flat_map_id,
}) => {
if let Some(id) = id {
v.visit_pattern(id);
}
if let Some(ast_type) = typ {
v.visit_ast_type(ast_type)
}
v.visit_expr(bound);
v.visit_expr(body);
if let Some(ref $($mut)* flat_map_id) = *flat_map_id {
Expand Down
40 changes: 24 additions & 16 deletions check/src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,7 @@ impl<'a, 'ast> Typecheck<'a, 'ast> {
}
Expr::Do(Do {
ref mut id,
ref mut typ,
ref mut bound,
ref mut body,
ref mut flat_map_id,
Expand Down Expand Up @@ -1222,10 +1223,10 @@ impl<'a, 'ast> Typecheck<'a, 'ast> {
_ => flat_map_type.clone(),
};

let id_var = self.subs.new_var();
let id_type = self.resolve_type_signature(typ.as_mut());
let arg1 = self
.subs
.function(Some(id_var.clone()), self.subs.new_var());
.function(Some(id_type.concrete.clone()), self.subs.new_var());
let arg2 = self.subs.new_var();

let ret = expected_type
Expand All @@ -1241,7 +1242,7 @@ impl<'a, 'ast> Typecheck<'a, 'ast> {
self.unify_span(do_span, &flat_map_type, func_type);

if let Some(ref mut id) = *id {
self.typecheck_pattern(id, ModType::wobbly(id_var.clone()), id_var);
self.typecheck_pattern(id, id_type.clone(), id_type.concrete);
}

let body_type = self.typecheck(body, ret.as_ref());
Expand Down Expand Up @@ -1890,6 +1891,25 @@ impl<'a, 'ast> Typecheck<'a, 'ast> {
}
}

fn resolve_type_signature(
&mut self,
mut signature: Option<&mut AstType<'_, Symbol>>,
) -> ModType {
let mut mod_type = if let Some(ref mut typ) = signature {
self.kindcheck(typ);
let rc_type = self.translate_ast_type(typ);

ModType::rigid(rc_type)
} else {
ModType::wobbly(self.subs.hole())
};

if let Some(typ) = self.create_unifiable_signature(&mod_type) {
mod_type.concrete = typ;
}
mod_type
}

fn typecheck_let_bindings(&mut self, bindings: &mut ValueBindings<'ast, Symbol>) {
self.enter_scope();
self.environment.skolem_variables.enter_scope();
Expand All @@ -1902,19 +1922,7 @@ impl<'a, 'ast> Typecheck<'a, 'ast> {
if is_recursive {
for (i, bind) in bindings.iter_mut().enumerate() {
let typ = {
if let Some(ref mut typ) = bind.typ {
self.kindcheck(typ);
let rc_type = self.translate_ast_type(typ);

resolved_types.push(ModType::rigid(rc_type));
} else {
resolved_types.push(ModType::wobbly(self.subs.hole()));
}

let typ = self.create_unifiable_signature(&resolved_types[i]);
if let Some(typ) = typ {
resolved_types[i].concrete = typ;
}
resolved_types.push(self.resolve_type_signature(bind.typ.as_mut()));

resolved_types[i].concrete.clone()
};
Expand Down
13 changes: 13 additions & 0 deletions check/tests/fail.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,16 @@ match Test 0 with
"#,
PatternError { .. }
}

test_check_err! {
do_type_signature_error,
r#"
type List a = | Cons a (List a) | Nil
let flat_map f x : (a -> List b) -> List a -> List b = Nil
do writer : Int = Cons "" Nil
Cons "" Nil
"#,
Unification { .. }
}
13 changes: 13 additions & 0 deletions check/tests/pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1172,3 +1172,16 @@ viewl
"#,
"forall a . test.FingerTree a -> test.View a"
}

test_check! {
do_type_signature,
r#"
type List a = | Cons a (List a) | Nil
let flat_map f x : (a -> List b) -> List a -> List b = Nil
do writer : String = Nil
Cons 0 Nil
"#,
"test.List Int"
}
26 changes: 19 additions & 7 deletions parser/src/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -889,18 +889,30 @@ Expr: Expr<'ast, Id> = {

DoExpression: Expr<'ast, Id> = {
"do" <bind: DoBinding> <body: SpExpr> => {
let bind = *bind;
Expr::Do(arena.alloc(Do { id: Some(bind.0), bound: arena.alloc(bind.1), body: arena.alloc(body), flat_map_id: None }))
let (id, typ, bound) = *bind;
Expr::Do(arena.alloc(Do {
id: Some(id),
typ,
bound: arena.alloc(bound),
body: arena.alloc(body),
flat_map_id: None,
}))
},

"seq" <bound: SpExpr> <body: InExpr> => {
Expr::Do(arena.alloc(Do { id: None, bound: arena.alloc(bound), body: arena.alloc(body), flat_map_id: None }))
Expr::Do(arena.alloc(Do {
id: None,
typ: None,
bound: arena.alloc(bound),
body: arena.alloc(body),
flat_map_id: None,
}))
},
};


DoBinding: Box<(SpannedPattern<'ast, Id>, SpannedExpr<'ast, Id>)> = {
<id: Sp<Pattern>> "=" <bound: SpExpr> "in" => Box::new((id, bound)),
DoBinding: Box<(SpannedPattern<'ast, Id>, Option<AstType<'ast, Id>>, SpannedExpr<'ast, Id>)> = {
<id: Sp<Pattern>> <typ: (":" <Type>)?> "=" <bound: SpExpr> "in" => Box::new((id, typ, bound)),

// Error recovery

Expand All @@ -910,7 +922,7 @@ DoBinding: Box<(SpannedPattern<'ast, Id>, SpannedExpr<'ast, Id>)> = {
token: (span.start(), Token::In, span.end()),
expected: ["="].iter().map(|s| s.to_string()).collect(),
});
Box::new((id, pos::spanned(span, Expr::Error(None))))
Box::new((id, None, pos::spanned(span, Expr::Error(None))))
},
};

Expand All @@ -930,7 +942,7 @@ BlockExpr: Expr<'ast, Id> = {
pos::spanned2(
expr.span.start(),
body.span.end(),
Expr::Do(arena.alloc(Do { id: None, bound: arena.alloc(expr), body: arena.alloc(body), flat_map_id: None })),
Expr::Do(arena.alloc(Do { id: None, typ: None, bound: arena.alloc(expr), body: arena.alloc(body), flat_map_id: None })),
)
}).value
}
Expand Down
24 changes: 23 additions & 1 deletion vm/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,7 @@ impl<'a, 'e> Translator<'a, 'e> {
ref bound,
ref body,
ref flat_map_id,
..
}) => {
let flat_map_id = flat_map_id
.as_ref()
Expand Down Expand Up @@ -2371,7 +2372,7 @@ pub mod tests {
use codespan_reporting::files::Files;

use crate::base::{
ast,
ast, pos,
source::Source,
symbol::{Symbol, SymbolModule, Symbols},
types::TypeCache,
Expand All @@ -2393,6 +2394,27 @@ pub mod tests {
)
.unwrap(),
);

struct Visitor<'ast>(&'ast ast::Arena<'ast, Symbol>);

impl<'ast> ast::MutVisitor<'_, 'ast> for Visitor<'ast> {
type Ident = Symbol;
fn visit_expr(&mut self, expr: &mut SpannedExpr<'ast, Symbol>) {
match &mut expr.value {
ast::Expr::Do(d) => {
d.flat_map_id = Some(self.0.alloc(pos::spanned(
expr.span,
ast::Expr::Ident(TypedIdent::new(Symbol::from("flat_map"))),
)))
}
_ => (),
}
ast::walk_mut_expr(self, expr);
}
}

ast::MutVisitor::visit_expr(&mut Visitor(&arena), expr);

ast::RootExpr::new(arena.clone(), expr)
}

Expand Down

0 comments on commit fac08dc

Please sign in to comment.