Skip to content

Commit

Permalink
Use string objects for extending strings (#199)
Browse files Browse the repository at this point in the history
* Refactoring

* Add ctor for plus operator

* Start implementing the plus operator for strings

* Fix tests

* Materialize string object

* Restructure code & add equality operators for strings

* Fix bug and add test
  • Loading branch information
marcauberer authored Sep 3, 2022
1 parent 41e267a commit c04d310
Show file tree
Hide file tree
Showing 26 changed files with 500 additions and 135 deletions.
14 changes: 10 additions & 4 deletions media/test-project/os-test.spice
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@ f<int> main() {
import "std/runtime/string_rt" as _rt_str;

f<int> main() {
_rt_str::String s1 = _rt_str::String('H');
s1.append("ello");

printf("Equals: %d", s1.opEquals("Hell2"));
// Plus
printf("String: %s\n", "Hello " + "World!");
string s = "Hello " + "World!";
printf("String: %s\n", s);
// Equals
printf("String: %d\n", "Hello World!" == "Hello Programmers!");
printf("String: %d\n", "Hello" == "Hell2");
// Not equals
printf("String: %d\n", "Hello World!" != "Hello Programmers!");
printf("String: %d\n", "Hello" != "Hell2");
}
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ set(SOURCES
analyzer/OpRuleManager.h
generator/GeneratorVisitor.cpp
generator/GeneratorVisitor.h
generator/StdFunctionManager.cpp
generator/StdFunctionManager.h
generator/OpRuleConversionsManager.cpp
generator/OpRuleConversionsManager.h
linker/LinkerInterface.cpp
Expand Down
6 changes: 3 additions & 3 deletions src/analyzer/AnalyzerVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ std::any AnalyzerVisitor::visitStructDef(StructDefNode *node) {
}

// Build struct specifiers
SymbolType symbolType = SymbolType(TY_STRUCT, node->structName, templateTypes);
SymbolType symbolType = SymbolType(TY_STRUCT, node->structName, {}, templateTypes);
auto structSymbolSpecifiers = SymbolSpecifiers(symbolType);
if (SpecifierLstNode *specifierLst = node->specifierLst(); specifierLst) {
for (const auto &specifier : specifierLst->specifiers()) {
Expand Down Expand Up @@ -1324,7 +1324,7 @@ std::any AnalyzerVisitor::visitAssignExpr(AssignExprNode *node) {
"The variable '" + variableName + "' was referenced before defined");

// Perform type inference
if (lhsTy.is(TY_DYN))
if (lhsTy.is(TY_DYN) || (lhsTy.is(TY_STRING) && rhsTy.is(TY_STRING)))
currentEntry->updateType(rhsTy, false);

// Update state in symbol table
Expand Down Expand Up @@ -2296,7 +2296,7 @@ std::any AnalyzerVisitor::visitCustomDataType(CustomDataTypeNode *node) {
throw SemanticError(node->codeLoc, UNKNOWN_DATATYPE, "Unknown datatype '" + identifier + "'");
structSymbol->setUsed();

return node->setEvaluatedSymbolType(SymbolType(TY_STRUCT, identifier, concreteTemplateTypes));
return node->setEvaluatedSymbolType(SymbolType(TY_STRUCT, identifier, {.arraySize = 0}, concreteTemplateTypes));
}

void AnalyzerVisitor::insertDestructorCall(const CodeLoc &codeLoc, SymbolTableEntry *varEntry) {
Expand Down
7 changes: 7 additions & 0 deletions src/analyzer/OpRuleManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ SymbolType OpRuleManager::getAssignResultType(const CodeLoc &codeLoc, const Symb
// Allow char* = string
if (lhs.isPointerOf(TY_CHAR) && rhs.is(TY_STRING))
return lhs;
// Allow string = string_object
if (lhs.is(TY_STRING) && rhs.is(TY_STRING))
return rhs;
// Check primitive type combinations
return validateBinaryOperation(codeLoc, ASSIGN_OP_RULES, "=", lhs, rhs);
}
Expand Down Expand Up @@ -155,6 +158,10 @@ SymbolType OpRuleManager::getPlusResultType(const CodeLoc &codeLoc, const Symbol
throw printErrorMessageUnsafe(codeLoc, "+", lhs, rhs);
}

// Allow string + string
if (lhs.is(TY_STRING) && rhs.is(TY_STRING))
return SymbolType(TY_STRING, "", { .isStringStruct = true }, {});

return validateBinaryOperation(codeLoc, PLUS_OP_RULES, "+", lhs, rhs);
}

Expand Down
2 changes: 0 additions & 2 deletions src/analyzer/OpRuleManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ const std::vector<BinaryOpRule> ASSIGN_OP_RULES = {
BinaryOpRule(TY_LONG, TY_LONG, TY_LONG, false), // long = long -> long
BinaryOpRule(TY_BYTE, TY_BYTE, TY_BYTE, false), // byte = byte -> byte
BinaryOpRule(TY_CHAR, TY_CHAR, TY_CHAR, false), // char = char -> char
BinaryOpRule(TY_STRING, TY_STRING, TY_STRING, false), // string = string -> string
BinaryOpRule(TY_BOOL, TY_BOOL, TY_BOOL, false), // bool = bool -> bool
};

Expand Down Expand Up @@ -412,7 +411,6 @@ const std::vector<BinaryOpRule> PLUS_OP_RULES = {
BinaryOpRule(TY_LONG, TY_SHORT, TY_LONG, false), // long + short -> long
BinaryOpRule(TY_LONG, TY_LONG, TY_LONG, false), // long + long -> long
BinaryOpRule(TY_BYTE, TY_BYTE, TY_BYTE, false), // byte + byte -> byte
BinaryOpRule(TY_STRING, TY_STRING, TY_STRING, false), // string + string -> string
};

// Minus op rules
Expand Down
73 changes: 22 additions & 51 deletions src/generator/GeneratorVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <symbol/Function.h>
#include <symbol/Struct.h>
#include <symbol/SymbolTable.h>
#include <util/CommonUtil.h>
#include <util/FileUtil.h>
#include <util/ThreadFactory.h>

Expand Down Expand Up @@ -42,7 +41,10 @@ GeneratorVisitor::GeneratorVisitor(const std::shared_ptr<llvm::LLVMContext> &con

// Create LLVM base components
module = std::make_unique<llvm::Module>(FileUtil::getFileName(sourceFile.filePath), *context);
conversionsManager = std::make_unique<OpRuleConversionsManager>(context, builder);

// Initialize generator helper objects
stdFunctionManager = std::make_unique<StdFunctionManager>(this);
conversionsManager = std::make_unique<OpRuleConversionsManager>(this);

// Initialize LLVM
llvm::InitializeAllTargetInfos();
Expand Down Expand Up @@ -286,7 +288,7 @@ std::any GeneratorVisitor::visitMainFctDef(MainFctDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(retrieveStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -473,7 +475,7 @@ std::any GeneratorVisitor::visitFctDef(FctDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(retrieveStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -660,7 +662,7 @@ std::any GeneratorVisitor::visitProcDef(ProcDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(retrieveStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -1300,12 +1302,12 @@ std::any GeneratorVisitor::visitAssertStmt(AssertStmtNode *node) {
parentFct->getBasicBlockList().push_back(bThen);
moveInsertPointToBlock(bThen);
// Generate IR for assertion error
llvm::Function *printfFct = retrievePrintfFct();
llvm::Function *printfFct = stdFunctionManager->getPrintfFct();
std::string errorMsg = "Assertion failed: Condition '" + node->expressionString + "' evaluated to false.";
llvm::Value *templateString = builder->CreateGlobalStringPtr(errorMsg);
builder->CreateCall(printfFct, templateString);
// Generate call to exit
llvm::Function *exitFct = retrieveExitFct();
llvm::Function *exitFct = stdFunctionManager->getExitFct();
builder->CreateCall(exitFct, builder->getInt32(1));
// Create unreachable instruction
builder->CreateUnreachable();
Expand Down Expand Up @@ -1456,7 +1458,7 @@ std::any GeneratorVisitor::visitContinueStmt(ContinueStmtNode *node) {

std::any GeneratorVisitor::visitPrintfCall(PrintfCallNode *node) {
// Declare if not declared already
llvm::Function *printfFct = retrievePrintfFct();
llvm::Function *printfFct = stdFunctionManager->getPrintfFct();

std::vector<llvm::Value *> printfArgs;
printfArgs.push_back(builder->CreateGlobalStringPtr(node->templatedString));
Expand All @@ -1471,6 +1473,9 @@ std::any GeneratorVisitor::visitPrintfCall(PrintfCallNode *node) {
if (argSymbolType.isArray()) { // Convert array type to pointer type
llvm::Value *indices[2] = {builder->getInt32(0), builder->getInt32(0)};
argVal = builder->CreateInBoundsGEP(targetType, argValPtr, indices);
} else if (argSymbolType.isStringStruct()) {
argValPtr = materializeString(argValPtr);
argVal = builder->CreateLoad(targetType, argValPtr);
} else {
argVal = builder->CreateLoad(targetType, argValPtr);
}
Expand Down Expand Up @@ -1982,6 +1987,10 @@ std::any GeneratorVisitor::visitAdditiveExpr(AdditiveExprNode *node) {

switch (opQueue.front().first) {
case AdditiveExprNode::OP_PLUS:
/*if (lhsSymbolType.isStringStruct())
lhs = materializeString(lhsPtr);
if (rhsSymbolType.isStringStruct())
rhs = materializeString(rhsPtr);*/
lhs = conversionsManager->getPlusInst(lhs, rhs, lhsSymbolType, rhsSymbolType, currentScope, node->codeLoc);
break;
case AdditiveExprNode::OP_MINUS:
Expand Down Expand Up @@ -2976,7 +2985,7 @@ llvm::Value *GeneratorVisitor::insertAlloca(llvm::Type *llvmType, const std::str

llvm::Value *GeneratorVisitor::allocateDynamicallySizedArray(llvm::Type *itemType) {
// Call llvm.stacksave intrinsic
llvm::Function *stackSaveFct = retrieveStackSaveFct();
llvm::Function *stackSaveFct = stdFunctionManager->getStackSaveFct();
if (stackState == nullptr)
stackState = builder->CreateCall(stackSaveFct);
// Allocate array
Expand Down Expand Up @@ -3043,48 +3052,10 @@ bool GeneratorVisitor::insertDestructorCall(const CodeLoc &codeLoc, SymbolTableE
return true;
}

llvm::Function *GeneratorVisitor::retrievePrintfFct() {
std::string printfFctName = "printf";
llvm::Function *printfFct = module->getFunction(printfFctName);
if (printfFct)
return printfFct;
// Not found -> declare it for linkage
llvm::FunctionType *printfFctTy = llvm::FunctionType::get(builder->getInt32Ty(), builder->getInt8PtrTy(), true);
module->getOrInsertFunction(printfFctName, printfFctTy);
return module->getFunction(printfFctName);
}

llvm::Function *GeneratorVisitor::retrieveExitFct() {
std::string exitFctName = "exit";
llvm::Function *exitFct = module->getFunction(exitFctName);
if (exitFct)
return exitFct;
// Not found -> declare it for linkage
llvm::FunctionType *exitFctTy = llvm::FunctionType::get(builder->getVoidTy(), builder->getInt32Ty(), false);
module->getOrInsertFunction(exitFctName, exitFctTy);
return module->getFunction(exitFctName);
}

llvm::Function *GeneratorVisitor::retrieveStackSaveFct() {
std::string stackSaveFctName = "llvm.stacksave";
llvm::Function *stackSaveFct = module->getFunction(stackSaveFctName);
if (stackSaveFct)
return stackSaveFct;
// Not found -> declare it for linkage
llvm::FunctionType *stackSaveFctTy = llvm::FunctionType::get(builder->getInt8PtrTy(), {}, false);
module->getOrInsertFunction(stackSaveFctName, stackSaveFctTy);
return module->getFunction(stackSaveFctName);
}

llvm::Function *GeneratorVisitor::retrieveStackRestoreFct() {
std::string stackRestoreFctName = "llvm.stackrestore";
llvm::Function *stackRestoreFct = module->getFunction(stackRestoreFctName);
if (stackRestoreFct)
return stackRestoreFct;
// Not found -> declare it for linkage
llvm::FunctionType *stackRestoreFctTy = llvm::FunctionType::get(builder->getVoidTy(), builder->getInt8PtrTy(), false);
module->getOrInsertFunction(stackRestoreFctName, stackRestoreFctTy);
return module->getFunction(stackRestoreFctName);
llvm::Value *GeneratorVisitor::materializeString(llvm::Value *stringStructPtr) {
assert(stringStructPtr->getType()->isPointerTy());
llvm::Value *rawStringValue = builder->CreateCall(stdFunctionManager->getStringRawFct(), stringStructPtr);
return rawStringValue;
}

llvm::Constant *GeneratorVisitor::getDefaultValueForSymbolType(const SymbolType &symbolType) {
Expand Down
13 changes: 8 additions & 5 deletions src/generator/GeneratorVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ast/AstNodes.h>
#include <ast/AstVisitor.h>

#include <generator/StdFunctionManager.h>
#include <generator/OpRuleConversionsManager.h>
#include <symbol/ScopePath.h>
#include <symbol/SymbolType.h>
Expand All @@ -22,7 +23,6 @@ class ThreadFactory;
class LinkerInterface;
struct CliOptions;
class LinkerInterface;
class OpRuleConversionsManager;
class SymbolTable;
class SymbolTableEntry;
class Function;
Expand All @@ -43,6 +43,10 @@ class GeneratorVisitor : public AstVisitor {
ThreadFactory &threadFactory, const LinkerInterface &linker, const CliOptions &cliOptions,
const SourceFile &sourceFile, const std::string &objectFile);

// Friend classes
friend class StdFunctionManager;
friend class OpRuleConversionsManager;

// Public methods
void optimize();
void emit();
Expand Down Expand Up @@ -101,6 +105,7 @@ class GeneratorVisitor : public AstVisitor {

private:
// Members
std::unique_ptr<StdFunctionManager> stdFunctionManager;
std::unique_ptr<OpRuleConversionsManager> conversionsManager;
const std::string &objectFile;
llvm::TargetMachine *targetMachine{};
Expand Down Expand Up @@ -163,10 +168,8 @@ class GeneratorVisitor : public AstVisitor {
llvm::Value *allocateDynamicallySizedArray(llvm::Type *itemType);
llvm::Value *createGlobalArray(llvm::Constant *constArray);
bool insertDestructorCall(const CodeLoc &codeLoc, SymbolTableEntry *varEntry);
llvm::Function *retrievePrintfFct();
llvm::Function *retrieveExitFct();
llvm::Function *retrieveStackSaveFct();
llvm::Function *retrieveStackRestoreFct();

llvm::Value *materializeString(llvm::Value *stringStruct);
llvm::Constant *getDefaultValueForSymbolType(const SymbolType &symbolType);
SymbolTableEntry *initExtGlobal(const std::string &globalName, const std::string &fqGlobalName);
llvm::Value *doImplicitCast(llvm::Value *src, llvm::Type *dstTy, SymbolType srcType);
Expand Down
37 changes: 28 additions & 9 deletions src/generator/OpRuleConversionsManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
#include <stdexcept>

#include <exception/IRError.h>
#include <generator/GeneratorVisitor.h>
#include <generator/StdFunctionManager.h>
#include <util/CodeLoc.h>

OpRuleConversionsManager::OpRuleConversionsManager(GeneratorVisitor *generator) : generator(generator) {
builder = generator->builder.get();
context = generator->context.get();
stdFunctionManager = generator->stdFunctionManager.get();
}

llvm::Value *OpRuleConversionsManager::getPlusEqualInst(llvm::Value *lhs, llvm::Value *rhs, const SymbolType &lhsSTy,
const SymbolType &rhsSTy, SymbolTable *accessScope,
const CodeLoc &codeLoc) {
Expand Down Expand Up @@ -465,9 +473,12 @@ llvm::Value *OpRuleConversionsManager::getEqualInst(llvm::Value *lhs, llvm::Valu
case COMB(TY_BYTE, TY_BYTE): // fallthrough
case COMB(TY_CHAR, TY_CHAR):
return builder->CreateICmpEQ(lhs, rhs);
case COMB(TY_STRING, TY_STRING):
// ToDo(@marcauberer): Insert call to opEquals in the runtime lib
throw IRError(codeLoc, COMING_SOON_IR, "The compiler does not support the '==' operator for lhs=string and rhs=string yet");
case COMB(TY_STRING, TY_STRING): {
// Generate call to the function isRawEqual(string, string) of the string std
llvm::Function *opFct = stdFunctionManager->getStringLitEqualsOpStringLitFct();
llvm::Value *result = builder->CreateCall(opFct, {lhs, rhs});
return result;
}
case COMB(TY_BOOL, TY_BOOL):
return builder->CreateICmpEQ(lhs, rhs);
}
Expand Down Expand Up @@ -570,9 +581,13 @@ llvm::Value *OpRuleConversionsManager::getNotEqualInst(llvm::Value *lhs, llvm::V
case COMB(TY_BYTE, TY_BYTE): // fallthrough
case COMB(TY_CHAR, TY_CHAR):
return builder->CreateICmpNE(lhs, rhs);
case COMB(TY_STRING, TY_STRING):
// ToDo(@marcauberer): Insert call to opNotEquals in the runtime lib
throw IRError(codeLoc, COMING_SOON_IR, "The compiler does not support the '!=' operator for lhs=string and rhs=string yet");
case COMB(TY_STRING, TY_STRING): {
// Generate call to the function isRawEqual(string, string) of the string std
llvm::Function *opFct = stdFunctionManager->getStringLitEqualsOpStringLitFct();
llvm::Value *result = builder->CreateCall(opFct, {lhs, rhs});
// Negate the result
return builder->CreateNot(result);
}
case COMB(TY_BOOL, TY_BOOL):
return builder->CreateICmpNE(lhs, rhs);
}
Expand Down Expand Up @@ -943,9 +958,13 @@ llvm::Value *OpRuleConversionsManager::getPlusInst(llvm::Value *lhs, llvm::Value
case COMB(TY_BYTE, TY_BYTE): // fallthrough
case COMB(TY_CHAR, TY_CHAR):
return builder->CreateAdd(lhs, rhs);
case COMB(TY_STRING, TY_STRING):
// ToDo(@marcauberer): Insert call to append in the runtime lib
throw IRError(codeLoc, COMING_SOON_IR, "The compiler does not support the '+' operator for lhs=string and rhs=string yet");
case COMB(TY_STRING, TY_STRING): {
// Generate call to the constructor ctor(string, string) of the String struct
llvm::Function *opFct = stdFunctionManager->getStringLitPlusOpStringLitFct();
llvm::Value *thisPtr = generator->insertAlloca(stdFunctionManager->getStringStructType());
builder->CreateCall(opFct, {thisPtr, lhs, rhs});
return thisPtr;
}
case COMB(TY_PTR, TY_INT): // fallthrough
case COMB(TY_PTR, TY_SHORT): // fallthrough
case COMB(TY_PTR, TY_LONG):
Expand Down
13 changes: 8 additions & 5 deletions src/generator/OpRuleConversionsManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
#include <llvm/IR/Value.h>

// Forward declarations
class ErrorFactory;
class GeneratorVisitor;
class StdFunctionManager;
struct CodeLoc;

#define COMB(en1, en2) ((en1) | ((en2) << 16))

class OpRuleConversionsManager {
public:
explicit OpRuleConversionsManager(const std::shared_ptr<llvm::LLVMContext> &context, std::shared_ptr<llvm::IRBuilder<>> builder)
: context(context), builder(std::move(builder)) {}
// Constructors
explicit OpRuleConversionsManager(GeneratorVisitor *generator);

// Public methods
llvm::Value *getPlusEqualInst(llvm::Value *lhs, llvm::Value *rhs, const SymbolType &lhsTy, const SymbolType &rhsTy,
Expand Down Expand Up @@ -64,6 +65,8 @@ class OpRuleConversionsManager {

private:
// Members
const std::shared_ptr<llvm::LLVMContext> &context;
std::shared_ptr<llvm::IRBuilder<>> builder;
GeneratorVisitor *generator;
llvm::LLVMContext *context;
llvm::IRBuilder<> *builder;
const StdFunctionManager *stdFunctionManager;
};
Loading

0 comments on commit c04d310

Please sign in to comment.