Skip to content

Commit

Permalink
start fixing gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Sep 20, 2023
1 parent 53a15ad commit 76111f9
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class adj<string _name> {
class Constant<string _value> {
string value = _value;
}
class Char<string _value> {
string value = _value;
class Char<string _name> {
string name = _name;
}

class transpose<string _name> {
Expand Down Expand Up @@ -227,8 +227,27 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A
/* alpha */ (Seq<["AB", "product", "m", "n"]>
(b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n
(FrobInnerProd<""> $m, $n, adj<"C">, $ldc, use<"AB">)),
/* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $ldc, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">, $lda),
/* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $ldc, $beta, adj<"B">, $ldb),
/* A */ (b<"gemm"> $layout, (Rows $transa, Char<"N">, $transb),
(Rows $transa, transpose<"transb">, Char<"T">),
(Rows $transa, $m, $k),
(Rows $transa, $k, $m),
$n, $alpha,
(Rows $transa, adj<"C">, $B),
(Rows $transa, $ldc, (ld $B, $transb, $ldb, $k, $n)),
(Rows $transa, $B, adj<"C">),
(Rows $transa, (ld $B, $transb, $ldb, $k, $n), $ldc),
$beta, adj<"A">, $lda),

/* B */ (b<"gemm"> $layout, (Rows $transb, transpose<"transa">, Char<"T">),
(Rows $transb, Char<"N">, transpose<"transa">),
(Rows $transb, $k, $n),
(Rows $transb, $n, $k),
$m, $alpha,
(Rows $transb, $A, adj<"C">),
(Rows $transb, (ld $A, $transa, $lda, $m, $k)),
(Rows $transb, adj<"C">, $A),
(Rows $transb, $ldc, $lda),
$beta, adj<"B">, $ldb),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, $ldc, ConstantInt<0>)
]
Expand Down

0 comments on commit 76111f9

Please sign in to comment.