Skip to content

Commit

Permalink
fix gemv rule for A
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Sep 19, 2023
1 parent 3870f89 commit 785c24b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
12 changes: 11 additions & 1 deletion enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,17 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $
/* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]>
(b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $n), $x, $incx, Constant<"0.0">, use<"Ax">, ConstantInt<1>),
(b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, use<"Ax">, ConstantInt<1>)),
/* A */ (b<"ger"> $layout, $m, $n, $alpha, adj<"y">, $incy, $x, $incx, adj<"A">, $lda),

//if (is_normal $transa) {
// call sger(m, n, alpha, ya, incy, x, incx, Aa, lda)
//} else {
// call sger(m, n, alpha, x, incx, ya, incy, Aa, lda)
//}
/* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, adj<"y">, $x),
(Rows $transa, $incy, $incx),
(Rows $transa, $x, adj<"y">),
(Rows $transa, $incx, $incy),
adj<"A">, $lda),
/* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, $transa, $lda, $m, $n), adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx),
/* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, input<"y">, $incy),
/* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">, $incy)
Expand Down
36 changes: 27 additions & 9 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,15 @@ void emit_deriv_blas_call(DagInit *ruleDag,
if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") {
if (!first)
typeString += ", ";
typeString += (Twine("type_") + Dag->getArgNameStr(1)).str();
first = false;
if (DefInit *def = dyn_cast<DefInit>(Dag->getArg(1))) {
const auto Def = def->getDef();
assert(Def->isSubClassOf("adj"));
typeString += (Twine("type_") + Def->getValueAsString("name")).str();
} else {
assert(Dag->getArgNameStr(1) != "");
typeString += (Twine("type_") + Dag->getArgNameStr(1)).str();
first = false;
}
continue;
} else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") {
if (!first)
Expand Down Expand Up @@ -985,7 +992,6 @@ void emit_tmp_creation(Record *Def, raw_ostream &os) {
void emit_deriv_rule(const StringMap<TGPattern> &patternMap, Rule &rule,
StringSet<> &handled, raw_ostream &os) {
const auto ruleDag = rule.getRuleDag();
const auto typeMap = rule.getArgTypeMap();
const auto opName = ruleDag->getOperator()->getAsString();
const auto nameMap = rule.getArgNameMap();
const auto Def = cast<DefInit>(ruleDag->getOperator())->getDef();
Expand Down Expand Up @@ -1039,11 +1045,24 @@ void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule,
auto Def = cast<DefInit>(Dag->getOperator())->getDef();

if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") {
auto tname = Dag->getArgNameStr(0);
auto rname = Dag->getArgNameStr(1);
auto cname = Dag->getArgNameStr(2);
os << "get_blas_row(Builder2, arg_transposed_" << tname << ", arg_"
<< rname << ", arg_" << cname << ", byRef)";
std::string tname, rname, cname;
tname = Dag->getArgNameStr(0);
if (DefInit *Def1 = dyn_cast<DefInit>(Dag->getArg(1))) {
auto Def1Name = Def1->getDef()->getValueAsString("name");
assert(Def1->getDef()->isSubClassOf("adj"));
rname = (Twine("d_") + Def1Name).str();
} else {
rname = (Twine("arg_") + Dag->getArgNameStr(1)).str();
}
if (DefInit *Def2 = dyn_cast<DefInit>(Dag->getArg(2))) {
auto Def2Name = Def2->getDef()->getValueAsString("name");
assert(Def2->getDef()->isSubClassOf("adj"));
cname = (Twine("d_") + Def2Name).str();
} else {
cname = (Twine("arg_") + Dag->getArgNameStr(2)).str();
}
os << "get_blas_row(Builder2, arg_transposed_" << tname << ", "
<< rname << ", " << cname << ", byRef)";
} else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") {
assert(Dag->getNumArgs() == 5);
//(ld $A, $transa, $lda, $m, $k)
Expand Down Expand Up @@ -1191,7 +1210,6 @@ void rev_call_args(StringRef argName, Rule &rule, size_t actArg,
raw_ostream &os, int subRule = -1) {

const auto nameMap = rule.getArgNameMap();
const auto typeMap = rule.getArgTypeMap();

auto ruleDag = rule.getRuleDag();
size_t numArgs = ruleDag->getNumArgs();
Expand Down

0 comments on commit 785c24b

Please sign in to comment.