Skip to content

Commit

Permalink
update rule for gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Sep 20, 2023
1 parent 6cf0df7 commit a3de25e
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def tp : MagicInst; // transpose the trans param.
def noop : MagicInst; // gradient is zero
def inactive : MagicInst; // like noop, but assert it's inactive
def Rows : MagicInst; // given a transpose, normal rows, normal cols get the true rows, aka normal rows if N else normal cols
def Concat : MagicInst;

// if !cache_A, then just use $lda.
// if cache_A, then check $transa.
Expand All @@ -82,8 +83,8 @@ class adj<string _name> {
class Constant<string _value> {
string value = _value;
}
class Char<string _name> {
string name = _name;
class Char<string _value> {
string value = _value;
}

class transpose<string _name> {
Expand Down Expand Up @@ -224,30 +225,31 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A

/* alpha */ (Seq<["AB", "product", "m", "n"]>
(b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n
(FrobInnerProd<""> $m, $n, adj<"C">, $ldc, use<"AB">)),
(FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)),
/* A */ (b<"gemm"> $layout, (Rows $transa, Char<"N">, $transb),
(Rows $transa, transpose<"transb">, Char<"T">),
(Rows $transa, $m, $k),
(Rows $transa, $k, $m),
$n, $alpha,
(Rows $transa, adj<"C">, $B),
(Rows $transa, $ldc, (ld $B, $transb, $ldb, $k, $n)),
(Rows $transa, $B, adj<"C">),
(Rows $transa, (ld $B, $transb, $ldb, $k, $n), $ldc),
$beta, adj<"A">, $lda),
(Rows $transa, adj<"C">, (Concat $B, (ld $B, $transb, $ldb, $k, $n))),
//(Rows $transa, ____, ),
(Rows $transa, (Concat $B, (ld $B, $transb, $ldb, $k, $n)), adj<"C">),
//(Rows $transa, , ____),
$beta, adj<"A">),

/* B */ (b<"gemm"> $layout, (Rows $transb, transpose<"transa">, Char<"T">),
(Rows $transb, Char<"N">, transpose<"transa">),
(Rows $transb, $k, $n),
(Rows $transb, $n, $k),
$m, $alpha,
(Rows $transb, $A, adj<"C">),
(Rows $transb, (ld $A, $transa, $lda, $m, $k)),
(Rows $transb, adj<"C">, $A),
(Rows $transb, $ldc, $lda),
$beta, adj<"B">, $ldb),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, $ldc, ConstantInt<0>)
(Rows $transb, (Concat $A, (ld $A, $transa, $lda, $m, $k)), adj<"C">),
//(Rows $transb, ),
// check next line
(Rows $transb, adj<"C">, (Concat $A, (ld $A, $transa, $lda, $m, $k))),
//(Rows $transb, ____, $lda),
$beta, adj<"B">),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, ConstantInt<0>)
]
>;

Expand Down

0 comments on commit a3de25e

Please sign in to comment.