From cf8f5199947cbde29986d2e895861cc87b9e8b7b Mon Sep 17 00:00:00 2001 From: Alex Vitkov Date: Thu, 10 Aug 2023 17:07:34 +0300 Subject: [PATCH] fix: properly capture lvalues in closure environments (#2120) --- .../src/hir/resolution/resolver.rs | 7 ++- .../src/monomorphization/mod.rs | 48 ++++++++++++------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index 6b3ea3eed69..4bfb5428ed7 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -919,7 +919,10 @@ impl<'a> Resolver<'a> { fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue { match lvalue { LValue::Ident(ident) => { - HirLValue::Ident(self.find_variable_or_default(&ident).0, Type::Error) + let ident = self.find_variable_or_default(&ident); + self.resolve_local_variable(ident.0, ident.1); + + HirLValue::Ident(ident.0, Type::Error) } LValue::MemberAccess { object, field_name } => { let object = Box::new(self.resolve_lvalue(*object)); @@ -1018,8 +1021,8 @@ impl<'a> Resolver<'a> { self.interner.push_definition_type(hir_ident.id, typ); } } - // We ignore the above definition kinds because only local variables can be captured by closures. DefinitionKind::Local(_) => { + // only local variables can be captured by closures. self.resolve_local_variable(hir_ident, var_scope_index); } } diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index b93c95efe07..783d9f3133e 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -31,7 +31,7 @@ pub mod ast; pub mod printer; struct LambdaContext { - env_ident: Box, + env_ident: ast::Ident, captures: Vec, } @@ -552,13 +552,26 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Block(definitions) } - /// Find a captured variable in the innermost closure - fn lookup_captured(&mut self, id: node_interner::DefinitionId) -> Option { + /// Find a captured variable in the innermost closure, and construct an expression + fn lookup_captured_expr(&mut self, id: node_interner::DefinitionId) -> Option { let ctx = self.lambda_envs_stack.last()?; - ctx.captures - .iter() - .position(|capture| capture.ident.id == id) - .map(|index| ast::Expression::ExtractTupleField(ctx.env_ident.clone(), index)) + ctx.captures.iter().position(|capture| capture.ident.id == id).map(|index| { + ast::Expression::ExtractTupleField( + Box::new(ast::Expression::Ident(ctx.env_ident.clone())), + index, + ) + }) + } + + /// Find a captured variable in the innermost closure construct a LValue + fn lookup_captured_lvalue(&mut self, id: node_interner::DefinitionId) -> Option { + let ctx = self.lambda_envs_stack.last()?; + ctx.captures.iter().position(|capture| capture.ident.id == id).map(|index| { + ast::LValue::MemberAccess { + object: Box::new(ast::LValue::Ident(ctx.env_ident.clone())), + field_index: index, + } + }) } /// A local (ie non-global) ident only @@ -599,7 +612,7 @@ impl<'interner> Monomorphizer<'interner> { } } DefinitionKind::Global(expr_id) => self.expr(*expr_id), - DefinitionKind::Local(_) => self.lookup_captured(ident.id).unwrap_or_else(|| { + DefinitionKind::Local(_) => self.lookup_captured_expr(ident.id).unwrap_or_else(|| { let ident = self.local_ident(&ident).unwrap(); ast::Expression::Ident(ident) }), @@ -961,7 +974,9 @@ impl<'interner> Monomorphizer<'interner> { fn lvalue(&mut self, lvalue: HirLValue) -> ast::LValue { match lvalue { - HirLValue::Ident(ident, _) => ast::LValue::Ident(self.local_ident(&ident).unwrap()), + HirLValue::Ident(ident, _) => self + .lookup_captured_lvalue(ident.id) + .unwrap_or_else(|| ast::LValue::Ident(self.local_ident(&ident).unwrap())), HirLValue::MemberAccess { object, field_index, .. } => { let field_index = field_index.unwrap(); let object = Box::new(self.lvalue(*object)); @@ -1065,7 +1080,7 @@ impl<'interner> Monomorphizer<'interner> { match capture.transitive_capture_index { Some(field_index) => match self.lambda_envs_stack.last() { Some(lambda_ctx) => ast::Expression::ExtractTupleField( - lambda_ctx.env_ident.clone(), + Box::new(ast::Expression::Ident(lambda_ctx.env_ident.clone())), field_index, ), None => unreachable!( @@ -1096,18 +1111,16 @@ impl<'interner> Monomorphizer<'interner> { let mutable = false; let definition = Definition::Local(env_local_id); - let env_ident = ast::Expression::Ident(ast::Ident { + let env_ident = ast::Ident { location, mutable, definition, name: env_name.to_string(), typ: env_typ.clone(), - }); + }; - self.lambda_envs_stack.push(LambdaContext { - env_ident: Box::new(env_ident.clone()), - captures: lambda.captures, - }); + self.lambda_envs_stack + .push(LambdaContext { env_ident: env_ident.clone(), captures: lambda.captures }); let body = self.expr(lambda.body); self.lambda_envs_stack.pop(); @@ -1129,7 +1142,8 @@ impl<'interner> Monomorphizer<'interner> { let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; self.push_function(id, function); - let lambda_value = ast::Expression::Tuple(vec![env_ident, lambda_fn]); + let lambda_value = + ast::Expression::Tuple(vec![ast::Expression::Ident(env_ident), lambda_fn]); let block_local_id = self.next_local_id(); let block_ident_name = "closure_variable"; let block_let_stmt = ast::Expression::Let(ast::Let {