From a675db715f74b664ade912fb1e32acf6db0abeaf Mon Sep 17 00:00:00 2001 From: Melvic Ybanez Date: Sat, 17 Sep 2022 23:23:20 +0800 Subject: [PATCH] Add support to parse function declarations --- .scalafmt.conf | 4 ++ src/main/scala/com/melvic/dry/ast/Decl.scala | 2 + .../melvic/dry/interpreter/Interpreter.scala | 4 +- .../com/melvic/dry/parsers/DeclParser.scala | 39 +++++++++++++++---- .../com/melvic/dry/parsers/StmtParser.scala | 14 +++---- 5 files changed, 46 insertions(+), 17 deletions(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index afe8c91..caffbcf 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -9,6 +9,10 @@ align.tokens = [ { code = "<-" owners = [{ regex = "Enumerator.Generator" }] + }, + { + code = "->", + owners = [{ regex = "Term.Apply" }] } ] diff --git a/src/main/scala/com/melvic/dry/ast/Decl.scala b/src/main/scala/com/melvic/dry/ast/Decl.scala index 082e65c..10d6bd8 100644 --- a/src/main/scala/com/melvic/dry/ast/Decl.scala +++ b/src/main/scala/com/melvic/dry/ast/Decl.scala @@ -12,6 +12,8 @@ object Decl { final case class StmtDecl(stmt: Stmt) extends Decl + final case class Def(name: Token, params: List[Token], body: List[Decl]) extends Decl + object StmtDecl { def fromExpr(expr: Expr): StmtDecl = StmtDecl(ExprStmt(expr)) diff --git a/src/main/scala/com/melvic/dry/interpreter/Interpreter.scala b/src/main/scala/com/melvic/dry/interpreter/Interpreter.scala index d23a58d..b3427a4 100644 --- a/src/main/scala/com/melvic/dry/interpreter/Interpreter.scala +++ b/src/main/scala/com/melvic/dry/interpreter/Interpreter.scala @@ -20,10 +20,10 @@ object Interpreter { })(env) } - recurse(declarations, LocalEnv(env.table, globals), Value.Unit) + recurse(declarations, LocalEnv(env.table, natives), Value.Unit) } - def globals: Env = Env.empty + def natives: Env = Env.empty .define("print", Callable(1, { case arg :: _ => print(Value.show(arg)).unit })) // we don't support user-defined functions yet, so we are building a dedicated function for println for now. // Once, user-defined functions are supported, we can just replace this with a call to `print`, applied diff --git a/src/main/scala/com/melvic/dry/parsers/DeclParser.scala b/src/main/scala/com/melvic/dry/parsers/DeclParser.scala index aa07cdb..317874a 100644 --- a/src/main/scala/com/melvic/dry/parsers/DeclParser.scala +++ b/src/main/scala/com/melvic/dry/parsers/DeclParser.scala @@ -3,19 +3,21 @@ package com.melvic.dry.parsers import com.melvic.dry.Token import com.melvic.dry.Token.TokenType import com.melvic.dry.ast.Decl -import com.melvic.dry.ast.Decl.{Let, LetDecl, LetInit, StmtDecl} +import com.melvic.dry.ast.Decl._ +import com.melvic.dry.parsers.Step._ import scala.util.chaining.scalaUtilChainingOps private[parsers] trait DeclParser extends StmtParser { _: Parser => def declaration: ParseResult[Decl] = - (matchAny(TokenType.Let) match { - case None => statement.mapValue(StmtDecl(_)) - case Some(parser) => parser.letDecl - }).pipe { - case result @ ParseResult(Left(_), _) => result.mapParser(_.synchronize) - case result => result - } + matchAny(TokenType.Def) + .map(_.defDecl) + .orElse(matchAny(TokenType.Let).map(_.letDecl)) + .getOrElse(statement.mapValue(StmtDecl(_))) + .pipe { + case result @ ParseResult(Left(_), _) => result.mapParser(_.synchronize) + case result => result + } def letDecl: ParseResult[Let] = { def consumeSemicolon(parser: Parser): ParseResult[Token] = @@ -31,4 +33,25 @@ private[parsers] trait DeclParser extends StmtParser { _: Parser => } } } + + def defDecl: ParseResult[Def] = + for { + name <- consume(TokenType.Identifier, "identifier", "def keyword") + leftParen <- name.consume(TokenType.LeftParen, "(", "function name") + params <- + if (!leftParen.check(TokenType.RightParen)) { + def recurse(params: List[Token], parser: Parser): ParseResult[List[Token]] = + parser + .matchAny(TokenType.Comma) + .fold(ParseResult.succeed(params.reverse, parser))( + _.consume(TokenType.Identifier, "parameter name", ",").mapValue(_ :: params) + ) + + leftParen.consume(TokenType.Identifier, "parameter name", "(").flatMap { case Step(param, parser) => + recurse(param :: Nil, parser).flatMapParser(_.consume(TokenType.RightParen, ")", "parameters")) + } + } else leftParen.consume(TokenType.RightParen, ")", "(").mapValue(_ :: Nil) + leftBrace <- params.consume(TokenType.LeftBrace, "{", "function signature") + body <- leftBrace.block + } yield Step(Def(name.value, params.value, body.value.declarations), body.next) } diff --git a/src/main/scala/com/melvic/dry/parsers/StmtParser.scala b/src/main/scala/com/melvic/dry/parsers/StmtParser.scala index 83ae75e..72cac5b 100644 --- a/src/main/scala/com/melvic/dry/parsers/StmtParser.scala +++ b/src/main/scala/com/melvic/dry/parsers/StmtParser.scala @@ -13,9 +13,9 @@ private[parsers] trait StmtParser { _: Parser with DeclParser => select( expressionStatement, TokenType.LeftBrace -> { _.block }, - TokenType.If -> { _.ifStatement }, - TokenType.While -> { _.whileStatement }, - TokenType.For -> { _.forStatement } + TokenType.If -> { _.ifStatement }, + TokenType.While -> { _.whileStatement }, + TokenType.For -> { _.forStatement } ) def expressionStatement: ParseResult[Stmt] = @@ -41,10 +41,10 @@ private[parsers] trait StmtParser { _: Parser with DeclParser => def ifStatement: ParseResult[Stmt] = for { - step <- consume(TokenType.LeftParen, "(", "if") - cond <- step.expression - body <- cond.consume(TokenType.RightParen, ")", "if condition") - thenBranch <- body.statement + leftParen <- consume(TokenType.LeftParen, "(", "if") + cond <- leftParen.expression + rightParen <- cond.consume(TokenType.RightParen, ")", "if condition") + thenBranch <- rightParen.statement ifStmt <- thenBranch .matchAny(TokenType.Else) .fold[ParseResult[Stmt]](