From 76111f99ed0ad158ae09f783a37284525d70eac7 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 19 Sep 2023 21:50:47 -0400 Subject: [PATCH] start fixing gemm --- enzyme/Enzyme/BlasDerivatives.td | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 43742a32246a..08fad3194867 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -82,8 +82,8 @@ class adj { class Constant { string value = _value; } -class Char { - string value = _value; +class Char { + string name = _name; } class transpose { @@ -227,8 +227,27 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A /* alpha */ (Seq<["AB", "product", "m", "n"]> (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, use<"AB">)), - /* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $ldc, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">, $lda), - /* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $ldc, $beta, adj<"B">, $ldb), + /* A */ (b<"gemm"> $layout, (Rows $transa, Char<"N">, $transb), + (Rows $transa, transpose<"transb">, Char<"T">), + (Rows $transa, $m, $k), + (Rows $transa, $k, $m), + $n, $alpha, + (Rows $transa, adj<"C">, $B), + (Rows $transa, $ldc, (ld $B, $transb, $ldb, $k, $n)), + (Rows $transa, $B, adj<"C">), + (Rows $transa, (ld $B, $transb, $ldb, $k, $n), $ldc), + $beta, adj<"A">, $lda), + + /* B */ (b<"gemm"> $layout, (Rows $transb, transpose<"transa">, Char<"T">), + (Rows $transb, Char<"N">, transpose<"transa">), + (Rows $transb, $k, $n), + (Rows $transb, $n, $k), + $m, $alpha, + (Rows $transb, $A, adj<"C">), + (Rows $transb, (ld $A, $transa, $lda, $m, $k)), + (Rows $transb, adj<"C">, $A), + (Rows $transb, $ldc, $lda), + $beta, adj<"B">, $ldb), /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, input<"C">), /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, $ldc, ConstantInt<0>) ]