Skip to content

Commit

Permalink
Optimize mechanism for dynamically sized arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer committed Oct 7, 2022
1 parent b256739 commit aa1e0ac
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 43 deletions.
14 changes: 13 additions & 1 deletion media/test-project/os-test.spice
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
const int SIZE = 9;

p print(int[SIZE][SIZE] grid) {
foreach int[SIZE] row : grid {
int[SIZE][SIZE] grid1 = {
{ 3, 0, 6, 5, 0, 8, 4, 0, 0 },
{ 5, 2, 0, 0, 0, 0, 0, 0, 0 },
{ 0, 8, 7, 0, 0, 0, 0, 3, 1 },
{ 0, 0, 3, 0, 1, 0, 0, 8, 0 },
{ 9, 0, 0, 8, 6, 3, 0, 0, 5 },
{ 0, 5, 0, 0, 9, 0, 6, 0, 0 },
{ 1, 3, 0, 0, 0, 0, 2, 5, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 7, 4 },
{ 0, 0, 5, 2, 0, 6, 3, 0, 0 }
};

foreach int[SIZE] row : grid1 {
foreach int cell : row {
printf("%d ", cell);
}
Expand Down
4 changes: 2 additions & 2 deletions src/analyzer/AnalyzerVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2175,8 +2175,8 @@ std::any AnalyzerVisitor::visitDataType(DataTypeNode *node) {
break;
}
case DataTypeNode::TYPE_ARRAY: {
if (typeModifier.hasSize) {
if (typeModifier.isSizeHardcoded) {
if (typeModifier.hasSizeAttached()) {
if (typeModifier.hasHardcodedSize()) {
if (typeModifier.hardcodedSize <= 1)
throw SemanticError(node->codeLoc, ARRAY_SIZE_INVALID, "The size of an array must be > 1 and explicitly stated");
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/analyzer/OpRuleManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SymbolType OpRuleManager::getAssignResultType(const CodeLoc &codeLoc, const Symb
if (lhs.isOneOf({TY_PTR, TY_ARRAY, TY_STRUCT}) && lhs == rhs)
return rhs;
// Allow array to pointer
if (lhs.is(TY_PTR) && rhs.is(TY_ARRAY) && lhs.getContainedTy() == rhs.getContainedTy())
if (lhs.isPointer() && rhs.isArray() && lhs.getContainedTy() == rhs.getContainedTy())
return lhs;
// Allow char* = string
if (lhs.isPointerOf(TY_CHAR) && rhs.is(TY_STRING))
Expand Down
8 changes: 5 additions & 3 deletions src/ast/AstNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1222,9 +1222,11 @@ class DataTypeNode : public AstNode {
// Structs
struct TypeModifier {
TypeModifierType modifierType = TYPE_PTR;
bool hasSize = false;
bool isSizeHardcoded = false;
int hardcodedSize = 0;
int hardcodedSize = 0; // 0: no size attached, -1: dynamically sized, >=1: hardcoded size

[[nodiscard]] bool hasDynamicSize() const { return hardcodedSize == -1; }
[[nodiscard]] bool hasSizeAttached() const { return hardcodedSize != 0; }
[[nodiscard]] bool hasHardcodedSize() const { return hardcodedSize >= 1; }
};

// Constructors
Expand Down
28 changes: 13 additions & 15 deletions src/generator/GeneratorVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ std::any GeneratorVisitor::visitForeachLoop(ForeachLoopNode *node) {
// Get array variable entry
llvm::Value *arrayValuePtr = resolveAddress(node->arrayAssign());
SymbolTableEntry *arrayVarEntry = currentScope->lookup(currentVarName);
bool dynamicallySized = arrayVarEntry && arrayVarEntry->type.is(TY_PTR) && arrayVarEntry->type.getDynamicArraySize() != nullptr;
bool dynamicallySized = arrayVarEntry && arrayVarEntry->type.isDynamicallySizedArray();

// Initialize loop variables
llvm::Value *idxVarPtr;
Expand Down Expand Up @@ -1336,7 +1336,7 @@ std::any GeneratorVisitor::visitDeclStmt(DeclStmtNode *node) {
if (node->assignExpr()) { // Declaration with assignment
memAddress = resolveAddress(node->assignExpr());
} else { // Declaration with default value
if (entry->type.is(TY_PTR) && entry->type.getDynamicArraySize() != nullptr) {
if (entry->type.isDynamicallySizedArray()) {
llvm::Type *itemType = entry->type.getContainedTy().toLLVMType(*context, nullptr);
dynamicArraySize = entry->type.getDynamicArraySize();
llvm::Value *value = allocateDynamicallySizedArray(itemType);
Expand Down Expand Up @@ -2718,21 +2718,19 @@ std::any GeneratorVisitor::visitArrayInitialization(ArrayInitializationNode *nod
size_t actualItemCount = node->itemLst() ? node->itemLst()->args().size() : 0;
size_t arraySize = lhsType != nullptr && lhsType->isArrayTy() ? lhsType->getArrayNumElements() : actualItemCount;

bool dynamicallySized = false;
bool outermostArray = true;
if (!arraySymbolType.is(TY_INVALID)) {
dynamicallySized = arraySymbolType.is(TY_PTR) && arraySymbolType.getDynamicArraySize() != nullptr;
outermostArray = false;
} else if (!lhsVarName.empty()) {
SymbolTableEntry *arrayEntry = currentScope->lookupStrict(lhsVarName);
SymbolTableEntry *arrayEntry = currentScope->lookup(lhsVarName);
assert(arrayEntry != nullptr);
arraySymbolType = arrayEntry->type;
dynamicallySized = arraySymbolType.isPointer() && arraySymbolType.getDynamicArraySize() != nullptr;
if (dynamicallySized)
dynamicArraySize = arraySymbolType.getDynamicArraySize();
} else {
arraySymbolType = node->getEvaluatedSymbolType();
}
bool dynamicallySized = arraySymbolType.isDynamicallySizedArray();
if (dynamicallySized)
dynamicArraySize = arraySymbolType.getDynamicArraySize();

SymbolType itemSymbolType = arraySymbolType.getContainedTy();
llvm::Type *arrayType = arraySymbolType.toLLVMType(*context, currentScope);
Expand Down Expand Up @@ -2931,17 +2929,17 @@ std::any GeneratorVisitor::visitDataType(DataTypeNode *node) {
break;
}
case DataTypeNode::TYPE_ARRAY: {
if (!typeModifier.hasSize) {
symbolType = symbolType.toPointer(node->codeLoc);
} else if (typeModifier.isSizeHardcoded) {
if (typeModifier.hasHardcodedSize()) { // hardcoded size
symbolType = symbolType.toArray(node->codeLoc, typeModifier.hardcodedSize);
} else {
} else if (!typeModifier.hasSizeAttached()) { // no size
symbolType = symbolType.toPointer(node->codeLoc);
} else { // dynamic size
AssignExprNode *indexExpr = arraySizeExpr[assignExprCounter++];
assert(indexExpr != nullptr);
auto sizeValuePtr = any_cast<llvm::Value *>(visit(indexExpr));
llvm::Type *sizeType = indexExpr->getEvaluatedSymbolType().toLLVMType(*context, currentScope);
dynamicArraySize = builder->CreateLoad(sizeType, sizeValuePtr);
symbolType = symbolType.toPointer(node->codeLoc, dynamicArraySize);
symbolType = symbolType.toArray(node->codeLoc, typeModifier.hardcodedSize, dynamicArraySize);
}
break;
}
Expand Down Expand Up @@ -3137,13 +3135,13 @@ llvm::Constant *GeneratorVisitor::getDefaultValueForSymbolType(const SymbolType
return currentConstValue = builder->getFalse();

// Pointer
if (symbolType.is(TY_PTR)) {
if (symbolType.isPointer() || symbolType.isDynamicallySizedArray()) {
llvm::Type *baseType = symbolType.getContainedTy().toLLVMType(*context, currentScope);
return currentConstValue = llvm::Constant::getNullValue(baseType->getPointerTo());
}

// Array
if (symbolType.is(TY_ARRAY)) {
if (symbolType.isArray()) {
size_t arraySize = symbolType.getArraySize();

llvm::Type *itemType = symbolType.getContainedTy().toLLVMType(*context, currentScope);
Expand Down
9 changes: 2 additions & 7 deletions src/parser/AstBuilderVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1389,26 +1389,21 @@ std::any AstBuilderVisitor::visitDataType(SpiceParser::DataTypeContext *ctx) {
if (rule = dynamic_cast<SpiceParser::BaseDataTypeContext *>(subTree); rule != nullptr) // BaseDataType
currentNode = dataTypeNode->createChild<BaseDataTypeNode>(CodeLoc(rule->start, fileName));
else if (auto t = dynamic_cast<antlr4::tree::TerminalNode *>(subTree); t->getSymbol()->getType() == SpiceParser::MUL)
dataTypeNode->tmQueue.push({DataTypeNode::TYPE_PTR, false, false, 0});
dataTypeNode->tmQueue.emplace(DataTypeNode::TYPE_PTR, 0);
else if (auto t = dynamic_cast<antlr4::tree::TerminalNode *>(subTree); t->getSymbol()->getType() == SpiceParser::LBRACKET) {
i++; // Consume LBRACKET
subTree = ctx->children[i];
bool hasSize = false;
bool isHardcoded = false;
int hardCodedSize = 0;
if (rule = dynamic_cast<SpiceParser::AssignExprContext *>(subTree); rule != nullptr) { // AssignExpr
hasSize = true;
hardCodedSize = -1;
currentNode = dataTypeNode->createChild<AssignExprNode>(CodeLoc(rule->start, fileName));
i++; // Consume INTEGER
} else if (auto t = dynamic_cast<antlr4::tree::TerminalNode *>(subTree);
t->getSymbol()->getType() == SpiceParser::INT_LIT) {
hasSize = true;
isHardcoded = true;
hardCodedSize = std::stoi(t->getSymbol()->getText());
i++; // Consume INTEGER
}
dataTypeNode->tmQueue.push({DataTypeNode::TYPE_ARRAY, hasSize, isHardcoded, hardCodedSize});
dataTypeNode->tmQueue.emplace(DataTypeNode::TYPE_ARRAY, hardCodedSize);
} else
assert(dynamic_cast<antlr4::tree::TerminalNode *>(subTree)); // Fail if we did not get a terminal

Expand Down
28 changes: 18 additions & 10 deletions src/symbol/SymbolType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ SymbolType::TypeChain SymbolType::getTypeChain() const { return typeChain; }
*
* @return Pointer type of the current type
*/
SymbolType SymbolType::toPointer(const CodeLoc &codeLoc, llvm::Value *dynamicSize) const {
SymbolType SymbolType::toPointer(const CodeLoc &codeLoc) const {
// Do not allow pointers of dyn
if (typeChain.top().superType == TY_DYN)
throw SemanticError(codeLoc, DYN_POINTERS_NOT_ALLOWED, "Just use the dyn type without '*' instead");

TypeChain newTypeChain = typeChain;
newTypeChain.push({TY_PTR, "", 0, {}, dynamicSize});
newTypeChain.push({TY_PTR, "", {}, {}, nullptr});
return SymbolType(newTypeChain);
}

Expand All @@ -38,13 +38,13 @@ SymbolType SymbolType::toPointer(const CodeLoc &codeLoc, llvm::Value *dynamicSiz
*
* @return Array type of the current type
*/
SymbolType SymbolType::toArray(const CodeLoc &codeLoc, int size) const {
SymbolType SymbolType::toArray(const CodeLoc &codeLoc, int size, llvm::Value *dynamicSize) const {
// Do not allow arrays of dyn
if (typeChain.top().superType == TY_DYN)
throw SemanticError(codeLoc, DYN_ARRAYS_NOT_ALLOWED, "Just use the dyn type without '[]' instead");

TypeChain newTypeChain = typeChain;
newTypeChain.push({TY_ARRAY, "", {.arraySize = size}, {}, nullptr});
newTypeChain.push({TY_ARRAY, "", {.arraySize = size}, {}, dynamicSize});
return SymbolType(newTypeChain);
}

Expand Down Expand Up @@ -148,7 +148,7 @@ llvm::Type *SymbolType::toLLVMType(llvm::LLVMContext &context, SymbolTable *acce
if (is(TY_ENUM))
return llvm::Type::getInt32Ty(context);

if (isPointer() || (isArray() && getArraySize() <= 0)) {
if (isPointer() || isDynamicallySizedArray()) {
llvm::PointerType *pointerType = getContainedTy().toLLVMType(context, accessScope)->getPointerTo();
return static_cast<llvm::Type *>(pointerType);
}
Expand Down Expand Up @@ -202,6 +202,13 @@ bool SymbolType::isArrayOf(SymbolSuperType elementSuperType) const { return isAr
*/
bool SymbolType::isArrayOf(const SymbolType &otherSymbolType) const { return isArray() && getContainedTy() == otherSymbolType; }

/**
* Check if the current type is a dynamically sized array.
*
* @return Dynamically sized array or not
*/
bool SymbolType::isDynamicallySizedArray() const { return isArray() && getArraySize() == -1; }

/**
* Check if the current type is of a certain super type
*
Expand Down Expand Up @@ -328,7 +335,8 @@ std::string SymbolType::getName(bool withSize, bool mangledName) const {
* Get the size of the current type
*
* Special cases:
* - 0: Array size was not defined
* - 0: Array size was not defined
* - -1: Array is dynamically sized
*
* @return Size
*/
Expand All @@ -345,9 +353,9 @@ int SymbolType::getArraySize() const {
* @return Dynamic array size
*/
llvm::Value *SymbolType::getDynamicArraySize() const {
if (typeChain.top().superType != TY_PTR) // GCOV_EXCL_LINE
if (typeChain.top().superType != TY_ARRAY) // GCOV_EXCL_LINE
throw std::runtime_error("Internal compiler error: Cannot get dynamic sized of non-array type"); // GCOV_EXCL_LINE
if (typeChain.top().data.arraySize > 0) // GCOV_EXCL_LINE
if (typeChain.top().data.arraySize != -1) // GCOV_EXCL_LINE
throw std::runtime_error("Cannot retrieve dynamic size of non-dynamically-sized array"); // GCOV_EXCL_LINE

return typeChain.top().dynamicArraySize;
Expand All @@ -360,8 +368,8 @@ bool operator!=(const SymbolType &lhs, const SymbolType &rhs) { return lhs.typeC
/**
* Compares the type chains of two symbol types without taking array sizes into account
*
* @param lhs Lhs symbol type
* @param rhs Rhs symbol type
* @param lhs Left hand symbol type
* @param rhs Right hand symbol type
* @return Equal or not
*/
bool equalsIgnoreArraySizes(SymbolType lhs, SymbolType rhs) {
Expand Down
9 changes: 5 additions & 4 deletions src/symbol/SymbolType.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ enum SymbolSuperType {
};

union TypeChainElementData {
bool isStringStruct; // TY_STRING
int arraySize; // TY_ARRAY
bool isStringStruct = false; // TY_STRING
int arraySize; // TY_ARRAY
};

class SymbolType {
Expand Down Expand Up @@ -89,8 +89,8 @@ class SymbolType {

// Public methods
[[nodiscard]] TypeChain getTypeChain() const;
SymbolType toPointer(const CodeLoc &codeLoc, llvm::Value *dynamicSize = nullptr) const;
[[nodiscard]] SymbolType toArray(const CodeLoc &codeLoc, int size = 0) const;
[[nodiscard]] SymbolType toPointer(const CodeLoc &codeLoc) const;
[[nodiscard]] SymbolType toArray(const CodeLoc &codeLoc, int size = 0, llvm::Value *dynamicSize = nullptr) const;
[[nodiscard]] SymbolType getContainedTy() const;
[[nodiscard]] SymbolType replaceBaseSubType(const std::string &newSubType) const;
[[nodiscard]] SymbolType replaceBaseType(const SymbolType &newBaseType) const;
Expand All @@ -100,6 +100,7 @@ class SymbolType {
[[nodiscard]] bool isArray() const;
[[nodiscard]] bool isArrayOf(SymbolSuperType superType) const;
[[nodiscard]] bool isArrayOf(const SymbolType &symbolType) const;
[[nodiscard]] bool isDynamicallySizedArray() const;
[[nodiscard]] bool is(SymbolSuperType superType) const;
[[nodiscard]] bool is(SymbolSuperType superType, const std::string &subType) const;
[[nodiscard]] bool isPrimitive() const;
Expand Down

0 comments on commit aa1e0ac

Please sign in to comment.