Skip to content

Commit

Permalink
Better support for generic catch blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
melvic-ybanez committed Nov 17, 2023
1 parent bb3253d commit 960319d
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 72 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ in [BNF](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form):
<return> ::= "return" <expression>? ";"
<import> ::= "import" <identifier>("."<identifier>)* ";"
<delete> ::= "del" <call><index> ";"
<try-catch> ::= "try" <block> ("catch" "(" (<identifier> ":")? <identifier> ")" <block>)+
<try-catch> ::= "try" <block> ("catch" "(" <identifier>? ":" <identifier>? ")" <block>)+
<expression> ::= <assignment> | <lambda>
<assignment> ::= <call> "=" <expression>
<call> ::= <primary> ("(" (<expression> | ("," <expression>)*)? ")" | "." <identifier> | <index>)*
Expand Down
21 changes: 16 additions & 5 deletions src/main/scala/com/melvic/dry/ast/Stmt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,27 @@ object Stmt {
}
}

final case class CatchBlock(instance: Option[Variable], kind: Variable, block: BlockStmt, paren: Token)
sealed trait CatchBlock

object CatchBlock {
final case class CatchType(kind: Variable, block: BlockStmt, paren: Token) extends CatchBlock
final case class CatchUntypedVar(instance: Variable, block: BlockStmt, paren: Token) extends CatchBlock
final case class CatchTypedVar(instance: Variable, kind: Variable, blockStmt: BlockStmt, paren: Token)
extends CatchBlock
final case class CatchAll(blockStmt: BlockStmt, paren: Token) extends CatchBlock

implicit val implicitShow: Show[CatchBlock] = show

def show: Show[CatchBlock] = {
case CatchBlock(None, kind, block, _) =>
show"${Lexemes.Catch} ($kind) $block"
case CatchBlock(Some(instance), kind, block, _) =>
show"${Lexemes.Catch} ($instance: $kind) $block"
def show(insideParens: String, block: BlockStmt): String =
show"${Lexemes.Catch} ($insideParens) $block"

{
case CatchUntypedVar(instance, block, _) => show(show"$instance:", block)
case CatchType(kind, block, _) => show(show": $kind", block)
case CatchTypedVar(instance, kind, block, _) => show(show"$instance: $kind", block)
case CatchAll(block, _) => show(":", block)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private[eval] trait EvalDecl extends EvalStmt {
node match {
case ClassDecl(name, methods) =>
env.define(name, Value.None)
val klass = new DClass(
val klass = DClass(
name.lexeme,
methods
.map(method =>
Expand Down
30 changes: 20 additions & 10 deletions src/main/scala/com/melvic/dry/interpreter/eval/EvalStmt.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.melvic.dry.interpreter.eval

import com.melvic.dry.Token
import com.melvic.dry.Token.TokenType
import com.melvic.dry.ast.Expr.Variable
import com.melvic.dry.ast.Stmt.CatchBlock.{CatchAll, CatchType, CatchTypedVar, CatchUntypedVar}
import com.melvic.dry.ast.Stmt.IfStmt.{IfThen, IfThenElse}
import com.melvic.dry.ast.Stmt.Loop.While
import com.melvic.dry.ast.Stmt._
Expand Down Expand Up @@ -102,23 +104,24 @@ private[eval] trait EvalStmt {
def invalidArg(got: String, paren: Token): Failure =
RuntimeError.invalidArgument(Types.Exception, got, paren.line)

lazy val dExceptionKind: Option[String] = DException.kindOf(raised.instance)
lazy val raisedKind = DException.kindOf(raised.instance)

def catchGeneric(paren: Token)(catchFromRaised: String => CatchBlock) =
raisedKind.toRight(One(invalidArg("unknown exception", paren))).flatMap { raised =>
evalCatchBlock(catchFromRaised(raised))
}

def evalCatchBlock: CatchBlock => Result[Option[Value]] = {
case CatchBlock(None, exceptionKind, block, paren) =>
case CatchType(exceptionKind, block, paren) =>
Evaluate.variable(exceptionKind).flatMap {
case DException(kind, _) if dExceptionKind.contains(kind.exceptionName) =>
case DException(kind, _) if raisedKind.contains(kind.exceptionName) =>
Evaluate.blockStmt(block).map(Some(_))
case DException(_, _) => Right(None)
case arg => invalidArg(Value.typeOf(arg), paren).fail
}
case catchBlock @ CatchBlock(_, Variable(token @ Token(_, "GenericError", _)), _, paren) =>
dExceptionKind.toRight(One(invalidArg("unknown exception", paren))).flatMap { dExceptionKind =>
evalCatchBlock(catchBlock.copy(kind = Variable(token.copy(lexeme = dExceptionKind))))
}
case CatchBlock(Some(Variable(instance)), exceptionKind, block, paren) =>
case CatchTypedVar(Variable(instance), exceptionKind, block, paren) =>
Evaluate.variable(exceptionKind).flatMap {
case DException(kind, _) if dExceptionKind.contains(kind.exceptionName) =>
case DException(kind, _) if raisedKind.contains(kind.exceptionName) =>
Evaluate
.blockStmtWith { env =>
val localEnv = Env.fromEnclosing(env)
Expand All @@ -128,7 +131,14 @@ private[eval] trait EvalStmt {
case DException(_, _) => Right(None)
case arg => invalidArg(Value.typeOf(arg), paren).fail
}
case CatchBlock(_, _, _, paren) => invalidArg("expression", paren).fail
case CatchAll(block, paren) =>
catchGeneric(paren)(raised =>
CatchType(Variable(Token(TokenType.Identifier, raised, paren.line)), block, paren)
)
case CatchUntypedVar(instance, block, paren) =>
catchGeneric(paren)(raised =>
CatchTypedVar(instance, Variable(Token(TokenType.Identifier, raised, paren.line)), block, paren)
)
}

def findCatchBlock(catchBlocks: Nel[CatchBlock]): Out =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.melvic.dry.interpreter.values

object Attributes {
val Class = "__class__"
val Name = "__name__"
}
6 changes: 3 additions & 3 deletions src/main/scala/com/melvic/dry/interpreter/values/DClass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.melvic.dry.result.Result.implicits.ToResult
import scala.collection.mutable
import scala.util.chaining.scalaUtilChainingOps

class DClass(val name: String, val methods: Methods, val enclosing: Env)
class DClass private[values] (val name: String, val methods: Methods, val enclosing: Env)
extends Callable
with Metaclass
with DObject {
Expand All @@ -24,14 +24,14 @@ class DClass(val name: String, val methods: Methods, val enclosing: Env)

override def klass = Metaclass

override val fields = mutable.Map("__name__" -> Value.Str(name))
override val fields = mutable.Map(Attributes.Name -> Value.Str(name))
}

object DClass {
type Methods = Map[String, DFunction]

def apply(name: String, methods: Methods, enclosing: Env): DClass =
new DClass(name, methods, enclosing)
new DClass(name, methods, enclosing).addField(Attributes.Class, Metaclass)

def unapply(klass: DClass): Option[(String, Methods, Env)] =
Some(klass.name, klass.methods, klass.enclosing)
Expand Down
22 changes: 11 additions & 11 deletions src/main/scala/com/melvic/dry/interpreter/values/DException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.melvic.dry.interpreter.values

import com.melvic.dry.Token
import com.melvic.dry.interpreter.Env
import com.melvic.dry.interpreter.values.DException.Fields
import com.melvic.dry.interpreter.values.DException.Attributes
import com.melvic.dry.interpreter.values.Value.{Types, typeOf}
import com.melvic.dry.result.Failure.RuntimeError
import com.melvic.dry.result.Failure.RuntimeError.Kind
Expand All @@ -15,9 +15,9 @@ class DException(val kind: Kind, val env: Env) extends DClass(kind.exceptionName
case args @ ((message: Value.Str) :: _) =>
super.call(token)(args).flatMap { case instance: DInstance =>
instance
.addField(Fields.Kind, Value.Str(kind.exceptionName))
.addField(Fields.Message, message)
.addField(Fields.Line, Value.Num(token.line))
.addField(Attributes.Kind, Value.Str(kind.exceptionName))
.addField(Attributes.Message, message)
.addField(Attributes.Line, Value.Num(token.line))
.ok
}
case arg :: _ => RuntimeError.invalidArgument(s"${Types.String}", typeOf(arg), token.line).fail
Expand All @@ -39,10 +39,10 @@ object DException {
new NoArgDException(kind, env)
}

object Fields {
val Message: String = "message"
val Kind: String = "kind"
val Line: String = "line"
object Attributes {
val Message: String = "__message__"
val Kind: String = "__kind__"
val Line: String = "__line__"
}

def apply(kind: Kind, env: Env): DException =
Expand All @@ -52,13 +52,13 @@ object DException {
Some(exception.kind, exception.env)

def kindOf(instance: DInstance): Option[String] =
asString(instance.getField(Fields.Kind))
asString(instance.getField(Attributes.Kind))

def messageOf(instance: DInstance): Option[String] =
asString(instance.getField(Fields.Message))
asString(instance.getField(Attributes.Message))

def lineOf(instance: DInstance): Option[Int] =
instance.getField(Fields.Line).flatMap(_.toNum).map(_.value.toInt)
instance.getField(Attributes.Line).flatMap(_.toNum).map(_.value.toInt)

private def asString(value: Option[Value]): Option[String] =
value.flatMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ import scala.collection.mutable
final case class DInstance private (klass: Metaclass, fields: mutable.Map[String, Value]) extends DObject

object DInstance {
def fromClass(klass: DClass): DInstance = DInstance(klass, mutable.Map.empty)
def fromClass(klass: DClass): DInstance =
DInstance(klass, mutable.Map.empty).addField(Attributes.Class, klass)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.melvic.dry.interpreter.values

trait Metaclass {
//noinspection SpellCheckingInspection
trait Metaclass extends Value {
def methods: Map[String, Callable.Function]

def findMethod(name: String): Option[Callable.Function] =
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/com/melvic/dry/parsers/Step.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ final case class Step[+A](value: A, next: Parser) {
Step(f(value), next)

def toParseResult: ParseResult[A] =
ParseResult.succeed(value, next)
ParseResult.fromStep(this)
}

object Step {
Expand Down
42 changes: 21 additions & 21 deletions src/main/scala/com/melvic/dry/parsers/StmtParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.melvic.dry.Token
import com.melvic.dry.Token.TokenType
import com.melvic.dry.ast.Decl.StmtDecl
import com.melvic.dry.ast.Expr.{IndexGet, Literal, Variable}
import com.melvic.dry.ast.Stmt.CatchBlock.{CatchAll, CatchType, CatchTypedVar, CatchUntypedVar}
import com.melvic.dry.ast.Stmt.IfStmt._
import com.melvic.dry.ast.Stmt.Loop.While
import com.melvic.dry.ast.Stmt._
Expand All @@ -12,6 +13,8 @@ import com.melvic.dry.aux.Nel
import com.melvic.dry.aux.Nel.{Many, One}
import com.melvic.dry.lexer.Lexemes

import scala.util.chaining.scalaUtilChainingOps

//noinspection ScalaWeakerAccess
private[parsers] trait StmtParser { _: Parser with DeclParser =>

Expand Down Expand Up @@ -177,32 +180,29 @@ private[parsers] trait StmtParser { _: Parser with DeclParser =>
.flatMapParser(_.consumeAfter(TokenType.Semicolon, ";", "]"))

/**
* {{{<try-catch> ::= "try" <block> ("catch" "(" (<identifier> ":")? <identifier> ")" <block>)+}}}
* {{{<try-catch> ::= "try" <block> ("catch" "(" <identifier>? ":" <identifier>? ")" <block>)+}}}
*/
def tryStatement: ParseResult[Stmt] = {
def consumeIdentifier(parser: Parser, after: String): ParseResult[Option[Variable]] =
parser
.matchAny(TokenType.Identifier)
.map(p => Step(Some(Variable(p.previousToken)), p).toParseResult)
.getOrElse(ParseResult.succeed(None, parser))

def catchStmt(parser: Parser): ParseResult[CatchBlock] =
for {
leftParen <- parser.consumeAfter(TokenType.LeftParen, "(", Lexemes.Catch)
variable <- leftParen
.consumeAfter(TokenType.Identifier, "identifier", "(")
.map(p => Step(Variable(p.next.previousToken), p))
exception <- variable
.matchAny(TokenType.Colon)
.fold[ParseResult[(Option[Variable], Variable)]](
// the parsed variable becomes the exception type
ParseResult.fromStep(variable.map(v => (None, v)))
) {
// the parsed variable becomes the exception instance, and we parse
// another variable for the type
_.consumeAfter(TokenType.Identifier, "identifier", "(")
.map(p => Step((Some(variable.value), Variable(p.next.previousToken)), p))
}
rightParen <- exception.consumeAfter(TokenType.RightParen, ")", "identifier")
leftParen <- parser.consumeAfter(TokenType.LeftParen, "(", Lexemes.Catch)
instance <- consumeIdentifier(leftParen, "(")
colon <- instance.consumeAfter(TokenType.Colon, ":", "identifier")
kind <- consumeIdentifier(colon, ":")
rightParen <- kind.consumeAfter(TokenType.RightParen, ")", "identifier")
block <- rightParen.blockAfter(")")
} yield Step(
CatchBlock(exception.value._1, exception.value._2, block.value, leftParen.value),
block.next
)
} yield ((instance.value, kind.value) match {
case (Some(instance), Some(kind)) => CatchTypedVar(instance, kind, block.value, leftParen.value)
case (Some(instance), None) => CatchUntypedVar(instance, block.value, leftParen.value)
case (_, Some(kind)) => CatchType(kind, block.value, leftParen.value)
case _ => CatchAll(block.value, leftParen.value)
}).pipe(Step(_, block.next))

for {
tryBlock <- blockAfter(Lexemes.Try)
Expand Down
12 changes: 9 additions & 3 deletions src/main/scala/com/melvic/dry/resolver/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.melvic.dry.Token
import com.melvic.dry.ast.Decl.Let.{LetDecl, LetInit}
import com.melvic.dry.ast.Decl.{ClassDecl, Def, StmtDecl}
import com.melvic.dry.ast.Expr.{List => _, _}
import com.melvic.dry.ast.Stmt.CatchBlock.{CatchType, CatchTypedVar, CatchUntypedVar}
import com.melvic.dry.ast.Stmt.IfStmt.{IfThen, IfThenElse}
import com.melvic.dry.ast.Stmt.Loop.While
import com.melvic.dry.ast.Stmt._
Expand Down Expand Up @@ -52,10 +53,15 @@ object Resolve {
}

private def tryCatch: TryCatch => Resolve = { case TryCatch(tryBlock, catchBlocks) =>
def catchVar(instance: Token, declarations: List[Decl]) =
Resolve.blockStmt(BlockStmt(LetDecl(instance) :: declarations))

Resolve.blockStmt(tryBlock) >=> sequence(catchBlocks.toList) {
case CatchBlock(None, _, block, _) => Resolve.blockStmt(block)
case CatchBlock(Some(Variable(instance)), _, BlockStmt(declarations), _) =>
Resolve.blockStmt(BlockStmt(LetInit(instance, Literal.None) :: declarations))
case CatchType(_, block, _) => Resolve.blockStmt(block)
case CatchUntypedVar(Variable(instance), BlockStmt(declarations), _) =>
catchVar(instance, declarations)
case CatchTypedVar(Variable(instance), _, BlockStmt(declarations), _) =>
catchVar(instance, declarations)
case _ => _.ok
}
}
Expand Down
12 changes: 10 additions & 2 deletions stdlib/assert.dry
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ def assert_equals(description, expected, got) {
def assert_error(description, expected_error, code) {
try {
code();
} catch (error: GenericError) {
assert_equals(description, expected_error.kind, error.kind);
} catch (error:) {
assert_equals(description, expected_error, error);
}
}

def assert_error_type(description, expected_error_type, code) {
try {
code();
} catch (error:) {
assert_equals(description, error.__kind__, error.__kind__);
}
}
24 changes: 12 additions & 12 deletions tests/test_try_catch.dry
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ class TestTryCatch {
try {
19 / 0;
x = 20; // shouldn't be executed
} catch (DivisionByZeroError) {}
} catch (: DivisionByZeroError) {}

assert_equals("Handle an exception with a single empty catch-block", 10, x);

try {
19 / 0;
x = 20;
} catch (DivisionByZeroError) {
} catch (: DivisionByZeroError) {
x = 30;
}

Expand All @@ -26,32 +26,32 @@ class TestTryCatch {

try {
xs[3];
} catch (DivisionByZeroError) {
} catch (: DivisionByZeroError) {
y = 20; // skipped
} catch (IndexOutOfBoundsError) {
} catch (: IndexOutOfBoundsError) {
x = 50;
}

assert_equals("Skip catch-blocks that don't capture the exception", (50, 10), (x, y));

try {
raise(UndefinedVariableError("x"));
} catch (UndefinedVariableError) {
} catch (: UndefinedVariableError) {
x = 20;
} catch (IncorrectArityError) {
} catch (: IncorrectArityError) {
x = 70; // ignored
}

assert_equals("Stop at the first catch-block that captures the exception", 20, x);
}

def test_no_match() {
let exception = UndefinedVariableError("y");
assert.assert_error("Throw the error if no catch-blocks capture it", exception, lambda() {
try {
raise(exception);
} catch (IncorrectArityError) {} catch (DivisionByZeroError) {}
});
assert.assert_error_type("Throw the error if no catch-blocks capture it", UndefinedVariableError,
lambda() {
try {
raise(UndefinedVariableError("y"));
} catch (: IncorrectArityError) {} catch (: DivisionByZeroError) {}
});
}
}

Expand Down

0 comments on commit 960319d

Please sign in to comment.