Skip to content

Commit

Permalink
fix attributore
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Sep 20, 2023
1 parent f95269c commit 8f1cacf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
34 changes: 22 additions & 12 deletions enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,37 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
}

os << " if (byRef) {\n";
for (size_t argPos = 0; argPos < argTypeMap.size(); argPos++) {
int numCharArgs = 0;
size_t numArgs = argTypeMap.size();
for (size_t argPos = 0; argPos < numArgs; argPos++) {
const auto typeOfArg = argTypeMap.lookup(argPos);
if (is_char_arg(typeOfArg))
numCharArgs++;
}

for (size_t argPos = 0; argPos < numArgs; argPos++) {
const auto typeOfArg = argTypeMap.lookup(argPos);
size_t i = (lv23 ? argPos - 1 : argPos);
if (typeOfArg == ArgType::len || typeOfArg == ArgType::vincInc ||
typeOfArg == ArgType::fp || typeOfArg == ArgType::trans ||
typeOfArg == ArgType::mldLD || typeOfArg == ArgType::uplo ||
typeOfArg == ArgType::diag || typeOfArg == ArgType::side) {
os << " F->removeParamAttr(" << i << (lv23 ? " + offset" : "")
<< ", llvm::Attribute::ReadNone);\n"
<< " F->addParamAttr(" << i << (lv23 ? " + offset" : "")
<< ", llvm::Attribute::ReadOnly);\n"
<< " F->addParamAttr(" << i << (lv23 ? " + offset" : "")
<< ", llvm::Attribute::NoCapture);\n";

if (is_char_arg(typeOfArg) || typeOfArg == ArgType::len ||
typeOfArg == ArgType::vincInc || typeOfArg == ArgType::fp ||
typeOfArg == ArgType::mldLD) {
if (is_char_arg(typeOfArg) && numArgs - argPos <= numCharArgs) {
os << " F->removeParamAttr(" << i << (lv23 ? " + offset" : "")
<< ", llvm::Attribute::ReadNone);\n"
<< " F->addParamAttr(" << i << (lv23 ? " + offset" : "")
<< ", llvm::Attribute::ReadOnly);\n"
<< " F->addParamAttr(" << i << (lv23 ? " + offset" : "")
<< ", llvm::Attribute::NoCapture);\n";
}
}
}

os << " }\n"
<< " // Julia declares double* pointers as Int64,\n"
<< " // so LLVM won't let us add these Attributes.\n"
<< " if (!julia_decl) {\n";
for (size_t argPos = 0; argPos < argTypeMap.size(); argPos++) {
for (size_t argPos = 0; argPos < numArgs; argPos++) {
auto typeOfArg = argTypeMap.lookup(argPos);
size_t i = (lv23 ? argPos - 1 : argPos);
if (typeOfArg == ArgType::vincData || typeOfArg == ArgType::mldData) {
Expand Down
7 changes: 7 additions & 0 deletions enzyme/tools/enzyme-tblgen/datastructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ raw_ostream &operator<<(raw_fd_ostream &os, ArgType arg) {

using namespace llvm;

bool is_char_arg(ArgType ty) {
if (ty == ArgType::side || ty == ArgType::diag || ty == ArgType::trans ||
ty == ArgType::uplo)
return true;
return false;
}

const char *TyToString(ArgType ty) {
switch (ty) {
case ArgType::fp:
Expand Down
2 changes: 2 additions & 0 deletions enzyme/tools/enzyme-tblgen/datastructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ enum class ArgType {
side
};

bool is_char_arg(ArgType ty);

namespace llvm {
raw_ostream &operator<<(raw_ostream &os, ArgType arg);
raw_ostream &operator<<(raw_fd_ostream &os, ArgType arg);
Expand Down

0 comments on commit 8f1cacf

Please sign in to comment.