From 91d7e2383a658f39f1abf12ea3f8a9005dc514ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Silveira?= Date: Sun, 29 Sep 2024 21:29:45 +0100 Subject: [PATCH] feat: add hash-consed types --- include/ADT/NonOwningList.h | 1 + include/AST/AST.h | 30 +++++++-- include/Alloc/Arena.h | 31 +++++++--- include/Analysis/CFA.h | 8 +++ include/Analysis/Resolver.h | 8 +++ include/Analysis/TypeChecker.h | 19 +++++- include/Codegen/Codegen.h | 2 +- include/Lex/Lexer.h | 7 +++ include/Lex/Token.h | 3 + include/Parse/Parser.h | 10 ++- include/Support/Reporting.h | 34 +++++----- include/Support/SourceFile.h | 6 ++ include/Support/Utilities.h | 27 ++++++++ include/Typing/Type.h | 4 ++ include/Typing/TypeContext.h | 109 +++++++++++++++++++++++++-------- samples/valid/1.lang | 7 ++- src/AST/ASTPrinter.cpp | 7 ++- src/Alloc/Arena.cpp | 13 +--- src/Analysis/CFA.cpp | 10 ++- src/Analysis/Resolver.cpp | 4 +- src/Analysis/TypeChecker.cpp | 52 +++++++++------- src/Codegen/Codegen.cpp | 25 ++++---- src/Parse/Parser.cpp | 44 ++++++++----- src/Support/Reporting.cpp | 2 +- src/main.cpp | 81 +++++++++++++++++------- 25 files changed, 394 insertions(+), 150 deletions(-) create mode 100644 include/Support/Utilities.h diff --git a/include/ADT/NonOwningList.h b/include/ADT/NonOwningList.h index 09bb58e..58f5e18 100644 --- a/include/ADT/NonOwningList.h +++ b/include/ADT/NonOwningList.h @@ -9,6 +9,7 @@ template class NonOwningList { struct Node { T data; Node *next; + explicit Node(const T &data) noexcept : data(data), next(nullptr) {} }; diff --git a/include/AST/AST.h b/include/AST/AST.h index 3b090a9..0a6d799 100644 --- a/include/AST/AST.h +++ b/include/AST/AST.h @@ -54,6 +54,7 @@ struct ExprAST { ExprASTKind kind; std::string_view span; Type *type; + ExprAST(ExprASTKind kind, std::string_view span) : kind(kind), span(span), type(nullptr) {} }; @@ -74,6 +75,7 @@ enum class StmtASTKind { struct StmtAST { StmtASTKind kind; std::string_view span; + StmtAST(StmtASTKind kind, std::string_view span) : kind(kind), span(span) {} }; @@ -86,6 +88,7 @@ enum class DeclASTKind { struct DeclAST { DeclASTKind kind; std::string_view ident; + DeclAST(DeclASTKind kind, std::string_view ident) : kind(kind), ident(ident) {} }; @@ -95,6 +98,7 @@ struct DeclAST { struct ModuleAST { std::string_view ident; NonOwningList decls; + ModuleAST(std::string_view ident, NonOwningList decls) : ident(ident), decls(decls) {} }; @@ -109,6 +113,7 @@ struct NumberExprAST : public ExprAST { struct UnaryExprAST : public ExprAST { UnOpKind op; ExprAST *expr; + UnaryExprAST(std::string_view span, UnOpKind op, ExprAST *expr) : ExprAST(ExprASTKind::Unary, span), op(op), expr(expr) {} }; @@ -117,6 +122,7 @@ struct BinaryExprAST : public ExprAST { BinOpKind op; ExprAST *lhs; ExprAST *rhs; + BinaryExprAST(std::string_view span, BinOpKind op, ExprAST *lhs, ExprAST *rhs) : ExprAST(ExprASTKind::Binary, span), op(op), lhs(lhs), rhs(rhs) {} @@ -124,15 +130,17 @@ struct BinaryExprAST : public ExprAST { struct CallExprAST : public ExprAST { ExprAST *callee; - ExprAST *arg; // TODO: support multiple arguments - // NOLINTNEXTLINE - CallExprAST(std::string_view span, ExprAST *callee, ExprAST *arg) - : ExprAST(ExprASTKind::Call, span), callee(callee), arg(arg) {} + NonOwningList args; + + CallExprAST(std::string_view span, ExprAST *callee, + NonOwningList args) + : ExprAST(ExprASTKind::Call, span), callee(callee), args(args) {} }; struct IndexExprAST : public ExprAST { ExprAST *base; ExprAST *index; + // NOLINTNEXTLINE IndexExprAST(std::string_view span, ExprAST *base, ExprAST *index) : ExprAST(ExprASTKind::Index, span), base(base), index(index) {} @@ -140,6 +148,7 @@ struct IndexExprAST : public ExprAST { struct GroupedExprAST : public ExprAST { ExprAST *expr; + GroupedExprAST(std::string_view span, ExprAST *expr) : ExprAST(ExprASTKind::Grouped, span), expr(expr) {} }; @@ -148,18 +157,21 @@ struct GroupedExprAST : public ExprAST { struct ExprStmtAST : public StmtAST { ExprAST *expr; + ExprStmtAST(std::string_view span, ExprAST *expr) : StmtAST(StmtASTKind::Expr, span), expr(expr) {} }; struct BreakStmtAST : public StmtAST { StmtAST *target; + explicit BreakStmtAST(std::string_view span) : StmtAST(StmtASTKind::Break, span), target(nullptr) {} }; struct ReturnStmtAST : public StmtAST { ExprAST *expr; + ReturnStmtAST(std::string_view span, ExprAST *expr) : StmtAST(StmtASTKind::Return, span), expr(expr) {} }; @@ -168,6 +180,7 @@ struct LocalStmtAST : public StmtAST { bool isConst; Type *type; ExprAST *init; + LocalStmtAST(bool isConst, std::string_view ident, Type *type, ExprAST *init) : StmtAST(StmtASTKind::Local, ident), isConst(isConst), type(type), @@ -177,12 +190,14 @@ struct LocalStmtAST : public StmtAST { struct AssignStmtAST : public StmtAST { ExprAST *lhs; ExprAST *rhs; + AssignStmtAST(std::string_view span, ExprAST *lhs, ExprAST *rhs) : StmtAST(StmtASTKind::Assign, span), lhs(lhs), rhs(rhs) {} }; struct BlockStmtAST : public StmtAST { NonOwningList stmts; + BlockStmtAST(std::string_view span, NonOwningList stmts) : StmtAST(StmtASTKind::Block, span), stmts(stmts) {} }; @@ -191,6 +206,7 @@ struct IfStmtAST : public StmtAST { ExprAST *cond; BlockStmtAST *thenStmt; StmtAST *elseStmt; + IfStmtAST(std::string_view span, ExprAST *cond, BlockStmtAST *thenStmt, StmtAST *elseStmt) : StmtAST(StmtASTKind::If, span), cond(cond), thenStmt(thenStmt), @@ -200,6 +216,7 @@ struct IfStmtAST : public StmtAST { struct WhileStmtAST : public StmtAST { ExprAST *cond; BlockStmtAST *body; + WhileStmtAST(std::string_view span, ExprAST *cond, BlockStmtAST *body) : StmtAST(StmtASTKind::While, span), cond(cond), body(body) {} }; @@ -210,11 +227,13 @@ struct FunctionDeclAST : public DeclAST { NonOwningList params; Type *retType; BlockStmtAST *body; + Type *type; + FunctionDeclAST(std::string_view ident, NonOwningList params, Type *retType, BlockStmtAST *body) : DeclAST(DeclASTKind::Function, ident), params(params), - retType(retType), body(body) {} + retType(retType), body(body), type(nullptr) {} }; /// === Identifier Expressions === @@ -224,6 +243,7 @@ using IdentifierDecl = struct IdentifierExprAST : public ExprAST { IdentifierDecl decl; + explicit IdentifierExprAST(std::string_view span) : ExprAST(ExprASTKind::Identifier, span) {} }; diff --git a/include/Alloc/Arena.h b/include/Alloc/Arena.h index 5de4dcc..2542af4 100644 --- a/include/Alloc/Arena.h +++ b/include/Alloc/Arena.h @@ -1,6 +1,7 @@ #ifndef LANG_ARENA_H #define LANG_ARENA_H +#include #include #include @@ -21,7 +22,7 @@ constexpr std::size_t gigaBytes(const std::size_t bytes) { class Arena { public: explicit Arena(std::size_t bytes) - : allocations(0), defaultSize(bytes), allocSize(0) {} + : defaultSize(bytes), numAllocations(0), allocSize(0) {} Arena(const Arena &) = delete; Arena &operator=(const Arena &) = delete; @@ -29,7 +30,9 @@ class Arena { Arena(Arena &&) = default; Arena &operator=(Arena &&) = default; - [[nodiscard]] std::size_t totalAllocations() const { return allocations; } + [[nodiscard]] std::size_t totalAllocations() const { + return numAllocations; + } [[nodiscard]] std::size_t totalAllocated() const { std::size_t total = allocSize; @@ -50,23 +53,35 @@ class Arena { template T *alloc(Args &&...args) { static_assert(std::is_trivially_destructible_v, "T must be trivially destructible"); - ++allocations; + ++numAllocations; return new (allocInternal(sizeof(T))) T(std::forward(args)...); } + template void dealloc(T *ptr) { + static_assert(std::is_trivially_destructible_v, + "T must be trivially destructible"); + std::byte *start = block.data.get() + allocSize - sizeof(T); + assert(reinterpret_cast(start) == ptr); + allocSize -= sizeof(T); + } + private: struct Block { - std::unique_ptr data; std::size_t size; - Block() : data(nullptr), size(0) {} + std::unique_ptr data; + + Block() : size(0), data(nullptr) {} + Block(std::size_t size) - : data(std::make_unique(size)), size(size) {} + : size(size), data(std::make_unique(size)) {} }; - std::size_t allocations; std::size_t defaultSize; - std::size_t allocSize; + std::size_t numAllocations; + Block block; + std::size_t allocSize; + std::list used; std::list avail; diff --git a/include/Analysis/CFA.h b/include/Analysis/CFA.h index ec9af8b..8c9f9ea 100644 --- a/include/Analysis/CFA.h +++ b/include/Analysis/CFA.h @@ -10,6 +10,7 @@ namespace lang { enum class CFAErrorKind { + EarlyBreakStmt, EarlyReturnStmt, InvalidBreakStmt, }; @@ -17,14 +18,21 @@ enum class CFAErrorKind { struct CFAError { CFAErrorKind kind; std::string_view span; + CFAError(CFAErrorKind kind, std::string_view span) : kind(kind), span(span) {} + TextError toTextError() const; + JSONError toJSONError() const; }; struct CFAResult { std::vector errors; + + CFAResult(std::vector errors) : errors(std::move(errors)) {} + + [[nodiscard]] bool hasErrors() const { return !errors.empty(); } }; class CFA : public MutableASTVisitor { diff --git a/include/Analysis/Resolver.h b/include/Analysis/Resolver.h index 6cac2db..682f7c6 100644 --- a/include/Analysis/Resolver.h +++ b/include/Analysis/Resolver.h @@ -16,14 +16,22 @@ enum class ResolveErrorKind { struct ResolveError { ResolveErrorKind kind; std::string_view span; + ResolveError(ResolveErrorKind kind, std::string_view span) : kind(kind), span(span) {} + TextError toTextError() const; + JSONError toJSONError() const; }; struct ResolveResult { std::vector errors; + + ResolveResult(std::vector errors) + : errors(std::move(errors)) {} + + [[nodiscard]] bool hasErrors() const { return !errors.empty(); } }; class Resolver : public MutableASTVisitor { diff --git a/include/Analysis/TypeChecker.h b/include/Analysis/TypeChecker.h index dbd1905..3fcb061 100644 --- a/include/Analysis/TypeChecker.h +++ b/include/Analysis/TypeChecker.h @@ -13,7 +13,7 @@ namespace lang { enum class TypeCheckerErrorKind { - InvalidReturn, + InvalidReturnStmt, InvalidAssignment, InvalidBinaryOperation, }; @@ -21,24 +21,37 @@ enum class TypeCheckerErrorKind { struct TypeCheckerError { TypeCheckerErrorKind kind; std::string_view span; + + TypeCheckerError(TypeCheckerErrorKind kind, std::string_view span) + : kind(kind), span(span) {} + TextError toTextError() const; + JSONError toJSONError() const; }; struct TypeCheckerResult { std::vector errors; + + TypeCheckerResult(std::vector errors) + : errors(std::move(errors)) {} + + [[nodiscard]] bool hasErrors() const { return !errors.empty(); } }; class TypeChecker : public MutableASTVisitor { friend class ASTVisitor; public: - TypeChecker(TypeContext &typeCtx) : typeCtx(typeCtx), currentFunction(nullptr) {} + TypeChecker(TypeContext &typeCtx) + : typeCtx(&typeCtx), arena(typeCtx.getArena()), + currentFunction(nullptr) {} TypeCheckerResult analyzeModuleAST(ModuleAST &module); private: - TypeContext &typeCtx; + TypeContext *typeCtx; + Arena *arena; FunctionDeclAST *currentFunction; std::vector errors; diff --git a/include/Codegen/Codegen.h b/include/Codegen/Codegen.h index 4554fec..60464fc 100644 --- a/include/Codegen/Codegen.h +++ b/include/Codegen/Codegen.h @@ -17,7 +17,7 @@ class Codegen : public ConstASTVisitor { : deepCodegen(false), context(std::make_unique()), builder(std::make_unique>(*context)) {} - llvm::Module *generate(const ModuleAST &module); + llvm::Module *generateFromModuleAST(const ModuleAST &module); private: bool deepCodegen = false; diff --git a/include/Lex/Lexer.h b/include/Lex/Lexer.h index 6ac3b1b..d3a7a6e 100644 --- a/include/Lex/Lexer.h +++ b/include/Lex/Lexer.h @@ -16,15 +16,22 @@ enum class LexErrorKind { struct LexError { LexErrorKind kind; std::string_view span; + LexError(LexErrorKind kind, std::string_view span) : kind(kind), span(span) {} + TextError toTextError() const; + JSONError toJSONError() const; }; struct LexResult { std::vector tokens; std::vector errors; + + LexResult() = default; + + [[nodiscard]] bool hasErrors() const { return !errors.empty(); } }; class Lexer { diff --git a/include/Lex/Token.h b/include/Lex/Token.h index 74e58ac..0349e7a 100644 --- a/include/Lex/Token.h +++ b/include/Lex/Token.h @@ -56,6 +56,9 @@ std::string tokenKindToString(TokenKind kind); struct Token { TokenKind kind; std::string_view span; + + Token(TokenKind kind, std::string_view span) : kind(kind), span(span) {} + std::string toString() const; }; diff --git a/include/Parse/Parser.h b/include/Parse/Parser.h index 6d5dcda..b5cb665 100644 --- a/include/Parse/Parser.h +++ b/include/Parse/Parser.h @@ -19,22 +19,30 @@ enum class ParseErrorKind { UnexpectedEOF, UnexpectedToken, ExpectedType, - ExpectedPrimaryExpression, + ExpectedExpression, }; struct ParseError { ParseErrorKind kind; std::string_view span; TokenKind expected; + ParseError(ParseErrorKind kind, std::string_view span, TokenKind expected) : kind(kind), span(span), expected(expected) {} + TextError toTextError() const; + JSONError toJSONError() const; }; struct ParseResult { ModuleAST *module; std::vector errors; + + ParseResult(ModuleAST *module, std::vector errors) + : module(module), errors(std::move(errors)) {} + + [[nodiscard]] bool hasErrors() const { return !errors.empty(); } }; class Parser { diff --git a/include/Support/Reporting.h b/include/Support/Reporting.h index 91d2602..9c708f8 100644 --- a/include/Support/Reporting.h +++ b/include/Support/Reporting.h @@ -2,41 +2,43 @@ #define LANG_REPORTING_H #include "SourceFile.h" +#include "Utilities.h" #include "llvm/Support/raw_ostream.h" namespace lang { -// TODO: Move this to a different file -constexpr unsigned getNumDigits(unsigned n) { - unsigned digits = 0; - while (n) { - n /= 10; - ++digits; - } - return digits; -} - struct TextError { std::string_view span; std::string_view title; std::string label; -}; -void reportTextError(llvm::raw_ostream &os, const SourceFile &file, - const TextError &error, unsigned lineNoWidthHint = 0); + // NOLINTNEXTLINE + TextError(std::string_view span, std::string_view title, + std::string_view label) + : span(span), title(title), label(label) {} +}; struct JSONError { std::string_view span; std::string_view title; + + // NOLINTNEXTLINE + JSONError(std::string_view span, std::string_view title) + : span(span), title(title) {} }; +void reportTextError(llvm::raw_ostream &os, const SourceFile &file, + const TextError &error, unsigned lineNoWidthHint = 0); + void reportJSONError(llvm::raw_ostream &os, const SourceFile &file, const JSONError &error); /// @brief Reports a vector of errors in batch in plain text -/// @pre errors only contains errors from the same file -/// @pre errors is sorted in the order of appearence within the file +/// @note To make proper use of this function, the following preconditions must +/// be met: +/// - errors only contains errors from the same file +/// - errors is sorted in the order of appearence within the file template void reportTextErrors( llvm::raw_ostream &os, const SourceFile &file, const std::vector &errors, @@ -47,7 +49,7 @@ void reportTextErrors( } const std::size_t maxLine = file.getLocation(errors.back().span).line; - const unsigned lineNoMaxWidth = getNumDigits(maxLine); + const unsigned lineNoMaxWidth = numDigits(maxLine); const std::string lineNoSpacesBody = std::string(lineNoMaxWidth + 2, ' '); const std::size_t lastButOne = maxErrors - 1; diff --git a/include/Support/SourceFile.h b/include/Support/SourceFile.h index 1c4efcd..9aa775e 100644 --- a/include/Support/SourceFile.h +++ b/include/Support/SourceFile.h @@ -12,6 +12,12 @@ struct SourceLocation { std::size_t column; std::string_view filename; std::string_view lineText; + + // NOLINTNEXTLINE + SourceLocation(std::size_t line, std::size_t column, + // NOLINTNEXTLINE + std::string_view filename, std::string_view lineText) + : line(line), column(column), filename(filename), lineText(lineText) {} }; class SourceFile { diff --git a/include/Support/Utilities.h b/include/Support/Utilities.h new file mode 100644 index 0000000..92f31d5 --- /dev/null +++ b/include/Support/Utilities.h @@ -0,0 +1,27 @@ +#ifndef LANG_UTILITIES_H +#define LANG_UTILITIES_H + +#include +#include + +namespace lang { + +template +int numDigits(T number) +{ + static_assert(std::is_integral_v, "T must be an integral type"); + int digits = 0; + if constexpr (std::is_signed_v) { + if (number < 0) digits = 1; + } + while (number) { + number /= 10; + digits++; + } + return digits; +} + + +} // namespace lang + +#endif // LANG_UTILITIES_H \ No newline at end of file diff --git a/include/Typing/Type.h b/include/Typing/Type.h index 3624524..a4bc3d5 100644 --- a/include/Typing/Type.h +++ b/include/Typing/Type.h @@ -14,11 +14,13 @@ enum class TypeKind { struct Type { TypeKind kind; + explicit Type(TypeKind kind) : kind(kind) {} std::string toString() const; template T *as() { return static_cast(this); } + template const T *as() const { return static_cast(this); } @@ -26,11 +28,13 @@ struct Type { struct PointerType : public Type { Type *pointee; + PointerType(Type *pointee) : Type(TypeKind::Pointer), pointee(pointee) {} }; struct FunctionType : public Type { NonOwningList arrows; + FunctionType(NonOwningList arrows) : Type(TypeKind::Function), arrows(arrows) {} }; diff --git a/include/Typing/TypeContext.h b/include/Typing/TypeContext.h index 0a31770..ece22db 100644 --- a/include/Typing/TypeContext.h +++ b/include/Typing/TypeContext.h @@ -5,10 +5,68 @@ #include "Type.h" -#include +#include namespace lang { +struct HashType { + std::size_t operator()(const Type *type) const { + switch (type->kind) { + case TypeKind::Void: + return 0; + case TypeKind::Number: + return 1; + case TypeKind::Pointer: { + const PointerType *ptrType = type->as(); + return 2 * operator()(ptrType->pointee); + } + case TypeKind::Function: { + const FunctionType *fnType = type->as(); + std::size_t hash = 3; + for (const Type *param : fnType->arrows) { + hash = hash * 31 + operator()(param); + } + return hash; + } + } + return 0; + } +}; + +struct EqualType { + bool operator()(const Type *lhs, const Type *rhs) const { + if (lhs->kind != rhs->kind) { + return false; + } + switch (lhs->kind) { + case TypeKind::Void: + case TypeKind::Number: + return true; + case TypeKind::Pointer: { + const PointerType *lhsPtr = lhs->as(); + const PointerType *rhsPtr = rhs->as(); + return operator()(lhsPtr->pointee, rhsPtr->pointee); + } + case TypeKind::Function: { + const FunctionType *lhsFn = lhs->as(); + const FunctionType *rhsFn = rhs->as(); + if (lhsFn->arrows.size() != rhsFn->arrows.size()) { + return false; + } + auto iterLhs = lhsFn->arrows.begin(); + auto iterRhs = rhsFn->arrows.begin(); + for (; iterLhs != lhsFn->arrows.end(); ++iterLhs, ++iterRhs) { + if (!operator()(*iterLhs, *iterRhs)) { + return false; + } + } + return true; + } + } + return false; + } +}; + class TypeContext { public: explicit TypeContext(Arena &arena) : arena(&arena) { @@ -16,51 +74,50 @@ class TypeContext { tyNumber = arena.alloc(TypeKind::Number); } + Arena *getArena() const { return arena; } + Type *getTypeVoid() const { return tyVoid; } Type *getTypeNumber() const { return tyNumber; } - template Type *make(Args &&...args) { - Type *type = arena->alloc(std::forward(args)...); + std::size_t getNumTypes() const { return typeSet.size(); } + + template Type *make(Args &&...args) { + auto [type, inserted] = + hashCons(arena->alloc(std::forward(args)...)); + if (!inserted) { + arena->dealloc(type); + } return type; } private: - struct HashType { - std::size_t operator()(const Type *type) const { return 0; } - }; - - struct EqualType { - bool operator()(const Type *lhs, const Type *rhs) const { - return false; - } - }; - Arena *arena; Type *tyVoid; Type *tyNumber; - std::unordered_map typeMap; + std::unordered_set typeSet; - Type *hashCons(Type *type) { + std::pair hashCons(Type *type) { switch (type->kind) { case TypeKind::Void: - return tyVoid; + return {tyVoid, false}; case TypeKind::Number: - return tyNumber; + return {tyNumber, false}; case TypeKind::Pointer: { - PointerType *pointerType = type->as(); - pointerType->pointee = hashCons(pointerType->pointee); - const auto [it, inserted] = typeMap.emplace(type, pointerType); - return it->second; + PointerType *ptrType = type->as(); + ptrType->pointee = hashCons(ptrType->pointee).first; + const auto [it, inserted] = typeSet.insert(type); + return {*it, inserted}; } case TypeKind::Function: - FunctionType *functionType = type->as(); - for (Type *&arrow : functionType->arrows) { - arrow = hashCons(arrow); + FunctionType *fnType = type->as(); + for (Type *¶m : fnType->arrows) { + param = hashCons(param).first; } - const auto [it, inserted] = typeMap.emplace(type, functionType); - return it->second; + const auto [it, inserted] = typeSet.insert(type); + return {*it, inserted}; } + return {nullptr, false}; } }; diff --git a/samples/valid/1.lang b/samples/valid/1.lang index 1a451e6..ebd0225 100644 --- a/samples/valid/1.lang +++ b/samples/valid/1.lang @@ -2,12 +2,13 @@ // // Unambiguous syntax, modern and easy to parse syntax with elements // inspired by C, Kotlin, Rust, and Swift. -fn main(): void { +fn main(): number { 1; - println(0); + println(0, 1); + return 1; } -fn println() : void { +fn println(x: number, y: number): void { // This is an empty block { } diff --git a/src/AST/ASTPrinter.cpp b/src/AST/ASTPrinter.cpp index a9e9475..a018690 100644 --- a/src/AST/ASTPrinter.cpp +++ b/src/AST/ASTPrinter.cpp @@ -1,4 +1,5 @@ #include "AST/ASTPrinter.h" +#include namespace { @@ -112,7 +113,7 @@ void ASTPrinter::visit(const WhileStmtAST &node) { void ASTPrinter::visit(const IdentifierExprAST &node) { INDENT(); os << "IdentifierExprAST: " << node.span; - std::visit(Overloaded{[&](const std::monostate) {}, + std::visit(Overloaded{[&](const std::monostate) { os << " => None"; }, [&](const LocalStmtAST *stmt) { os << " => LocalStmtAST(" << static_cast(stmt) << ')'; @@ -147,7 +148,9 @@ void ASTPrinter::visit(const CallExprAST &node) { INDENT(); os << "CallExprAST\n"; ASTVisitor::visit(*node.callee); - ASTVisitor::visit(*node.arg); + for (const ExprAST *arg : node.args) { + ASTVisitor::visit(*arg); + } } void ASTPrinter::visit(const IndexExprAST &node) { diff --git a/src/Alloc/Arena.cpp b/src/Alloc/Arena.cpp index 0333269..dd9cfd1 100644 --- a/src/Alloc/Arena.cpp +++ b/src/Alloc/Arena.cpp @@ -4,16 +4,9 @@ namespace lang { -constexpr std::size_t alignUp(std::size_t size, std::size_t align) { - return (size + align - 1) & ~(align - 1); -} - -} // namespace lang - -namespace lang { - void *Arena::allocInternal(std::size_t size) { - size = alignUp(size, alignof(max_align_t)); + constexpr auto alignMask = alignof(max_align_t) - 1; + size = (size + alignMask) & ~alignMask; if (allocSize + size > block.size) { if (block.data != nullptr) { @@ -29,7 +22,7 @@ void *Arena::allocInternal(std::size_t size) { } if (block.data == nullptr) { - block = Block{std::max(size, defaultSize)}; + block = Block(std::max(size, defaultSize)); } allocSize = 0; diff --git a/src/Analysis/CFA.cpp b/src/Analysis/CFA.cpp index 617b6d4..2059e9c 100644 --- a/src/Analysis/CFA.cpp +++ b/src/Analysis/CFA.cpp @@ -4,6 +4,9 @@ namespace lang { TextError CFAError::toTextError() const { switch (kind) { + case CFAErrorKind::EarlyBreakStmt: + return {span, "Early break statement", + "Code after break statement will never be executed"}; case CFAErrorKind::EarlyReturnStmt: return {span, "Early return statement", "Code after return statement will never be executed"}; @@ -17,6 +20,8 @@ TextError CFAError::toTextError() const { JSONError CFAError::toJSONError() const { switch (kind) { + case CFAErrorKind::EarlyBreakStmt: + return {span, "cfa-early-break-stmt"}; case CFAErrorKind::EarlyReturnStmt: return {span, "cfa-early-return-stmt"}; case CFAErrorKind::InvalidBreakStmt: @@ -51,8 +56,11 @@ void CFA::visit(BlockStmtAST &node) { ++i; if (stmt->kind == StmtASTKind::Break || stmt->kind == StmtASTKind::Return) { + CFAErrorKind kind = stmt->kind == StmtASTKind::Break + ? CFAErrorKind::EarlyBreakStmt + : CFAErrorKind::EarlyReturnStmt; if (i != node.stmts.size()) { - errors.push_back({CFAErrorKind::EarlyReturnStmt, stmt->span}); + errors.push_back({kind, stmt->span}); break; } } diff --git a/src/Analysis/Resolver.cpp b/src/Analysis/Resolver.cpp index b3e11e5..4c62d52 100644 --- a/src/Analysis/Resolver.cpp +++ b/src/Analysis/Resolver.cpp @@ -123,7 +123,9 @@ void Resolver::visit(BinaryExprAST &node) { void Resolver::visit(CallExprAST &node) { ASTVisitor::visit(*node.callee); - ASTVisitor::visit(*node.arg); + for (auto *arg : node.args) { + ASTVisitor::visit(*arg); + } } void Resolver::visit(IndexExprAST &node) { diff --git a/src/Analysis/TypeChecker.cpp b/src/Analysis/TypeChecker.cpp index b8efca9..47904b8 100644 --- a/src/Analysis/TypeChecker.cpp +++ b/src/Analysis/TypeChecker.cpp @@ -1,4 +1,5 @@ #include "Analysis/TypeChecker.h" +#include "ADT/NonOwningList.h" namespace { @@ -14,19 +15,20 @@ namespace lang { TextError TypeCheckerError::toTextError() const { switch (kind) { - case TypeCheckerErrorKind::InvalidReturn: + case TypeCheckerErrorKind::InvalidReturnStmt: return {span, "Invalid return statement", "Return type mismatch"}; case TypeCheckerErrorKind::InvalidAssignment: return {span, "Invalid assignment", "Type mismatch"}; case TypeCheckerErrorKind::InvalidBinaryOperation: return {span, "Invalid binary operation", "Type mismatch"}; } - return {span, "Unknown type checking error title", "Unknown type checking error label"}; + return {span, "Unknown type checking error title", + "Unknown type checking error label"}; } JSONError TypeCheckerError::toJSONError() const { switch (kind) { - case TypeCheckerErrorKind::InvalidReturn: + case TypeCheckerErrorKind::InvalidReturnStmt: return {span, "type-check-invalid-return"}; case TypeCheckerErrorKind::InvalidAssignment: return {span, "type-check-invalid-assignment"}; @@ -55,16 +57,16 @@ void TypeChecker::visit(ReturnStmtAST &node) { ASTVisitor::visit(*node.expr); } - if (currentFunction->retType == nullptr) { - if (node.expr != nullptr) { - errors.push_back({TypeCheckerErrorKind::InvalidReturn, node.span}); - } + if (currentFunction->retType == nullptr && node.expr != nullptr) { + errors.push_back({TypeCheckerErrorKind::InvalidReturnStmt, node.span}); } else { if (node.expr == nullptr) { - errors.push_back({TypeCheckerErrorKind::InvalidReturn, node.span}); + errors.push_back( + {TypeCheckerErrorKind::InvalidReturnStmt, node.span}); } else { if (node.expr->type != currentFunction->retType) { - errors.push_back({TypeCheckerErrorKind::InvalidReturn, node.span}); + errors.push_back( + {TypeCheckerErrorKind::InvalidReturnStmt, node.span}); } } } @@ -78,12 +80,14 @@ void TypeChecker::visit(LocalStmtAST &node) { node.type = node.init->type; } else { if (node.type != node.init->type) { - errors.push_back({TypeCheckerErrorKind::InvalidAssignment, node.span}); + errors.push_back( + {TypeCheckerErrorKind::InvalidAssignment, node.span}); } } } else { if (node.type == nullptr) { - errors.push_back({TypeCheckerErrorKind::InvalidAssignment, node.span}); + errors.push_back( + {TypeCheckerErrorKind::InvalidAssignment, node.span}); } } } @@ -118,17 +122,18 @@ void TypeChecker::visit(WhileStmtAST &node) { } void TypeChecker::visit(IdentifierExprAST &node) { - std::visit( - Overloaded{ - [&](const std::monostate) {}, - [&](const LocalStmtAST *stmt) { node.type = stmt->type; }, - // TODO: here the type should be function pointer - [&](const FunctionDeclAST *decl) { node.type = decl->retType; }, - }, - node.decl); + std::visit(Overloaded{ + [&](const std::monostate) { node.type = nullptr; }, + [&](const LocalStmtAST *stmt) { node.type = stmt->type; }, + // TODO: continue + [&](const FunctionDeclAST *decl) { node.type = decl->type; }, + }, + node.decl); } -void TypeChecker::visit(NumberExprAST &node) { node.type = typeCtx.getTypeNumber(); } +void TypeChecker::visit(NumberExprAST &node) { + node.type = typeCtx->getTypeNumber(); +} void TypeChecker::visit(UnaryExprAST &node) { ASTVisitor::visit(*node.expr); @@ -141,7 +146,8 @@ void TypeChecker::visit(BinaryExprAST &node) { ASTVisitor::visit(*node.rhs); if (node.lhs->type != node.rhs->type) { - errors.push_back({TypeCheckerErrorKind::InvalidBinaryOperation, node.span}); + errors.push_back( + {TypeCheckerErrorKind::InvalidBinaryOperation, node.span}); } node.type = node.lhs->type; @@ -149,7 +155,9 @@ void TypeChecker::visit(BinaryExprAST &node) { void TypeChecker::visit(CallExprAST &node) { ASTVisitor::visit(*node.callee); - ASTVisitor::visit(*node.arg); + for (auto &arg : node.args) { + ASTVisitor::visit(*arg); + } // TODO: Verify that the callee is a function and that the argument types // match the parameter types diff --git a/src/Codegen/Codegen.cpp b/src/Codegen/Codegen.cpp index 592a63f..f1d0093 100644 --- a/src/Codegen/Codegen.cpp +++ b/src/Codegen/Codegen.cpp @@ -12,7 +12,7 @@ template Overloaded(Ts...) -> Overloaded; namespace lang { -llvm::Module *Codegen::generate(const ModuleAST &module) { +llvm::Module *Codegen::generateFromModuleAST(const ModuleAST &module) { llvmModule = std::make_unique("main", *context); for (auto *decl : module.decls) { ASTVisitor::visit(*decl); @@ -120,7 +120,7 @@ void Codegen::visit(const WhileStmtAST &node) { void Codegen::visit(const IdentifierExprAST &node) { std::visit(Overloaded{ - [&](const std::monostate &decl) { exprResult = nullptr; }, + [&](const std::monostate) { exprResult = nullptr; }, [&](const LocalStmtAST *decl) { exprResult = namedValues.at(decl); }, @@ -186,17 +186,16 @@ void Codegen::visit(const BinaryExprAST &node) { void Codegen::visit(const CallExprAST &node) { if (node.callee->kind == ExprASTKind::Identifier) { auto *callee = static_cast(node.callee); - std::visit( - Overloaded{ - [&](const std::monostate &decl) { exprResult = nullptr; }, - [&](const LocalStmtAST *decl) { exprResult = nullptr; }, - [&](const FunctionDeclAST *decl) { - exprResult = builder->CreateCall( - llvmModule->getFunction(decl->ident), - std::vector(1, exprResult)); - }, - }, - callee->decl); + std::visit(Overloaded{ + [&](const std::monostate) { exprResult = nullptr; }, + [&](const LocalStmtAST *decl) { exprResult = nullptr; }, + [&](const FunctionDeclAST *decl) { + exprResult = builder->CreateCall( + llvmModule->getFunction(decl->ident), + std::vector(1, exprResult)); + }, + }, + callee->decl); } else { // TODO: // ASTVisitor::visit(*node.callee); diff --git a/src/Parse/Parser.cpp b/src/Parse/Parser.cpp index c2d26ff..f6d3f41 100644 --- a/src/Parse/Parser.cpp +++ b/src/Parse/Parser.cpp @@ -6,8 +6,8 @@ namespace { struct BinOpPair { lang::BinOpKind bind; - int precLHS; - int precRHS; + int precLhs; + int precRhs; }; const std::unordered_set declLevelSyncSet = { @@ -74,9 +74,8 @@ TextError ParseError::toTextError() const { "Expected " + tokenKindToString(expected) + " instead"}; case ParseErrorKind::ExpectedType: return {span, "Unexpected token", "Expected a type instead"}; - case ParseErrorKind::ExpectedPrimaryExpression: - return {span, "Unexpected token", - "Expected a primary expression instead"}; + case ParseErrorKind::ExpectedExpression: + return {span, "Unexpected token", "Expected an expression instead"}; } return {span, "Unknown parsing error title", "Unknown parse error label"}; } @@ -89,7 +88,7 @@ JSONError ParseError::toJSONError() const { return {span, "parse-unexpected-token"}; case ParseErrorKind::ExpectedType: return {span, "parse-expected-type"}; - case ParseErrorKind::ExpectedPrimaryExpression: + case ParseErrorKind::ExpectedExpression: return {span, "parse-expected-primary-expr"}; } return {span, "parser-unknown-error"}; @@ -399,8 +398,8 @@ StmtAST *Parser::parseExprStmtOrAssignStmtAST() { } break; default: - errors.emplace_back(ParseErrorKind::ExpectedPrimaryExpression, - tok->span, TokenKind::Amp); + errors.emplace_back(ParseErrorKind::ExpectedExpression, tok->span, + TokenKind::Amp); } return nullptr; @@ -416,10 +415,25 @@ ExprAST *Parser::parseExprAST(int prec) { case TokenKind::LParen: { ++cur; - ExprAST *arg = parseExprAST(); - RETURN_IF_NULL(arg); + NonOwningList args; + + const Token *argTok = peek(); + while (argTok != nullptr && argTok->kind != TokenKind::RParen) { + ExprAST *arg = parseExprAST(); + RETURN_IF_NULL(arg); + + args.emplace_back(arena, arg); - lhs = arena->alloc(tok->span, lhs, arg); + argTok = peek(); + if (argTok == nullptr || argTok->kind != TokenKind::Comma) { + break; + } + + ++cur; + argTok = peek(); + } + + lhs = arena->alloc(tok->span, lhs, args); EXPECT(TokenKind::RParen); } break; @@ -440,12 +454,12 @@ ExprAST *Parser::parseExprAST(int prec) { if (it == binOpMap.end()) { return lhs; } - if (it->second.precLHS <= prec) { + if (it->second.precLhs <= prec) { return lhs; } ++cur; lhs = arena->alloc(tok->span, it->second.bind, lhs, - parseExprAST(it->second.precRHS)); + parseExprAST(it->second.precRhs)); } tok = peek(); @@ -480,8 +494,8 @@ ExprAST *Parser::parsePrimaryExprAST() { return arena->alloc(tok->span, expr); default: - errors.emplace_back(ParseErrorKind::ExpectedPrimaryExpression, - tok->span, TokenKind::Amp); + errors.emplace_back(ParseErrorKind::ExpectedExpression, tok->span, + TokenKind::Amp); } return nullptr; diff --git a/src/Support/Reporting.cpp b/src/Support/Reporting.cpp index af91397..dc0421c 100644 --- a/src/Support/Reporting.cpp +++ b/src/Support/Reporting.cpp @@ -7,7 +7,7 @@ void reportTextError(llvm::raw_ostream &os, const SourceFile &file, const TextError &error, unsigned lineNoWidthHint) { assert(!error.span.empty() && "span cannot be empty"); const SourceLocation loc = file.getLocation(error.span); - const unsigned lineNoWidth = getNumDigits(loc.line); + const unsigned lineNoWidth = numDigits(loc.line); const unsigned lineNoMaxWidth = std::max(lineNoWidth, lineNoWidthHint); const std::string lineNoSpacesTitle(lineNoMaxWidth + 1, ' '); const std::string lineNoSpacesBody(lineNoMaxWidth + 2, ' '); diff --git a/src/main.cpp b/src/main.cpp index 6e0d8cd..96e48c8 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2,7 +2,6 @@ #include "Support/Reporting.h" #include "Support/SourceFile.h" -#include "AST/AST.h" #include "AST/ASTPrinter.h" #include "Lex/Lexer.h" @@ -30,6 +29,7 @@ enum class CompilerUntilStage { None, Lex, AST, + Sema, }; enum class CompilerEmitAction { @@ -54,10 +54,13 @@ const llvm::cl::opt compilerErrorFormat( const llvm::cl::opt compilerUntilStage( "until", llvm::cl::desc("Select the stage until which the compiler run"), - llvm::cl::values(clEnumValN(CompilerUntilStage::Lex, "lex", - "Run the compiler until the lexing stage"), - clEnumValN(CompilerUntilStage::AST, "ast", - "Run the compiler until the parsing stage")), + llvm::cl::values( + clEnumValN(CompilerUntilStage::Lex, "lex", + "Run the compiler until the lexing stage"), + clEnumValN(CompilerUntilStage::AST, "ast", + "Run the compiler until the parsing stage"), + clEnumValN(CompilerUntilStage::Sema, "sema", + "Run the compiler until the semantic analysis stage")), llvm::cl::init(CompilerUntilStage::None)); const llvm::cl::opt compilerEmitAction( @@ -104,6 +107,10 @@ int main(int argc, char **argv) { const lang::SourceFile source(inputFilename, file.get()->getBuffer()); + // ------------------------------------------------------------------------- + // Lexing + // ------------------------------------------------------------------------- + lang::Lexer lexer(file.get()->getBuffer()); const auto lexResult = lexer.lexAll(); @@ -113,7 +120,7 @@ int main(int argc, char **argv) { } } - if (!lexResult.errors.empty()) { + if (lexResult.hasErrors()) { reportErrors(llvm::errs(), compilerErrorFormat, source, lexResult.errors); return EXIT_FAILURE; @@ -128,14 +135,19 @@ int main(int argc, char **argv) { return EXIT_SUCCESS; } + // ------------------------------------------------------------------------- + // Parsing + // ------------------------------------------------------------------------- + lang::Arena arena(lang::kiloBytes(32)); + lang::ASTPrinter astPrinter(llvm::outs()); + lang::TypeContext typeCtx(arena); - lang::ASTPrinter astPrinter(llvm::outs()); lang::Parser parser(arena, typeCtx, lexResult.tokens); - const auto parseResult = parser.parseModuleAST(); - DEBUG("%lu alloc() with %lu bytes", arena.totalAllocations(), + + DEBUG("%lu allocation(s) occupying %lu bytes", arena.totalAllocations(), arena.totalAllocated()); if (compilerEmitAction == CompilerEmitAction::Src) { @@ -144,10 +156,11 @@ int main(int argc, char **argv) { } if (compilerEmitAction == CompilerEmitAction::AST) { + llvm::outs() << "=== Parsed AST ===\n"; astPrinter.visit(*parseResult.module); } - if (!parseResult.errors.empty()) { + if (parseResult.hasErrors()) { reportErrors(llvm::errs(), compilerErrorFormat, source, parseResult.errors); return EXIT_FAILURE; @@ -157,54 +170,78 @@ int main(int argc, char **argv) { return EXIT_SUCCESS; } - lang::ModuleAST *module = parseResult.module; + // ------------------------------------------------------------------------- + // Control Flow Analysis + // ------------------------------------------------------------------------- lang::CFA controlFlowAnalyzer; - const auto cfaResult = controlFlowAnalyzer.analyzeModuleAST(*module); + const auto cfaResult = + controlFlowAnalyzer.analyzeModuleAST(*parseResult.module); - if (!cfaResult.errors.empty()) { + if (cfaResult.hasErrors()) { reportErrors(llvm::errs(), compilerErrorFormat, source, cfaResult.errors); return EXIT_FAILURE; } + // ------------------------------------------------------------------------- + // Resolution + // ------------------------------------------------------------------------- + lang::Resolver resolver; - const auto resolveResult = resolver.resolveModuleAST(*module); + const auto resolveResult = resolver.resolveModuleAST(*parseResult.module); if (compilerEmitAction == CompilerEmitAction::AST) { - astPrinter.visit(*module); + llvm::outs() << "=== Resolved AST ===\n"; + astPrinter.visit(*parseResult.module); } - if (!resolveResult.errors.empty()) { + if (resolveResult.hasErrors()) { reportErrors(llvm::errs(), compilerErrorFormat, source, resolveResult.errors); return EXIT_FAILURE; } + // ------------------------------------------------------------------------- + // Type Checking + // ------------------------------------------------------------------------- + lang::TypeChecker typeChecker(typeCtx); - const auto typeCheckerResult = typeChecker.analyzeModuleAST(*module); + const auto typeCheckerResult = + typeChecker.analyzeModuleAST(*parseResult.module); + + DEBUG("%lu custom type(s) created", typeCtx.getNumTypes()); if (compilerEmitAction == CompilerEmitAction::AST) { - astPrinter.visit(*module); + llvm::outs() << "=== Type-checked AST ===\n"; + astPrinter.visit(*parseResult.module); } - if (!resolveResult.errors.empty()) { + if (resolveResult.hasErrors()) { reportErrors(llvm::errs(), compilerErrorFormat, source, typeCheckerResult.errors); return EXIT_FAILURE; } + if (compilerUntilStage == CompilerUntilStage::Sema) { + return EXIT_SUCCESS; + } + + // ------------------------------------------------------------------------- + // Code Generation + // ------------------------------------------------------------------------- + lang::Codegen codegen; - const llvm::Module *llvmModule = codegen.generate(*module); + const llvm::Module *llvmModule = + codegen.generateFromModuleAST(*parseResult.module); if (compilerEmitAction == CompilerEmitAction::LLVM) { llvmModule->print(llvm::outs(), nullptr); } if (llvm::verifyModule(*llvmModule, &llvm::errs())) { - llvm::errs() << "Error: generated LLVM IR is invalid\n"; return EXIT_FAILURE; } - + return EXIT_SUCCESS; }