diff --git a/crates/noirc_frontend/src/ast/expression.rs b/crates/noirc_frontend/src/ast/expression.rs index ac6161ddac1..3f73604f648 100644 --- a/crates/noirc_frontend/src/ast/expression.rs +++ b/crates/noirc_frontend/src/ast/expression.rs @@ -106,6 +106,12 @@ impl Recoverable for Expression { } } +impl Recoverable for Option { + fn error(span: Span) -> Self { + Some(Expression::new(ExpressionKind::Error, span)) + } +} + #[derive(Debug, Eq, Clone)] pub struct Expression { pub kind: ExpressionKind, diff --git a/crates/noirc_frontend/src/ast/statement.rs b/crates/noirc_frontend/src/ast/statement.rs index 5e0dd4e4391..9c8fb6c20cb 100644 --- a/crates/noirc_frontend/src/ast/statement.rs +++ b/crates/noirc_frontend/src/ast/statement.rs @@ -25,6 +25,12 @@ pub enum Statement { Assign(AssignStatement), // This is an expression with a trailing semi-colon Semi(Expression), + Return { + expr: Option, + // Initially `false`, but after semicolon validation it says whether + // the semicolon was present. + semi: bool, + }, // This statement is the result of a recovered parse error. // To avoid issuing multiple errors in later steps, it should // be skipped in any future analysis if possible. @@ -65,6 +71,18 @@ impl Statement { self } + Statement::Return { expr, semi: false } => { + if !last_statement_in_block && semi.is_none() { + let reason = "Expected a ; separating these two statements".to_string(); + emit_error(ParserError::with_reason(reason, span)); + } + Statement::Return { expr, semi: semi.is_some() } + } + + Statement::Return { expr: _, semi: true } => { + unreachable!() + } + Statement::Expression(expr) => { match (&expr.kind, semi, last_statement_in_block) { // Semicolons are optional for these expressions @@ -394,6 +412,10 @@ impl Display for Statement { Statement::Expression(expression) => expression.fmt(f), Statement::Assign(assign) => assign.fmt(f), Statement::Semi(semi) => write!(f, "{semi};"), + Statement::Return { expr: Some(expr), semi: true } => write!(f, "return {expr};"), + Statement::Return { expr: None, semi: true } => write!(f, "return;"), + Statement::Return { expr: Some(expr), semi: false } => write!(f, "return {expr}"), + Statement::Return { expr: None, semi: false } => write!(f, "return"), Statement::Error => write!(f, "Error"), } } diff --git a/crates/noirc_frontend/src/hir/resolution/resolver.rs b/crates/noirc_frontend/src/hir/resolution/resolver.rs index cfb354498ab..f69e8133594 100644 --- a/crates/noirc_frontend/src/hir/resolution/resolver.rs +++ b/crates/noirc_frontend/src/hir/resolution/resolver.rs @@ -803,7 +803,7 @@ impl<'a> Resolver<'a> { let stmt = HirAssignStatement { lvalue: identifier, expression }; HirStatement::Assign(stmt) } - Statement::Error => HirStatement::Error, + Statement::Error | Statement::Return { .. } => HirStatement::Error, } } diff --git a/crates/noirc_frontend/src/parser/parser.rs b/crates/noirc_frontend/src/parser/parser.rs index f4793d06368..24ce8bb34e1 100644 --- a/crates/noirc_frontend/src/parser/parser.rs +++ b/crates/noirc_frontend/src/parser/parser.rs @@ -430,6 +430,7 @@ where constrain(expr_parser.clone()), declaration(expr_parser.clone()), assignment(expr_parser.clone()), + return_statement(expr_parser.clone()), expr_parser.map(Statement::Expression), )) } @@ -657,6 +658,18 @@ fn expression() -> impl ExprParser { .labelled("expression") } +fn return_statement<'a, P>(expr_parser: P) -> impl NoirParser + 'a +where + P: ExprParser + 'a, +{ + ignore_then_commit(keyword(Keyword::Return), expr_parser.or_not()) + .validate(|expr, span, emit| { + emit(ParserError::with_reason("Early 'return' is unsupported".to_owned(), span)); + Statement::Return { expr, semi: false } + }) + .labelled("return expression") +} + // An expression is a single term followed by 0 or more (OP subexpression)* // where OP is an operator at the given precedence level and subexpression // is an expression at the current precedence level plus one. @@ -1484,4 +1497,40 @@ mod test { ); } } + + #[test] + fn return_validation() { + let cases = vec![ + ("{ return 42; }", 1, "{\n return 42;\n}"), + ("{ return 1; return 2; }", 2, "{\n return 1;\n return 2;\n}"), + ( + "{ return 123; let foo = 4 + 3; }", + 1, + "{\n return 123;\n let foo: unspecified = (4 + 3)\n}", + ), + ("{ return 1 + 2 }", 1, "{\n return (1 + 2)\n}"), + ("{ return; }", 1, "{\n return;\n}"), + ]; + + let show_errors = |v| vecmap(&v, ToString::to_string).join("\n"); + + let results = vecmap(&cases, |&(src, expected_errors, expected_result)| { + let (opt, errors) = parse_recover(block(expression()), src); + let actual = opt.map(|ast| ast.to_string()); + let actual = if let Some(s) = &actual { s.to_string() } else { "(none)".to_string() }; + + let result = + ((errors.len(), actual.clone()), (expected_errors, expected_result.to_string())); + if result.0 != result.1 { + let num_errors = errors.len(); + let shown_errors = show_errors(errors); + eprintln!( + "\nExpected {} error(s) and got {}:\n\n{}\n\nFrom input: {}\nExpected AST: {}\nActual AST: {}\n", + expected_errors, num_errors, shown_errors, src, expected_result, actual); + } + result + }); + + assert_eq!(vecmap(&results, |t| t.0.clone()), vecmap(&results, |t| t.1.clone()),); + } }