Skip to content

Commit

Permalink
Implement concat (#1450)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 20, 2023
1 parent b14e0ab commit 21dcb51
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 130 deletions.
13 changes: 13 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,19 @@ llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B,
llvm::Value *arg_ld, llvm::Value *dim_1,
llvm::Value *dim_2, bool cacheMat,
bool byRef);

template <typename... T> static inline void nothing(T...){};
template <typename... T>
static inline llvm::SmallVector<llvm::Value *, 1> concat_values(T... t) {
llvm::SmallVector<llvm::Value *, 1> res;
auto append = [&](llvm::ArrayRef<llvm::Value *> V) {
res.append(V.begin(), V.end());
return 0;
};
nothing(append(t)...);
return res;
}

llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef);
llvm::Value *is_uper(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef);
llvm::Value *select_vec_dims(llvm::IRBuilder<> &B, llvm::Value *trans,
Expand Down
140 changes: 10 additions & 130 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,136 +785,6 @@ std::string get_blas_ret_ty(StringRef dfnc_name) {
return "Builder2.getVoidTy()";
}

/*
void emit_deriv_blas_call(DagInit *ruleDag,
const StringMap<TGPattern> &patternMap,
StringSet<> &handled, raw_ostream &os) {
const auto Def = cast<DefInit>(ruleDag->getOperator())->getDef();
const auto dfnc_name = Def->getValueAsString("s");
if (patternMap.find(dfnc_name) == patternMap.end()) {
PrintFatalError("calling unknown Blas function");
}
TGPattern calledPattern = patternMap.find(dfnc_name)->getValue();
bool derivlv23 = calledPattern.isBLASLevel2or3();
DenseSet<size_t> mutableArgs = calledPattern.getMutableArgs();
if (handled.find(dfnc_name) != handled.end())
return;
else
handled.insert(dfnc_name);
auto retTy = get_blas_ret_ty(dfnc_name);
// insert arg types based on .td file
std::string typeString = "";
bool first = true;
for (size_t i = 0; i < ruleDag->getNumArgs(); i++) {
Init *subArg = ruleDag->getArg(i);
if (DefInit *def = dyn_cast<DefInit>(subArg)) {
const auto Def = def->getDef();
std::string typeToAdd = "";
if (Def->isSubClassOf("DiffeRetIndex")) {
typeToAdd = "byRef ? PointerType::getUnqual(call.getType()) : "
"call.getType()\n";
} else if (Def->isSubClassOf("adj")) {
auto argStr = Def->getValueAsString("name");
// primary and adj have the same type
typeToAdd = (Twine("type_") + argStr).str();
} else if (Def->isSubClassOf("input")) {
auto argStr = Def->getValueAsString("name");
// primary and adj have the same type
typeToAdd = (Twine("type_") + argStr).str();
} else if (Def->isSubClassOf("Constant")) {
typeToAdd = "blasFPType";
} else if (Def->isSubClassOf("Char")) {
typeToAdd = "byRef ? (Type*)PointerType::getUnqual(charType) : "
"(Type*)charType";
} else if (Def->isSubClassOf("ConstantInt")) {
typeToAdd = "byRef ? (Type*)blasIntType : (Type*)intType";
} else if (Def->isSubClassOf("transpose")) {
auto argStr = Def->getValueAsString("name");
// transpose the given trans arg, but type stays
typeToAdd = (Twine("type_") + argStr).str();
} else if (Def->isSubClassOf("use")) {
// we only use tmp matrices, so mat type
typeToAdd = "type_vec_like";
} else {
PrintFatalError(Def->getLoc(), "PANIC! Unsupported Definit");
}
typeString += ((first) ? "" : ", ") + typeToAdd;
} else {
if (auto Dag = dyn_cast<DagInit>(subArg)) {
auto Def = cast<DefInit>(Dag->getOperator())->getDef();
if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") {
if (!first)
typeString += ", ";
if (DefInit *def = dyn_cast<DefInit>(Dag->getArg(1))) {
const auto Def = def->getDef();
assert(Def->isSubClassOf("adj"));
typeString +=
(Twine("type_") + Def->getValueAsString("name")).str();
} else {
assert(Dag->getArgNameStr(1) != "");
typeString += (Twine("type_") + Dag->getArgNameStr(1)).str();
first = false;
}
continue;
} else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") {
if (!first)
typeString += ", ";
//(ld $A, $transa, $lda, $m, $k)
// Either of 2,3,4 would work
typeString += (Twine("type_") + Dag->getArgNameStr(2)).str();
first = false;
continue;
}
}
const auto argStr = ruleDag->getArgNameStr(i);
// skip layout because it is cblas only,
// so not relevant for the byRef Fortran abi.
// Optionally add it later as first arg for byRef.
if (argStr == "layout")
continue;
typeString += (first ? "" : ", ");
typeString += (Twine("type_") + argStr).str();
}
first = false;
}
std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name);
os << " llvm::FunctionType *FT" << dfnc_name << " = nullptr;\n";
if (derivlv23) {
os << " if(byRef) {\n"
<< " Type* tys" << dfnc_name << "[] = {" << typeString << "};\n"
<< " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty
<< ", tys" << dfnc_name << ", false);\n"
<< " } else {\n"
<< " Type* tys" << dfnc_name << "[] = {type_layout, " << typeString
<< "};\n"
<< " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty
<< ", tys" << dfnc_name << ", false);\n"
<< " }\n";
} else {
os << " Type* tys" << dfnc_name << "[] = {" << typeString << "};\n"
<< " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty
<< ", tys" << dfnc_name << ", false);\n";
}
os << " auto derivcall_" << dfnc_name
<< " = gutils->oldFunc->getParent()->getOrInsertFunction(\n"
<< " (blas.prefix + blas.floatType + \"" << dfnc_name
<< "\" + blas.suffix).str(), FT" << dfnc_name << ");\n";
os << " if (auto F = dyn_cast<Function>(derivcall_" << dfnc_name
<< ".getCallee()))\n"
<< " {\n"
<< " attribute_" << dfnc_name << "(blas, F);\n"
<< " }\n\n";
return;
}
*/

void emit_tmp_creation(Record *Def, raw_ostream &os) {
const auto args = Def->getValueAsListOfStrings("args");
// allocating tmp variables is optional, return if not required
Expand Down Expand Up @@ -1047,6 +917,16 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
os << "byRef)";
return;
}
if (Def->getName() == "Concat") {
os << "concat_values(";
for (size_t i = 0; i < Dag->getNumArgs(); i++) {
if (i != 0)
os << ", ";
rev_call_arg(Dag, rule, actArg, i, os);
}
os << ")";
return;
}
if (Def->getName() == "ld") {
assert(Dag->getNumArgs() == 5);
//(ld $A, $transa, $lda, $m, $k)
Expand Down

0 comments on commit 21dcb51

Please sign in to comment.