Skip to content

Commit

Permalink
Refactor implicit generator coding (#596)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer authored Jun 16, 2024
1 parent 41e7282 commit 0275aeb
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 59 deletions.
2 changes: 1 addition & 1 deletion .run/spice run.run.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="spice run" type="CMakeRunConfiguration" factoryName="Application" PROGRAM_PARAMS="run -O1 -d -ir ../../media/test-project/test.spice" REDIRECT_INPUT="false" ELEVATE="false" USE_EXTERNAL_CONSOLE="false" EMULATE_TERMINAL="false" PASS_PARENT_ENVS_2="true" PROJECT_NAME="Spice" TARGET_NAME="spice" CONFIG_NAME="Debug" RUN_TARGET_PROJECT_NAME="Spice" RUN_TARGET_NAME="spice">
<configuration default="false" name="spice run" type="CMakeRunConfiguration" factoryName="Application" PROGRAM_PARAMS="run -O0 -d -ir ../../media/test-project/test.spice" REDIRECT_INPUT="false" ELEVATE="false" USE_EXTERNAL_CONSOLE="false" EMULATE_TERMINAL="false" PASS_PARENT_ENVS_2="true" PROJECT_NAME="Spice" TARGET_NAME="spice" CONFIG_NAME="Debug" RUN_TARGET_PROJECT_NAME="Spice" RUN_TARGET_NAME="spice">
<envs>
<env name="LLVM_ADDITIONAL_FLAGS" value="-lole32 -lws2_32" />
<env name="LLVM_BUILD_INCLUDE_DIR" value="$PROJECT_DIR$/../llvm-project-latest/build/include" />
Expand Down
21 changes: 21 additions & 0 deletions src/ast/ASTNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@

namespace spice::compiler {

// Operator overload function names
constexpr const char *const OP_FCT_PREFIX = "op.";
constexpr const char *const OP_FCT_PLUS = "op.plus";
constexpr const char *const OP_FCT_MINUS = "op.minus";
constexpr const char *const OP_FCT_MUL = "op.mul";
constexpr const char *const OP_FCT_DIV = "op.div";
constexpr const char *const OP_FCT_EQUAL = "op.equal";
constexpr const char *const OP_FCT_NOT_EQUAL = "op.notequal";
constexpr const char *const OP_FCT_SHL = "op.shl";
constexpr const char *const OP_FCT_SHR = "op.shr";
constexpr const char *const OP_FCT_PLUS_EQUAL = "op.plusequal";
constexpr const char *const OP_FCT_MINUS_EQUAL = "op.minusequal";
constexpr const char *const OP_FCT_MUL_EQUAL = "op.mulequal";
constexpr const char *const OP_FCT_DIV_EQUAL = "op.divequal";
constexpr const char *const OP_FCT_POSTFIX_PLUS_PLUS = "op.plusplus.post";
constexpr const char *const OP_FCT_POSTFIX_MINUS_MINUS = "op.minusminus.post";

/**
* Saves a constant value for an AST node to realize features like array-out-of-bounds checks
*/
Expand Down Expand Up @@ -254,6 +271,10 @@ class FctNameNode : public ASTNode {
std::any accept(AbstractASTVisitor *visitor) override { return visitor->visitFctName(this); }
std::any accept(ParallelizableASTVisitor *visitor) const override { return visitor->visitFctName(this); }

// Other methods
[[nodiscard]] constexpr bool isOperatorOverload() const { return name.starts_with(OP_FCT_PREFIX); }
[[nodiscard]] bool supportsInverseOperator() const { return name == OP_FCT_EQUAL || name == OP_FCT_NOT_EQUAL; }

// Public members
std::string name;
std::string structName;
Expand Down
8 changes: 2 additions & 6 deletions src/irgenerator/GenBuiltinFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,8 @@ std::any IRGenerator::visitPrintfCall(const PrintfCallNode *node) {
}

// Extend all integer types lower than 32 bit to 32 bit
if (argSymbolType.removeReferenceWrapper().isOneOf({TY_SHORT, TY_BYTE, TY_CHAR, TY_BOOL})) {
if (argSymbolType.removeReferenceWrapper().isSigned())
argVal = builder.CreateSExt(argVal, llvm::Type::getInt32Ty(context));
else
argVal = builder.CreateZExt(argVal, llvm::Type::getInt32Ty(context));
}
if (argSymbolType.removeReferenceWrapper().isOneOf({TY_SHORT, TY_BYTE, TY_CHAR, TY_BOOL}))
argVal = builder.CreateIntCast(argVal, builder.getInt32Ty(), argSymbolType.removeReferenceWrapper().isSigned());

printfArgs.push_back(argVal);
}
Expand Down
63 changes: 44 additions & 19 deletions src/irgenerator/GenImplicit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,49 @@ void IRGenerator::generateScopeCleanup(const StmtLstNode *node) const {
}
}

llvm::Value *IRGenerator::generateFctCall(const Function *fct, const std::vector<llvm::Value *> &args) const {
// Retrieve metadata for the function
const std::string mangledName = fct->getMangledName();

// Function is not defined in the current module -> declare it
if (!module->getFunction(mangledName)) {
std::vector<llvm::Type *> paramTypes;
for (const llvm::Value *argValue : args)
paramTypes.push_back(argValue->getType());
llvm::Type *returnType = fct->returnType.toLLVMType(sourceFile);
llvm::FunctionType *fctType = llvm::FunctionType::get(returnType, paramTypes, false);
module->getOrInsertFunction(mangledName, fctType);
}

// Get callee function
llvm::Function *callee = module->getFunction(mangledName);
assert(callee != nullptr);

// Generate function call
return builder.CreateCall(callee, args);
}

void IRGenerator::generateProcCall(const Function *proc, std::vector<llvm::Value *> &args) const {
// Retrieve metadata for the function
const std::string mangledName = proc->getMangledName();

// Function is not defined in the current module -> declare it
if (!module->getFunction(mangledName)) {
std::vector<llvm::Type *> paramTypes;
for (const llvm::Value *argValue : args)
paramTypes.push_back(argValue->getType());
llvm::FunctionType *fctType = llvm::FunctionType::get(builder.getVoidTy(), paramTypes, false);
module->getOrInsertFunction(mangledName, fctType);
}

// Get callee function
llvm::Function *callee = module->getFunction(mangledName);
assert(callee != nullptr);

// Generate function call
builder.CreateCall(callee, args);
}

void IRGenerator::generateCtorOrDtorCall(SymbolTableEntry *entry, const Function *ctorOrDtor,
const std::vector<llvm::Value *> &args) const {
// Retrieve address of the struct variable. For fields this is the 'this' variable, otherwise use the normal address
Expand All @@ -95,30 +138,12 @@ void IRGenerator::generateCtorOrDtorCall(SymbolTableEntry *entry, const Function

void IRGenerator::generateCtorOrDtorCall(llvm::Value *structAddr, const Function *ctorOrDtor,
const std::vector<llvm::Value *> &args) const {
assert(ctorOrDtor != nullptr);

// Retrieve metadata for the function
const std::string mangledName = ctorOrDtor->getMangledName();

// Function is not defined in the current module -> declare it
if (!module->getFunction(mangledName)) {
std::vector<llvm::Type *> paramTypes = {builder.getPtrTy()};
for (llvm::Value *argValue : args)
paramTypes.push_back(argValue->getType());
llvm::FunctionType *fctType = llvm::FunctionType::get(builder.getVoidTy(), paramTypes, false);
module->getOrInsertFunction(mangledName, fctType);
}

// Get callee function
llvm::Function *callee = module->getFunction(mangledName);
assert(callee != nullptr);

// Build parameter list
std::vector<llvm::Value *> argValues = {structAddr};
argValues.insert(argValues.end(), args.begin(), args.end());

// Generate function call
builder.CreateCall(callee, argValues);
generateProcCall(ctorOrDtor, argValues);
}

void IRGenerator::generateDeallocCall(llvm::Value *variableAddress) const {
Expand Down
2 changes: 2 additions & 0 deletions src/irgenerator/IRGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ class IRGenerator : private CompilerPass, public ParallelizableASTVisitor {
// Generate implicit
llvm::Value *doImplicitCast(llvm::Value *src, QualType dstSTy, QualType srcSTy);
void generateScopeCleanup(const StmtLstNode *node) const;
llvm::Value *generateFctCall(const Function *fct, const std::vector<llvm::Value *> &args) const;
void generateProcCall(const Function *proc, std::vector<llvm::Value *> &args) const;
void generateCtorOrDtorCall(SymbolTableEntry *entry, const Function *ctorOrDtor, const std::vector<llvm::Value *> &args) const;
void generateCtorOrDtorCall(llvm::Value *structAddr, const Function *ctorOrDtor, const std::vector<llvm::Value *> &args) const;
void generateDeallocCall(llvm::Value *variableAddress) const;
Expand Down
4 changes: 2 additions & 2 deletions src/irgenerator/OpRuleConversionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ LLVMExprResult OpRuleConversionManager::getPlusEqualInst(const ASTNode *node, LL
case COMB(TY_PTR, TY_SHORT): // fallthrough
case COMB(TY_PTR, TY_LONG): {
llvm::Type *elementTy = lhsSTy.getContained().toLLVMType(irGenerator->sourceFile);
llvm::Value *rhsVExt = builder.CreateSExt(rhsV(), builder.getInt64Ty());
llvm::Value *rhsVExt = builder.CreateIntCast(rhsV(), builder.getInt64Ty(), rhsSTy.isSigned());
return {.value = builder.CreateGEP(elementTy, lhsV(), rhsVExt)};
}
default: // GCOV_EXCL_LINE
Expand Down Expand Up @@ -120,7 +120,7 @@ LLVMExprResult OpRuleConversionManager::getMinusEqualInst(const ASTNode *node, L
case COMB(TY_PTR, TY_SHORT): // fallthrough
case COMB(TY_PTR, TY_LONG): {
llvm::Type *elementTy = lhsSTy.getContained().toLLVMType(irGenerator->sourceFile);
llvm::Value *rhsVExt = builder.CreateSExt(rhsV(), builder.getInt64Ty());
llvm::Value *rhsVExt = builder.CreateIntCast(rhsV(), builder.getInt64Ty(), rhsSTy.isSigned());
llvm::Value *rhsVNeg = builder.CreateNeg(rhsVExt);
return {.value = builder.CreateGEP(elementTy, lhsV(), rhsVNeg)};
}
Expand Down
16 changes: 0 additions & 16 deletions src/typechecker/OpRuleManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@ class GlobalResourceManager;
// Helper macro to get the length of an array
#define ARRAY_LENGTH(array) sizeof(array) / sizeof(*array)

// Operator overload function names
const char *const OP_FCT_PLUS = "op.plus";
const char *const OP_FCT_MINUS = "op.minus";
const char *const OP_FCT_MUL = "op.mul";
const char *const OP_FCT_DIV = "op.div";
const char *const OP_FCT_EQUAL = "op.equal";
const char *const OP_FCT_NOT_EQUAL = "op.notequal";
const char *const OP_FCT_SHL = "op.shl";
const char *const OP_FCT_SHR = "op.shr";
const char *const OP_FCT_PLUS_EQUAL = "op.plusequal";
const char *const OP_FCT_MINUS_EQUAL = "op.minusequal";
const char *const OP_FCT_MUL_EQUAL = "op.mulequal";
const char *const OP_FCT_DIV_EQUAL = "op.divequal";
const char *const OP_FCT_POSTFIX_PLUS_PLUS = "op.plusplus.post";
const char *const OP_FCT_POSTFIX_MINUS_MINUS = "op.minusminus.post";

// Custom error message prefixes
const char *const ERROR_MSG_RETURN = "Passed wrong data type to return statement";
const char *const ERROR_FOREACH_ITEM = "Passed wrong data type to foreach item";
Expand Down
8 changes: 2 additions & 6 deletions src/typechecker/TypeCheckerPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ std::any TypeChecker::visitFctDefPrepare(FctDefNode *node) {
usedGenericTypes.emplace_back(templateType);
usedGenericTypes.back().used = true;
}
}

// Set type of 'this' variable
if (node->isMethod) {
// Set type of 'this' variable
SymbolTableEntry *thisEntry = currentScope->lookupStrict(THIS_VARIABLE_NAME);
assert(thisEntry != nullptr);
thisEntry->updateType(thisPtrType, false);
Expand Down Expand Up @@ -242,10 +240,8 @@ std::any TypeChecker::visitProcDefPrepare(ProcDefNode *node) {
usedGenericTypes.emplace_back(templateType);
usedGenericTypes.back().used = true;
}
}

// Set type of 'this' variable
if (node->isMethod) {
// Set type of 'this' variable
SymbolTableEntry *thisEntry = currentScope->lookupStrict(THIS_VARIABLE_NAME);
assert(thisEntry != nullptr);
thisEntry->updateType(thisPtrType, false);
Expand Down
6 changes: 4 additions & 2 deletions std/data/queue.spice
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ public f<bool> operator==<T>(const Queue<T>& lhs, const Queue<T>& rhs) {
// Compare the sizes
if lhs.size != rhs.size { return false; }
// Compare the contents
for unsigned long i = 0l; index < lhs.size; i++ {
if lhs.contents[i] != rhs.contents[i] { return false; }
unsafe {
for unsigned long index = 0l; index < lhs.size; index++ {
if lhs.contents[index] != rhs.contents[index] { return false; }
}
}
return true;
}
Expand Down
6 changes: 4 additions & 2 deletions std/data/stack.spice
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ public f<bool> operator==<T>(const Stack<T>& lhs, const Stack<T>& rhs) {
// Compare the sizes
if lhs.size != rhs.size { return false; }
// Compare the contents
for unsigned long i = 0l; index < lhs.size; i++ {
if lhs.contents[i] != rhs.contents[i] { return false; }
unsafe {
for unsigned long i = 0l; index < lhs.size; i++ {
if lhs.contents[i] != rhs.contents[i] { return false; }
}
}
return true;
}
Expand Down
6 changes: 4 additions & 2 deletions std/data/vector.spice
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,10 @@ public f<bool> operator==<T>(const Vector<T>& lhs, const Vector<T>& rhs) {
// Compare the sizes
if lhs.size != rhs.size { return false; }
// Compare the contents
for unsigned long index = 0l; index < lhs.size; index++ {
if lhs.contents[index] != rhs.contents[index] { return false; }
unsafe {
for unsigned long index = 0l; index < lhs.size; index++ {
if lhs.contents[index] != rhs.contents[index] { return false; }
}
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,11 @@ attributes #3 = { cold noreturn nounwind }
!42 = !{!"branch_weights", i32 2000, i32 1}
!43 = !DILocation(line: 13, column: 14, scope: !15)
!44 = !DILocalVariable(name: "it", scope: !15, file: !5, line: 13, type: !45)
!45 = !DICompositeType(tag: DW_TAG_structure_type, name: "VectorIterator", scope: !5, file: !5, line: 258, size: 192, align: 8, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !46, identifier: "struct.VectorIterator")
!45 = !DICompositeType(tag: DW_TAG_structure_type, name: "VectorIterator", scope: !5, file: !5, line: 260, size: 192, align: 8, flags: DIFlagTypePassByReference | DIFlagNonTrivial, elements: !46, identifier: "struct.VectorIterator")
!46 = !{!47, !49}
!47 = !DIDerivedType(tag: DW_TAG_member, name: "vector", scope: !45, file: !5, line: 259, baseType: !48, size: 64, offset: 64)
!47 = !DIDerivedType(tag: DW_TAG_member, name: "vector", scope: !45, file: !5, line: 261, baseType: !48, size: 64, offset: 64)
!48 = !DIDerivedType(tag: DW_TAG_reference_type, baseType: !28, size: 64)
!49 = !DIDerivedType(tag: DW_TAG_member, name: "cursor", scope: !45, file: !5, line: 260, baseType: !33, size: 64, offset: 128)
!49 = !DIDerivedType(tag: DW_TAG_member, name: "cursor", scope: !45, file: !5, line: 262, baseType: !33, size: 64, offset: 128)
!50 = !DILocation(line: 13, column: 5, scope: !15)
!51 = !DILocation(line: 14, column: 12, scope: !15)
!52 = !DILocation(line: 15, column: 12, scope: !15)
Expand Down

0 comments on commit 0275aeb

Please sign in to comment.