Skip to content

Commit

Permalink
Add support for more string operators (#200)
Browse files Browse the repository at this point in the history
* Support += operator for strings

* Support * operator for strings

* Support * operator for chars

* Update tests
  • Loading branch information
marcauberer authored Sep 4, 2022
1 parent c04d310 commit cfc1f08
Show file tree
Hide file tree
Showing 14 changed files with 473 additions and 204 deletions.
13 changes: 3 additions & 10 deletions media/test-project/os-test.spice
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,7 @@ f<int> main() {
import "std/runtime/string_rt" as _rt_str;

f<int> main() {
// Plus
printf("String: %s\n", "Hello " + "World!");
string s = "Hello " + "World!";
printf("String: %s\n", s);
// Equals
printf("String: %d\n", "Hello World!" == "Hello Programmers!");
printf("String: %d\n", "Hello" == "Hell2");
// Not equals
printf("String: %d\n", "Hello World!" != "Hello Programmers!");
printf("String: %d\n", "Hello" != "Hell2");
string s = "Test";
s *= 3;
printf("%s", s);
}
43 changes: 42 additions & 1 deletion src/analyzer/OpRuleManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ SymbolType OpRuleManager::getPlusEqualResultType(const CodeLoc &codeLoc, const S
throw printErrorMessageUnsafe(codeLoc, "+=", lhs, rhs);
}

// Allow string += char
if (lhs.is(TY_STRING) && rhs.is(TY_CHAR))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});
// Allow string += string
if (lhs.is(TY_STRING) && rhs.is(TY_STRING))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

return validateBinaryOperation(codeLoc, PLUS_EQUAL_OP_RULES, "+=", lhs, rhs);
}

Expand All @@ -45,6 +52,16 @@ SymbolType OpRuleManager::getMinusEqualResultType(const CodeLoc &codeLoc, const
}

SymbolType OpRuleManager::getMulEqualResultType(const CodeLoc &codeLoc, const SymbolType &lhs, const SymbolType &rhs) {
// Allow string *= int
if (lhs.is(TY_STRING) && rhs.is(TY_INT))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});
// Allow string *= long
if (lhs.is(TY_STRING) && rhs.is(TY_LONG))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});
// Allow string *= short
if (lhs.is(TY_STRING) && rhs.is(TY_SHORT))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

return validateBinaryOperation(codeLoc, MUL_EQUAL_OP_RULES, "*=", lhs, rhs);
}

Expand Down Expand Up @@ -160,7 +177,7 @@ SymbolType OpRuleManager::getPlusResultType(const CodeLoc &codeLoc, const Symbol

// Allow string + string
if (lhs.is(TY_STRING) && rhs.is(TY_STRING))
return SymbolType(TY_STRING, "", { .isStringStruct = true }, {});
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

return validateBinaryOperation(codeLoc, PLUS_OP_RULES, "+", lhs, rhs);
}
Expand All @@ -185,6 +202,30 @@ SymbolType OpRuleManager::getMinusResultType(const CodeLoc &codeLoc, const Symbo
}

SymbolType OpRuleManager::getMulResultType(const CodeLoc &codeLoc, const SymbolType &lhs, const SymbolType &rhs) {
// Allow string * int and int * string
if ((lhs.is(TY_STRING) && rhs.is(TY_INT)) || (lhs.is(TY_INT) && rhs.is(TY_STRING)))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

// Allow string * short and short * string
if ((lhs.is(TY_STRING) && rhs.is(TY_SHORT)) || (lhs.is(TY_SHORT) && rhs.is(TY_STRING)))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

// Allow string * long and long * string
if ((lhs.is(TY_STRING) && rhs.is(TY_LONG)) || (lhs.is(TY_LONG) && rhs.is(TY_STRING)))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

// Allow char * int and int * char
if ((lhs.is(TY_CHAR) && rhs.is(TY_INT)) || (lhs.is(TY_INT) && rhs.is(TY_CHAR)))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

// Allow char * short and short * char
if ((lhs.is(TY_CHAR) && rhs.is(TY_SHORT)) || (lhs.is(TY_SHORT) && rhs.is(TY_CHAR)))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

// Allow char * long and long * char
if ((lhs.is(TY_CHAR) && rhs.is(TY_LONG)) || (lhs.is(TY_LONG) && rhs.is(TY_CHAR)))
return SymbolType(TY_STRING, "", {.isStringStruct = true}, {});

return validateBinaryOperation(codeLoc, MUL_OP_RULES, "*", lhs, rhs);
}

Expand Down
14 changes: 0 additions & 14 deletions src/analyzer/OpRuleManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ const std::vector<BinaryOpRule> PLUS_EQUAL_OP_RULES = {
BinaryOpRule(TY_LONG, TY_LONG, TY_LONG, false), // long += long -> long
BinaryOpRule(TY_BYTE, TY_BYTE, TY_BYTE, false), // byte += byte -> byte
BinaryOpRule(TY_CHAR, TY_CHAR, TY_CHAR, false), // char += char -> char
BinaryOpRule(TY_STRING, TY_CHAR, TY_STRING, false), // string += char -> string
BinaryOpRule(TY_STRING, TY_STRING, TY_STRING, false), // string += string -> string
};

// Minus equal op rules
Expand Down Expand Up @@ -444,27 +442,15 @@ const std::vector<BinaryOpRule> MUL_OP_RULES = {
BinaryOpRule(TY_INT, TY_INT, TY_INT, false), // int * int -> int
BinaryOpRule(TY_INT, TY_SHORT, TY_INT, false), // int * short -> int
BinaryOpRule(TY_INT, TY_LONG, TY_LONG, false), // int * long -> long
BinaryOpRule(TY_INT, TY_CHAR, TY_STRING, false), // int * char -> string
BinaryOpRule(TY_INT, TY_STRING, TY_STRING, false), // int * string -> string
BinaryOpRule(TY_SHORT, TY_DOUBLE, TY_DOUBLE, false), // short * double -> double
BinaryOpRule(TY_SHORT, TY_INT, TY_INT, false), // short * int -> int
BinaryOpRule(TY_SHORT, TY_SHORT, TY_SHORT, false), // short * short -> short
BinaryOpRule(TY_SHORT, TY_LONG, TY_LONG, false), // short * long -> long
BinaryOpRule(TY_SHORT, TY_CHAR, TY_STRING, false), // short * char -> string
BinaryOpRule(TY_SHORT, TY_STRING, TY_STRING, false), // short * string -> string
BinaryOpRule(TY_LONG, TY_DOUBLE, TY_DOUBLE, false), // long * double -> double
BinaryOpRule(TY_LONG, TY_INT, TY_LONG, false), // long * int -> long
BinaryOpRule(TY_LONG, TY_SHORT, TY_LONG, false), // long * short -> long
BinaryOpRule(TY_LONG, TY_LONG, TY_LONG, false), // long * long -> long
BinaryOpRule(TY_LONG, TY_CHAR, TY_STRING, false), // long * char -> string
BinaryOpRule(TY_LONG, TY_STRING, TY_STRING, false), // long * string -> string
BinaryOpRule(TY_BYTE, TY_BYTE, TY_BYTE, false), // byte * byte -> byte
BinaryOpRule(TY_CHAR, TY_INT, TY_STRING, false), // char * int -> string
BinaryOpRule(TY_CHAR, TY_SHORT, TY_STRING, false), // char * short -> string
BinaryOpRule(TY_CHAR, TY_LONG, TY_STRING, false), // char * long -> string
BinaryOpRule(TY_STRING, TY_INT, TY_STRING, false), // string * int -> string
BinaryOpRule(TY_STRING, TY_SHORT, TY_STRING, false), // string * short -> string
BinaryOpRule(TY_STRING, TY_LONG, TY_STRING, false), // string * long -> string
};

// Div op rules
Expand Down
10 changes: 5 additions & 5 deletions src/generator/GeneratorVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ std::any GeneratorVisitor::visitMainFctDef(MainFctDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreIntrinsic(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -475,7 +475,7 @@ std::any GeneratorVisitor::visitFctDef(FctDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreIntrinsic(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -662,7 +662,7 @@ std::any GeneratorVisitor::visitProcDef(ProcDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreIntrinsic(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -2985,7 +2985,7 @@ llvm::Value *GeneratorVisitor::insertAlloca(llvm::Type *llvmType, const std::str

llvm::Value *GeneratorVisitor::allocateDynamicallySizedArray(llvm::Type *itemType) {
// Call llvm.stacksave intrinsic
llvm::Function *stackSaveFct = stdFunctionManager->getStackSaveFct();
llvm::Function *stackSaveFct = stdFunctionManager->getStackSaveIntrinsic();
if (stackState == nullptr)
stackState = builder->CreateCall(stackSaveFct);
// Allocate array
Expand Down Expand Up @@ -3054,7 +3054,7 @@ bool GeneratorVisitor::insertDestructorCall(const CodeLoc &codeLoc, SymbolTableE

llvm::Value *GeneratorVisitor::materializeString(llvm::Value *stringStructPtr) {
assert(stringStructPtr->getType()->isPointerTy());
llvm::Value *rawStringValue = builder->CreateCall(stdFunctionManager->getStringRawFct(), stringStructPtr);
llvm::Value *rawStringValue = builder->CreateCall(stdFunctionManager->getStringGetRawFct(), stringStructPtr);
return rawStringValue;
}

Expand Down
Loading

0 comments on commit cfc1f08

Please sign in to comment.