Skip to content

Commit

Permalink
feat: codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
jsilll committed Sep 25, 2024
1 parent 26af1f6 commit 14e4048
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 1 deletion.
65 changes: 65 additions & 0 deletions include/Codegen/Codegen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#ifndef LANG_CODEGEN_H
#define LANG_CODEGEN_H

#include "AST/ASTVisitor.h"

#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"

namespace lang {

class Codegen : public ConstASTVisitor<Codegen> {
friend class ASTVisitor<Codegen, true>;

public:
Codegen()
: context(std::make_unique<llvm::LLVMContext>()),
builder(std::make_unique<llvm::IRBuilder<>>(*context)) {}

llvm::Module *generate(const ModuleAST &module);

private:
std::unique_ptr<llvm::LLVMContext> context;
std::unique_ptr<llvm::IRBuilder<>> builder;
std::unique_ptr<llvm::Module> llvmModule = nullptr;

llvm::Value *exprResult = nullptr;
std::unordered_map<const StmtAST *, llvm::BasicBlock *> blocks;
std::unordered_map<const LocalStmtAST *, llvm::Value *> locals;

void visit(const FunctionDeclAST &node);

void visit(const ExprStmtAST &node);

void visit(const BreakStmtAST &node);

void visit(const ReturnStmtAST &node);

void visit(const LocalStmtAST &node);

void visit(const AssignStmtAST &node);

void visit(const BlockStmtAST &node);

void visit(const IfStmtAST &node);

void visit(const WhileStmtAST &node);

void visit(const IdentifierExprAST &node);

void visit(const NumberExprAST &node);

void visit(const UnaryExprAST &node);

void visit(const BinaryExprAST &node);

void visit(const CallExprAST &node);

void visit(const IndexExprAST &node);

void visit(const GroupedExprAST &node);
};

} // namespace lang

#endif // LANG_CODEGEN_H
109 changes: 109 additions & 0 deletions src/Codegen/Codegen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include "Codegen/Codegen.h"

namespace {

template <class... Ts> struct Overloaded : Ts... {
using Ts::operator()...;
};

template <class... Ts> Overloaded(Ts...) -> Overloaded<Ts...>;

} // namespace

namespace lang {

llvm::Module *Codegen::generate(const ModuleAST &module) {
llvmModule = std::make_unique<llvm::Module>("main", *context);

for (auto *decl : module.decls) {
ASTVisitor::visit(*decl);
}

return llvmModule.get();
}

void Codegen::visit(const FunctionDeclAST &node) {
// TODO: handle function arguments and return type
auto *funcType = llvm::FunctionType::get(llvm::Type::getVoidTy(*context),
false /* isVarArg */);

auto *func =
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage,
node.ident, llvmModule.get());

auto *entry = llvm::BasicBlock::Create(*context, "entry", func);
builder->SetInsertPoint(entry);
ASTVisitor::visit(*node.body);
builder->CreateRetVoid();
}

void Codegen::visit(const ExprStmtAST &node) { ASTVisitor::visit(*node.expr); }

void Codegen::visit(const BreakStmtAST &node) { /* TODO */
}

void Codegen::visit(const ReturnStmtAST &node) {
if (node.expr) {
ASTVisitor::visit(*node.expr);
builder->CreateRet(exprResult);
} else {
builder->CreateRetVoid();
}
}

void Codegen::visit(const LocalStmtAST &node) {
auto *alloca = builder->CreateAlloca(llvm::Type::getFloatTy(*context),
nullptr /* ArraySize */, node.span);

locals[&node] = alloca;

if (node.init) {
ASTVisitor::visit(*node.init);
builder->CreateStore(exprResult, alloca);
}
}

void Codegen::visit(const AssignStmtAST &node) { /* TODO */
}

void Codegen::visit(const BlockStmtAST &node) {
for (auto *stmt : node.stmts) {
ASTVisitor::visit(*stmt);
}
}

void Codegen::visit(const IfStmtAST &node) { /* TODO */
}

void Codegen::visit(const WhileStmtAST &node) { /* TODO */
}

void Codegen::visit(const IdentifierExprAST &node) { /* TODO */
exprResult = llvm::UndefValue::get(llvm::Type::getFloatTy(*context));
}

void Codegen::visit(const NumberExprAST &node) { /* TODO */
exprResult = llvm::UndefValue::get(llvm::Type::getFloatTy(*context));
}

void Codegen::visit(const UnaryExprAST &node) { /* TODO */
exprResult = llvm::UndefValue::get(llvm::Type::getFloatTy(*context));
}

void Codegen::visit(const BinaryExprAST &node) { /* TODO */
exprResult = llvm::UndefValue::get(llvm::Type::getFloatTy(*context));
}

void Codegen::visit(const CallExprAST &node) { /* TODO */
exprResult = llvm::UndefValue::get(llvm::Type::getFloatTy(*context));
}

void Codegen::visit(const IndexExprAST &node) { /* TODO */
exprResult = llvm::UndefValue::get(llvm::Type::getFloatTy(*context));
}

void Codegen::visit(const GroupedExprAST &node) { /* TODO */
exprResult = llvm::UndefValue::get(llvm::Type::getFloatTy(*context));
}

} // namespace lang
14 changes: 13 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "Analysis/Resolver.h"
#include "Analysis/TypeChecker.h"

#include "Codegen/Codegen.h"

#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBuffer.h"

Expand All @@ -34,6 +36,7 @@ enum class CompilerEmitAction {
Lex,
Src,
AST,
LLVM,
};

const llvm::cl::opt<std::string>
Expand Down Expand Up @@ -64,7 +67,9 @@ const llvm::cl::opt<CompilerEmitAction> compilerEmitAction(
clEnumValN(CompilerEmitAction::Src, "src",
"Emit the original source code of the input file"),
clEnumValN(CompilerEmitAction::AST, "ast",
"Emit the abstract syntax tree of the input file")),
"Emit the abstract syntax tree of the input file"),
clEnumValN(CompilerEmitAction::LLVM, "llvm",
"Emit the LLVM IR of the input file")),
llvm::cl::init(CompilerEmitAction::None));

template <typename T>
Expand Down Expand Up @@ -187,4 +192,11 @@ int main(int argc, char **argv) {
typeCheckerResult.errors);
return EXIT_FAILURE;
}

lang::Codegen codegen;
const llvm::Module *llvmModule = codegen.generate(*module);

if (compilerEmitAction == CompilerEmitAction::LLVM) {
llvmModule->print(llvm::outs(), nullptr);
}
}

0 comments on commit 14e4048

Please sign in to comment.