From 21dcb51d0eb4c297356d5661e414ab5d15d30ae2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 20 Sep 2023 13:24:17 -0500 Subject: [PATCH] Implement concat (#1450) --- enzyme/Enzyme/Utils.h | 13 ++ enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 140 ++------------------- 2 files changed, 23 insertions(+), 130 deletions(-) diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 8d8d47bcff51..793803635556 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1642,6 +1642,19 @@ llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *arg_ld, llvm::Value *dim_1, llvm::Value *dim_2, bool cacheMat, bool byRef); + +template static inline void nothing(T...){}; +template +static inline llvm::SmallVector concat_values(T... t) { + llvm::SmallVector res; + auto append = [&](llvm::ArrayRef V) { + res.append(V.begin(), V.end()); + return 0; + }; + nothing(append(t)...); + return res; +} + llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef); llvm::Value *is_uper(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef); llvm::Value *select_vec_dims(llvm::IRBuilder<> &B, llvm::Value *trans, diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 5123aec615f1..3e903ed98061 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -785,136 +785,6 @@ std::string get_blas_ret_ty(StringRef dfnc_name) { return "Builder2.getVoidTy()"; } -/* -void emit_deriv_blas_call(DagInit *ruleDag, - const StringMap &patternMap, - StringSet<> &handled, raw_ostream &os) { - - const auto Def = cast(ruleDag->getOperator())->getDef(); - const auto dfnc_name = Def->getValueAsString("s"); - if (patternMap.find(dfnc_name) == patternMap.end()) { - PrintFatalError("calling unknown Blas function"); - } - TGPattern calledPattern = patternMap.find(dfnc_name)->getValue(); - bool derivlv23 = calledPattern.isBLASLevel2or3(); - DenseSet mutableArgs = calledPattern.getMutableArgs(); - - if (handled.find(dfnc_name) != handled.end()) - return; - else - handled.insert(dfnc_name); - - auto retTy = get_blas_ret_ty(dfnc_name); - - // insert arg types based on .td file - std::string typeString = ""; - bool first = true; - for (size_t i = 0; i < ruleDag->getNumArgs(); i++) { - Init *subArg = ruleDag->getArg(i); - if (DefInit *def = dyn_cast(subArg)) { - const auto Def = def->getDef(); - std::string typeToAdd = ""; - if (Def->isSubClassOf("DiffeRetIndex")) { - typeToAdd = "byRef ? PointerType::getUnqual(call.getType()) : " - "call.getType()\n"; - } else if (Def->isSubClassOf("adj")) { - auto argStr = Def->getValueAsString("name"); - // primary and adj have the same type - typeToAdd = (Twine("type_") + argStr).str(); - } else if (Def->isSubClassOf("input")) { - auto argStr = Def->getValueAsString("name"); - // primary and adj have the same type - typeToAdd = (Twine("type_") + argStr).str(); - } else if (Def->isSubClassOf("Constant")) { - typeToAdd = "blasFPType"; - } else if (Def->isSubClassOf("Char")) { - typeToAdd = "byRef ? (Type*)PointerType::getUnqual(charType) : " - "(Type*)charType"; - } else if (Def->isSubClassOf("ConstantInt")) { - typeToAdd = "byRef ? (Type*)blasIntType : (Type*)intType"; - } else if (Def->isSubClassOf("transpose")) { - auto argStr = Def->getValueAsString("name"); - // transpose the given trans arg, but type stays - typeToAdd = (Twine("type_") + argStr).str(); - } else if (Def->isSubClassOf("use")) { - // we only use tmp matrices, so mat type - typeToAdd = "type_vec_like"; - } else { - PrintFatalError(Def->getLoc(), "PANIC! Unsupported Definit"); - } - typeString += ((first) ? "" : ", ") + typeToAdd; - } else { - if (auto Dag = dyn_cast(subArg)) { - auto Def = cast(Dag->getOperator())->getDef(); - if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") { - if (!first) - typeString += ", "; - if (DefInit *def = dyn_cast(Dag->getArg(1))) { - const auto Def = def->getDef(); - assert(Def->isSubClassOf("adj")); - typeString += - (Twine("type_") + Def->getValueAsString("name")).str(); - } else { - assert(Dag->getArgNameStr(1) != ""); - typeString += (Twine("type_") + Dag->getArgNameStr(1)).str(); - first = false; - } - continue; - } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") { - if (!first) - typeString += ", "; - //(ld $A, $transa, $lda, $m, $k) - // Either of 2,3,4 would work - typeString += (Twine("type_") + Dag->getArgNameStr(2)).str(); - first = false; - continue; - } - } - const auto argStr = ruleDag->getArgNameStr(i); - // skip layout because it is cblas only, - // so not relevant for the byRef Fortran abi. - // Optionally add it later as first arg for byRef. - if (argStr == "layout") - continue; - typeString += (first ? "" : ", "); - typeString += (Twine("type_") + argStr).str(); - } - first = false; - } - - std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); - os << " llvm::FunctionType *FT" << dfnc_name << " = nullptr;\n"; - if (derivlv23) { - os << " if(byRef) {\n" - << " Type* tys" << dfnc_name << "[] = {" << typeString << "};\n" - << " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty - << ", tys" << dfnc_name << ", false);\n" - << " } else {\n" - << " Type* tys" << dfnc_name << "[] = {type_layout, " << typeString - << "};\n" - << " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty - << ", tys" << dfnc_name << ", false);\n" - << " }\n"; - } else { - os << " Type* tys" << dfnc_name << "[] = {" << typeString << "};\n" - << " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty - << ", tys" << dfnc_name << ", false);\n"; - } - - os << " auto derivcall_" << dfnc_name - << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" - << " (blas.prefix + blas.floatType + \"" << dfnc_name - << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; - - os << " if (auto F = dyn_cast(derivcall_" << dfnc_name - << ".getCallee()))\n" - << " {\n" - << " attribute_" << dfnc_name << "(blas, F);\n" - << " }\n\n"; - return; -} -*/ - void emit_tmp_creation(Record *Def, raw_ostream &os) { const auto args = Def->getValueAsListOfStrings("args"); // allocating tmp variables is optional, return if not required @@ -1047,6 +917,16 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, os << "byRef)"; return; } + if (Def->getName() == "Concat") { + os << "concat_values("; + for (size_t i = 0; i < Dag->getNumArgs(); i++) { + if (i != 0) + os << ", "; + rev_call_arg(Dag, rule, actArg, i, os); + } + os << ")"; + return; + } if (Def->getName() == "ld") { assert(Dag->getNumArgs() == 5); //(ld $A, $transa, $lda, $m, $k)