Skip to content

Commit

Permalink
Gemm passes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 20, 2023
1 parent 817cd86 commit 5dab863
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A
//(Rows $transa, ____, ),
(Rows $transa, (Concat $B, (ld $B, $transb, $ldb, $k, $n)), adj<"C">),
//(Rows $transa, , ____),
$beta, adj<"A">),
Constant<"1.0">, adj<"A">),

/* B */ (b<"gemm"> $layout, (Rows $transb, transpose<"transa">, Char<"T">),
(Rows $transb, Char<"N">, transpose<"transa">),
(Rows $transb, Char<"N">, $transa),
(Rows $transb, $k, $n),
(Rows $transb, $n, $k),
$m, $alpha,
Expand All @@ -247,7 +247,7 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A
// check next line
(Rows $transb, adj<"C">, (Concat $A, (ld $A, $transa, $lda, $m, $k))),
//(Rows $transb, ____, $lda),
$beta, adj<"B">),
Constant<"1.0">, adj<"B">),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, ConstantInt<0>)
]
Expand Down
14 changes: 7 additions & 7 deletions enzyme/test/Integration/ReverseMode/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,11 +1215,11 @@ static void gemmTests() {
{

bool transA_bool = !(transA == 'N' || transA == 'n');
bool transB_bool = !(transA == 'N' || transA == 'n');
bool transB_bool = !(transB == 'N' || transB == 'n');
std::string Test = "GEMM";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, layout, transA_bool ? K : M, transA_bool ? M : K, lda),
/*B*/ BlasInfo(B, layout, transB_bool ? N : K , transA_bool ? K : N, incB),
/*B*/ BlasInfo(B, layout, transB_bool ? N : K , transB_bool ? K : N, incB),
/*C*/ BlasInfo(C, layout, M, N, incC),
BlasInfo(),
BlasInfo(),
Expand Down Expand Up @@ -1275,22 +1275,22 @@ static void gemmTests() {

// dA =
my_dgemm(layout,
transA_bool ? transpose(transB) : transA,
transA_bool ? transA : transpose(transB),
transA_bool ? transB : 'N',
transA_bool ? 'T' : transpose(transB),
transA_bool ? K : M,
transA_bool ? M : K,
N,
alpha,
transA_bool ? B : dC,
transA_bool ? incB : incC,
transA_bool ? C : dB,
transA_bool ? dC : B,
transA_bool ? incC : incB,
1.0, dA, lda);

// dB =
my_dgemm(layout,
transB_bool ? transB : transpose(transA),
transB_bool ? transA : transB,
transB_bool ? 'T' : transpose(transA),
transB_bool ? transA : 'N', //transB,
transB_bool ? N : K,
transB_bool ? K : N,
M,
Expand Down

0 comments on commit 5dab863

Please sign in to comment.