diff --git a/aztec_macros/src/utils/ast_utils.rs b/aztec_macros/src/utils/ast_utils.rs index 1731dfab49c..d0d7077a946 100644 --- a/aztec_macros/src/utils/ast_utils.rs +++ b/aztec_macros/src/utils/ast_utils.rs @@ -66,6 +66,7 @@ pub fn mutable_assignment(name: &str, assigned_to: Expression) -> Statement { pattern: mutable(name), r#type: make_type(UnresolvedTypeData::Unspecified), expression: assigned_to, + comptime: false, attributes: vec![], })) } @@ -90,6 +91,7 @@ pub fn assignment_with_type( pattern: pattern(name), r#type: make_type(typ), expression: assigned_to, + comptime: false, attributes: vec![], })) } diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 753b5a31d32..c46ab39df1e 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -38,6 +38,8 @@ pub enum StatementKind { For(ForLoopStatement), Break, Continue, + /// This statement should be executed at compile-time + Comptime(Box), // This is an expression with a trailing semi-colon Semi(Expression), // This statement is the result of a recovered parse error. @@ -47,6 +49,19 @@ pub enum StatementKind { } impl Statement { + pub fn add_semicolon( + mut self, + semi: Option, + span: Span, + last_statement_in_block: bool, + emit_error: &mut dyn FnMut(ParserError), + ) -> Self { + self.kind = self.kind.add_semicolon(semi, span, last_statement_in_block, emit_error); + self + } +} + +impl StatementKind { pub fn add_semicolon( self, semi: Option, @@ -57,7 +72,7 @@ impl Statement { let missing_semicolon = ParserError::with_reason(ParserErrorReason::MissingSeparatingSemi, span); - let kind = match self.kind { + match self { StatementKind::Let(_) | StatementKind::Constrain(_) | StatementKind::Assign(_) @@ -69,10 +84,15 @@ impl Statement { if semi.is_none() { emit_error(missing_semicolon); } - self.kind + self + } + StatementKind::Comptime(mut statement) => { + *statement = + statement.add_semicolon(semi, span, last_statement_in_block, emit_error); + StatementKind::Comptime(statement) } // A semicolon on a for loop is optional and does nothing - StatementKind::For(_) => self.kind, + StatementKind::For(_) => self, StatementKind::Expression(expr) => { match (&expr.kind, semi, last_statement_in_block) { @@ -92,9 +112,7 @@ impl Statement { (_, None, true) => StatementKind::Expression(expr), } } - }; - - Statement { kind, span: self.span } + } } } @@ -108,7 +126,13 @@ impl StatementKind { pub fn new_let( ((pattern, r#type), expression): ((Pattern, UnresolvedType), Expression), ) -> StatementKind { - StatementKind::Let(LetStatement { pattern, r#type, expression, attributes: vec![] }) + StatementKind::Let(LetStatement { + pattern, + r#type, + expression, + comptime: false, + attributes: vec![], + }) } /// Create a Statement::Assign value, desugaring any combined operators like += if needed. @@ -407,17 +431,9 @@ pub struct LetStatement { pub r#type: UnresolvedType, pub expression: Expression, pub attributes: Vec, -} -impl LetStatement { - pub fn new_let( - (((pattern, r#type), expression), attributes): ( - ((Pattern, UnresolvedType), Expression), - Vec, - ), - ) -> LetStatement { - LetStatement { pattern, r#type, expression, attributes } - } + // True if this should only be run during compile-time + pub comptime: bool, } #[derive(Debug, PartialEq, Eq, Clone)] @@ -573,6 +589,7 @@ impl ForRange { pattern: Pattern::Identifier(array_ident.clone()), r#type: UnresolvedType::unspecified(), expression: array, + comptime: false, attributes: vec![], }), span: array_span, @@ -616,6 +633,7 @@ impl ForRange { pattern: Pattern::Identifier(identifier), r#type: UnresolvedType::unspecified(), expression: Expression::new(loop_element, array_span), + comptime: false, attributes: vec![], }), span: array_span, @@ -666,6 +684,7 @@ impl Display for StatementKind { StatementKind::For(for_loop) => for_loop.fmt(f), StatementKind::Break => write!(f, "break"), StatementKind::Continue => write!(f, "continue"), + StatementKind::Comptime(statement) => write!(f, "comptime {statement}"), StatementKind::Semi(semi) => write!(f, "{semi};"), StatementKind::Error => write!(f, "Error"), } diff --git a/compiler/noirc_frontend/src/debug/mod.rs b/compiler/noirc_frontend/src/debug/mod.rs index 67b52071d7b..3e7d123398b 100644 --- a/compiler/noirc_frontend/src/debug/mod.rs +++ b/compiler/noirc_frontend/src/debug/mod.rs @@ -145,6 +145,7 @@ impl DebugInstrumenter { pattern: ast::Pattern::Identifier(ident("__debug_expr", ret_expr.span)), r#type: ast::UnresolvedType::unspecified(), expression: ret_expr.clone(), + comptime: false, attributes: vec![], }), span: ret_expr.span, @@ -243,6 +244,7 @@ impl DebugInstrumenter { kind: ast::StatementKind::Let(ast::LetStatement { pattern: ast::Pattern::Tuple(vars_pattern, let_stmt.pattern.span()), r#type: ast::UnresolvedType::unspecified(), + comptime: false, expression: ast::Expression { kind: ast::ExpressionKind::Block(ast::BlockExpression { statements: block_stmts, @@ -275,6 +277,7 @@ impl DebugInstrumenter { pattern: ast::Pattern::Identifier(ident("__debug_expr", assign_stmt.expression.span)), r#type: ast::UnresolvedType::unspecified(), expression: assign_stmt.expression.clone(), + comptime: false, attributes: vec![], }); let expression_span = assign_stmt.expression.span; diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs index 8ffcbce7d62..9f9004b5ad1 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs @@ -36,6 +36,7 @@ impl StmtId { pattern, r#type, expression, + comptime: false, attributes: Vec::new(), }) } @@ -64,6 +65,9 @@ impl StmtId { HirStatement::Expression(expr) => StatementKind::Expression(expr.to_ast(interner)), HirStatement::Semi(expr) => StatementKind::Semi(expr.to_ast(interner)), HirStatement::Error => StatementKind::Error, + HirStatement::Comptime(statement) => { + StatementKind::Comptime(Box::new(statement.to_ast(interner).kind)) + } }; Statement { kind, span } diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 81050073008..54f1c321a3b 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -44,6 +44,10 @@ pub(crate) struct Interpreter<'interner> { changed_globally: bool, in_loop: bool, + + /// True if we're currently in a compile-time context. + /// If this is false code is skipped over instead of executed. + in_comptime_context: bool, } #[allow(unused)] @@ -121,6 +125,7 @@ impl<'a> Interpreter<'a> { changed_functions: HashSet::default(), changed_globally: false, in_loop: false, + in_comptime_context: false, } } @@ -1074,6 +1079,7 @@ impl<'a> Interpreter<'a> { HirStatement::Break => self.evaluate_break(), HirStatement::Continue => self.evaluate_continue(), HirStatement::Expression(expression) => self.evaluate(expression), + HirStatement::Comptime(statement) => self.evaluate_comptime(statement), HirStatement::Semi(expression) => { self.evaluate(expression)?; Ok(Value::Unit) @@ -1246,6 +1252,13 @@ impl<'a> Interpreter<'a> { Err(InterpreterError::ContinueNotInLoop) } } + + fn evaluate_comptime(&mut self, statement: StmtId) -> IResult { + let was_in_comptime = std::mem::replace(&mut self.in_comptime_context, true); + let result = self.evaluate_statement(statement); + self.in_comptime_context = was_in_comptime; + result + } } impl Value { diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 01d477e9d4c..ab0490a80e4 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -1275,6 +1275,10 @@ impl<'a> Resolver<'a> { HirStatement::Continue } StatementKind::Error => HirStatement::Error, + StatementKind::Comptime(statement) => { + let statement = self.resolve_stmt(*statement, span); + HirStatement::Comptime(self.interner.push_stmt(statement)) + } } } diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index fb57aa75f89..4dfc896fb29 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -51,6 +51,7 @@ impl<'interner> TypeChecker<'interner> { HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt), HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id), HirStatement::For(for_loop) => self.check_for_loop(for_loop), + HirStatement::Comptime(statement) => return self.check_statement(&statement), HirStatement::Break | HirStatement::Continue | HirStatement::Error => (), } Type::Unit diff --git a/compiler/noirc_frontend/src/hir_def/stmt.rs b/compiler/noirc_frontend/src/hir_def/stmt.rs index 37e3651a9b2..605f25ebfbf 100644 --- a/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -1,6 +1,6 @@ use super::expr::HirIdent; use crate::macros_api::SecondaryAttribute; -use crate::node_interner::ExprId; +use crate::node_interner::{ExprId, StmtId}; use crate::{Ident, Type}; use fm::FileId; use noirc_errors::{Location, Span}; @@ -19,6 +19,7 @@ pub enum HirStatement { Continue, Expression(ExprId), Semi(ExprId), + Comptime(StmtId), Error, } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 20b9c0885bf..477ed3fe951 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -624,6 +624,9 @@ impl<'interner> Monomorphizer<'interner> { HirStatement::Break => Ok(ast::Expression::Break), HirStatement::Continue => Ok(ast::Expression::Continue), HirStatement::Error => unreachable!(), + + // All `comptime` statements & expressions should be removed before runtime. + HirStatement::Comptime(_) => unreachable!("comptime statement in runtime code"), } } diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 5706c3ef12f..ba5b9bf1ab0 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -234,14 +234,16 @@ fn implementation() -> impl NoirParser { /// global_declaration: 'global' ident global_type_annotation '=' literal fn global_declaration() -> impl NoirParser { let p = attributes::attributes() + .then(maybe_comp_time()) .then_ignore(keyword(Keyword::Global).labelled(ParsingRuleLabel::Global)) .then(ident().map(Pattern::Identifier)); + let p = then_commit(p, optional_type_annotation()); let p = then_commit_ignore(p, just(Token::Assign)); let p = then_commit(p, expression()); - p.validate(|(((attributes, pattern), r#type), expression), span, emit| { + p.validate(|((((attributes, comptime), pattern), r#type), expression), span, emit| { let global_attributes = attributes::validate_secondary_attributes(attributes, span, emit); - LetStatement { pattern, r#type, expression, attributes: global_attributes } + LetStatement { pattern, r#type, comptime, expression, attributes: global_attributes } }) .map(TopLevelStatement::Global) } @@ -498,10 +500,11 @@ where assertion::assertion_eq(expr_parser.clone()), declaration(expr_parser.clone()), assignment(expr_parser.clone()), - for_loop(expr_no_constructors, statement), + for_loop(expr_no_constructors.clone(), statement.clone()), break_statement(), continue_statement(), return_statement(expr_parser.clone()), + comptime_statement(expr_parser.clone(), expr_no_constructors, statement), expr_parser.map(StatementKind::Expression), )) }) @@ -519,6 +522,35 @@ fn continue_statement() -> impl NoirParser { keyword(Keyword::Continue).to(StatementKind::Continue) } +fn comptime_statement<'a, P1, P2, S>( + expr: P1, + expr_no_constructors: P2, + statement: S, +) -> impl NoirParser + 'a +where + P1: ExprParser + 'a, + P2: ExprParser + 'a, + S: NoirParser + 'a, +{ + keyword(Keyword::CompTime) + .ignore_then(choice(( + declaration(expr), + for_loop(expr_no_constructors, statement.clone()), + block(statement).map_with_span(|block, span| { + StatementKind::Expression(Expression::new(ExpressionKind::Block(block), span)) + }), + ))) + .map(|statement| StatementKind::Comptime(Box::new(statement))) +} + +/// Comptime in an expression position only accepts entire blocks +fn comptime_expr<'a, S>(statement: S) -> impl NoirParser + 'a +where + S: NoirParser + 'a, +{ + keyword(Keyword::CompTime).ignore_then(block(statement)).map(ExpressionKind::Block) +} + fn declaration<'a, P>(expr_parser: P) -> impl NoirParser + 'a where P: ExprParser + 'a, @@ -700,24 +732,25 @@ fn optional_distinctness() -> impl NoirParser { }) } -fn maybe_comp_time() -> impl NoirParser<()> { +fn maybe_comp_time() -> impl NoirParser { keyword(Keyword::CompTime).or_not().validate(|opt, span, emit| { if opt.is_some() { - emit(ParserError::with_reason(ParserErrorReason::ComptimeDeprecated, span)); + emit(ParserError::with_reason( + ParserErrorReason::ExperimentalFeature("comptime"), + span, + )); } + opt.is_some() }) } fn field_type() -> impl NoirParser { - maybe_comp_time() - .then_ignore(keyword(Keyword::Field)) + keyword(Keyword::Field) .map_with_span(|_, span| UnresolvedTypeData::FieldElement.with_span(span)) } fn bool_type() -> impl NoirParser { - maybe_comp_time() - .then_ignore(keyword(Keyword::Bool)) - .map_with_span(|_, span| UnresolvedTypeData::Bool.with_span(span)) + keyword(Keyword::Bool).map_with_span(|_, span| UnresolvedTypeData::Bool.with_span(span)) } fn string_type() -> impl NoirParser { @@ -744,21 +777,20 @@ fn format_string_type( } fn int_type() -> impl NoirParser { - maybe_comp_time() - .then(filter_map(|span, token: Token| match token { - Token::IntType(int_type) => Ok(int_type), - unexpected => { - Err(ParserError::expected_label(ParsingRuleLabel::IntegerType, unexpected, span)) - } - })) - .validate(|(_, token), span, emit| { - UnresolvedTypeData::from_int_token(token) - .map(|data| data.with_span(span)) - .unwrap_or_else(|err| { - emit(ParserError::with_reason(ParserErrorReason::InvalidBitSize(err.0), span)); - UnresolvedType::error(span) - }) - }) + filter_map(|span, token: Token| match token { + Token::IntType(int_type) => Ok(int_type), + unexpected => { + Err(ParserError::expected_label(ParsingRuleLabel::IntegerType, unexpected, span)) + } + }) + .validate(|token, span, emit| { + UnresolvedTypeData::from_int_token(token).map(|data| data.with_span(span)).unwrap_or_else( + |err| { + emit(ParserError::with_reason(ParserErrorReason::InvalidBitSize(err.0), span)); + UnresolvedType::error(span) + }, + ) + }) } fn named_type(type_parser: impl NoirParser) -> impl NoirParser { @@ -1236,6 +1268,7 @@ where }, lambdas::lambda(expr_parser.clone()), block(statement.clone()).map(ExpressionKind::Block), + comptime_expr(statement.clone()), quote(statement), variable(), literal(), diff --git a/compiler/noirc_frontend/src/parser/parser/function.rs b/compiler/noirc_frontend/src/parser/parser/function.rs index 18f17065038..074c902ff7b 100644 --- a/compiler/noirc_frontend/src/parser/parser/function.rs +++ b/compiler/noirc_frontend/src/parser/parser/function.rs @@ -1,8 +1,8 @@ use super::{ attributes::{attributes, validate_attributes}, - block, fresh_statement, ident, keyword, nothing, optional_distinctness, optional_visibility, - parameter_name_recovery, parameter_recovery, parenthesized, parse_type, pattern, - self_parameter, where_clause, NoirParser, + block, fresh_statement, ident, keyword, maybe_comp_time, nothing, optional_distinctness, + optional_visibility, parameter_name_recovery, parameter_recovery, parenthesized, parse_type, + pattern, self_parameter, where_clause, NoirParser, }; use crate::parser::labels::ParsingRuleLabel; use crate::parser::spanned; @@ -36,8 +36,8 @@ pub(super) fn function_definition(allow_self: bool) -> impl NoirParser impl NoirParser { /// function_modifiers: 'unconstrained'? (visibility)? /// /// returns (is_unconstrained, visibility) for whether each keyword was present -fn function_modifiers() -> impl NoirParser<(bool, ItemVisibility)> { +fn function_modifiers() -> impl NoirParser<(bool, ItemVisibility, bool)> { keyword(Keyword::Unconstrained) .or_not() .then(visibility_modifier()) - .map(|(unconstrained, visibility)| (unconstrained.is_some(), visibility)) + .then(maybe_comp_time()) + .map(|((unconstrained, visibility), comptime)| { + (unconstrained.is_some(), visibility, comptime) + }) } /// non_empty_ident_list: ident ',' non_empty_ident_list @@ -172,8 +175,8 @@ mod test { "fn f(f: pub Field, y : T, z : Field) -> u8 { x + a }", "fn func_name(x: [Field], y : [Field;2],y : pub [Field;2], z : pub [u8;5]) {}", "fn main(x: pub u8, y: pub u8) -> distinct pub [u8; 2] { [x, y] }", - "fn f(f: pub Field, y : Field, z : comptime Field) -> u8 { x + a }", - "fn f(f: pub Field, y : T, z : comptime Field) -> u8 { x + a }", + "fn f(f: pub Field, y : Field, z : Field) -> u8 { x + a }", + "fn f(f: pub Field, y : T, z : Field) -> u8 { x + a }", "fn func_name(f: Field, y : T) where T: SomeTrait {}", "fn func_name(f: Field, y : T) where T: SomeTrait + SomeTrait2 {}", "fn func_name(f: Field, y : T) where T: SomeTrait, T: SomeTrait2 {}", diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index c4f0a8d67ba..31bf2245b1f 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -780,6 +780,7 @@ mod test { HirStatement::Error => panic!("Invalid HirStatement!"), HirStatement::Break => panic!("Unexpected break"), HirStatement::Continue => panic!("Unexpected continue"), + HirStatement::Comptime(_) => panic!("Unexpected comptime"), }; let expr = interner.expression(&expr_id); diff --git a/tooling/nargo_fmt/src/visitor/stmt.rs b/tooling/nargo_fmt/src/visitor/stmt.rs index ee8cc990e0e..612330ad3a3 100644 --- a/tooling/nargo_fmt/src/visitor/stmt.rs +++ b/tooling/nargo_fmt/src/visitor/stmt.rs @@ -1,7 +1,8 @@ use std::iter::zip; use noirc_frontend::{ - ConstrainKind, ConstrainStatement, ExpressionKind, ForRange, Statement, StatementKind, + macros_api::Span, ConstrainKind, ConstrainStatement, ExpressionKind, ForRange, Statement, + StatementKind, }; use crate::{rewrite, visitor::expr::wrap_exprs}; @@ -14,92 +15,94 @@ impl super::FmtVisitor<'_> { for (Statement { kind, span }, index) in zip(stmts, 1..) { let is_last = index == len; + self.visit_stmt(kind, span, is_last); + self.last_position = span.end(); + } + } - match kind { - StatementKind::Expression(expr) => self.visit_expr( - expr, - if is_last { ExpressionType::SubExpression } else { ExpressionType::Statement }, - ), - StatementKind::Semi(expr) => { - self.visit_expr(expr, ExpressionType::Statement); - self.push_str(";"); - } - StatementKind::Let(let_stmt) => { - let let_str = - self.slice(span.start()..let_stmt.expression.span.start()).trim_end(); - - let expr_str = rewrite::sub_expr(self, self.shape(), let_stmt.expression); - - self.push_rewrite(format!("{let_str} {expr_str};"), span); - } - StatementKind::Constrain(ConstrainStatement(expr, message, kind)) => { - let mut nested_shape = self.shape(); - let shape = nested_shape; - - nested_shape.indent.block_indent(self.config); - - let message = message.map_or(String::new(), |message| { - let message = rewrite::sub_expr(self, nested_shape, message); - format!(", {message}") - }); - - let (callee, args) = match kind { - ConstrainKind::Assert | ConstrainKind::Constrain => { - let assertion = rewrite::sub_expr(self, nested_shape, expr); - let args = format!("{assertion}{message}"); - - ("assert", args) - } - ConstrainKind::AssertEq => { - if let ExpressionKind::Infix(infix) = expr.kind { - let lhs = rewrite::sub_expr(self, nested_shape, infix.lhs); - let rhs = rewrite::sub_expr(self, nested_shape, infix.rhs); + fn visit_stmt(&mut self, kind: StatementKind, span: Span, is_last: bool) { + match kind { + StatementKind::Expression(expr) => self.visit_expr( + expr, + if is_last { ExpressionType::SubExpression } else { ExpressionType::Statement }, + ), + StatementKind::Semi(expr) => { + self.visit_expr(expr, ExpressionType::Statement); + self.push_str(";"); + } + StatementKind::Let(let_stmt) => { + let let_str = self.slice(span.start()..let_stmt.expression.span.start()).trim_end(); - let args = format!("{lhs}, {rhs}{message}"); + let expr_str = rewrite::sub_expr(self, self.shape(), let_stmt.expression); - ("assert_eq", args) - } else { - unreachable!() - } + self.push_rewrite(format!("{let_str} {expr_str};"), span); + } + StatementKind::Constrain(ConstrainStatement(expr, message, kind)) => { + let mut nested_shape = self.shape(); + let shape = nested_shape; + + nested_shape.indent.block_indent(self.config); + + let message = message.map_or(String::new(), |message| { + let message = rewrite::sub_expr(self, nested_shape, message); + format!(", {message}") + }); + + let (callee, args) = match kind { + ConstrainKind::Assert | ConstrainKind::Constrain => { + let assertion = rewrite::sub_expr(self, nested_shape, expr); + let args = format!("{assertion}{message}"); + + ("assert", args) + } + ConstrainKind::AssertEq => { + if let ExpressionKind::Infix(infix) = expr.kind { + let lhs = rewrite::sub_expr(self, nested_shape, infix.lhs); + let rhs = rewrite::sub_expr(self, nested_shape, infix.rhs); + + let args = format!("{lhs}, {rhs}{message}"); + + ("assert_eq", args) + } else { + unreachable!() } - }; - - let args = wrap_exprs( - "(", - ")", - args, - nested_shape, - shape, - NewlineMode::IfContainsNewLineAndWidth, - ); - let constrain = format!("{callee}{args};"); - - self.push_rewrite(constrain, span); - } - StatementKind::For(for_stmt) => { - let identifier = self.slice(for_stmt.identifier.span()); - let range = match for_stmt.range { - ForRange::Range(start, end) => format!( - "{}..{}", - rewrite::sub_expr(self, self.shape(), start), - rewrite::sub_expr(self, self.shape(), end) - ), - ForRange::Array(array) => rewrite::sub_expr(self, self.shape(), array), - }; - let block = rewrite::sub_expr(self, self.shape(), for_stmt.block); - - let result = format!("for {identifier} in {range} {block}"); - self.push_rewrite(result, span); - } - StatementKind::Assign(_) => { - self.push_rewrite(self.slice(span).to_string(), span); - } - StatementKind::Error => unreachable!(), - StatementKind::Break => self.push_rewrite("break;".into(), span), - StatementKind::Continue => self.push_rewrite("continue;".into(), span), + } + }; + + let args = wrap_exprs( + "(", + ")", + args, + nested_shape, + shape, + NewlineMode::IfContainsNewLineAndWidth, + ); + let constrain = format!("{callee}{args};"); + + self.push_rewrite(constrain, span); } - - self.last_position = span.end(); + StatementKind::For(for_stmt) => { + let identifier = self.slice(for_stmt.identifier.span()); + let range = match for_stmt.range { + ForRange::Range(start, end) => format!( + "{}..{}", + rewrite::sub_expr(self, self.shape(), start), + rewrite::sub_expr(self, self.shape(), end) + ), + ForRange::Array(array) => rewrite::sub_expr(self, self.shape(), array), + }; + let block = rewrite::sub_expr(self, self.shape(), for_stmt.block); + + let result = format!("for {identifier} in {range} {block}"); + self.push_rewrite(result, span); + } + StatementKind::Assign(_) => { + self.push_rewrite(self.slice(span).to_string(), span); + } + StatementKind::Error => unreachable!(), + StatementKind::Break => self.push_rewrite("break;".into(), span), + StatementKind::Continue => self.push_rewrite("continue;".into(), span), + StatementKind::Comptime(statement) => self.visit_stmt(*statement, span, is_last), } } }