diff --git a/ast/ast.go b/ast/ast.go index bf44dc9..a97abc3 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -41,8 +41,16 @@ type LetStatement struct { Value Expression } +type ReturnStatement struct { + Token token.Token + ReturnValue Expression +} + func (i *Identifier) expressionNode() {} func (i *Identifier) TokenLiteral() string { return i.Token.Literal } func (ls *LetStatement) statementNode() {} func (ls *LetStatement) TokenLiteral() string { return ls.Token.Literal } + +func (rs *ReturnStatement) statementNode() {} +func (rs *ReturnStatement) TokenLiteral() string { return rs.Token.Literal } diff --git a/parser/parser.go b/parser/parser.go index e39d660..4589c59 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -81,6 +81,8 @@ func (p *Parser) parseStatement() ast.Statement { switch p.curToken.Type { case token.LET: return p.parseLetStatement() + case token.RETURN: + return p.parseReturnStatment() default: return nil } @@ -105,3 +107,12 @@ func (p *Parser) parseLetStatement() *ast.LetStatement { } return stmt } + +func (p *Parser) parseReturnStatment() *ast.ReturnStatement { + stmt := &ast.ReturnStatement{Token: p.curToken} + p.nextToken() + for !p.curTokenIs(token.SEMICOLON) { + p.nextToken() + } + return stmt +} diff --git a/parser/parser_test.go b/parser/parser_test.go index e027e8a..10bda8f 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -56,3 +56,30 @@ func testLetStatement(t *testing.T, s ast.Statement, name string) bool { } return true } + +func TestReturnStatements(t *testing.T) { + input := ` + return 5; + return 10; + return 993322; + ` + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + // checkParserErrors(t, p) + if len(program.Statements) != 3 { + t.Fatalf("program.Statements does not contain 3 statements. got=%d", + len(program.Statements)) + } + for _, stmt := range program.Statements { + returnStmt, ok := stmt.(*ast.ReturnStatement) + if !ok { + t.Errorf("stmt not *ast.returnStatement. got=%T", stmt) + continue + } + if returnStmt.TokenLiteral() != "return" { + t.Errorf("returnStmt.TokenLiteral not 'return', got %q", + returnStmt.TokenLiteral()) + } + } +}