From edd0331deebf5e875ac312357285f251e4d35d88 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 20 Sep 2023 04:28:50 -0500 Subject: [PATCH] Add dot blas test (#1446) * Add dot blas test * Now adding caching test * Cleanup and progress * cleanup * Fix autodiff ordering with inlining * Rebase * Fix fortran calling conv --- enzyme/Enzyme/BlasDerivatives.td | 52 +- enzyme/Enzyme/Enzyme.cpp | 144 +++-- enzyme/Enzyme/Utils.cpp | 33 +- enzyme/Enzyme/Utils.h | 9 +- enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll | 10 +- .../test/Enzyme/ReverseMode/blas/gemm_f_c.ll | 10 +- .../Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll | 12 +- .../blas/gemm_f_c_lacpy_runtime_act.ll | 12 +- .../Enzyme/ReverseMode/blas/gemm_f_c_loop.ll | 8 +- .../Enzyme/ReverseMode/blas/gemm_f_c_split.ll | 10 +- .../ReverseMode/blas/gemm_f_c_split_lacpy.ll | 10 +- .../blas/gemm_f_c_split_transpose_lacpy.ll | 10 +- .../blas/gemm_f_c_transpose_lacpy.ll | 10 +- .../ReverseMode/blas/gemm_f_change_ld.ll | 8 +- .../Enzyme/ReverseMode/blas/gemm_f_lacpy.ll | 10 +- .../Enzyme/ReverseMode/blas/gemm_f_over.ll | 10 +- .../ReverseMode/blas/gemm_f_over_lacpy.ll | 10 +- .../Enzyme/ReverseMode/blas/gemv_c_loop.ll | 2 +- .../Enzyme/ReverseMode/blas/gemv_c_loop2.ll | 5 +- .../ReverseMode/blas/gemv_c_loop3_matcopy.ll | 12 +- .../blas/gemv_f_c_split_blascpy.ll | 52 +- .../gemv_f_c_split_blascpy_runtime_act.ll | 28 +- .../ReverseMode/blas/gemv_f_c_split_memcpy.ll | 36 +- enzyme/test/Integration/ReverseMode/blas.cpp | 609 ++++++++++++++++-- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 371 ++++++----- enzyme/tools/enzyme-tblgen/caching.cpp | 6 +- enzyme/tools/enzyme-tblgen/datastructures.cpp | 65 +- enzyme/tools/enzyme-tblgen/datastructures.h | 9 +- 28 files changed, 1082 insertions(+), 481 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 43742a32246a..9ff3abf65bbd 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -109,8 +109,8 @@ def scal : CallBlasPattern<(Op $n, $alpha, $x, $incx), ["x"],[len, fp, vinc<["n"]>], [ // dot must proceed scal, because scal modifies adj<"x"> - (b<"dot"> $n, $x, $incx, adj<"x">, $incx), - (b<"scal"> $n, $alpha, adj<"x">, $incx) + (b<"dot"> $n, $x, adj<"x">), + (b<"scal"> $n, $alpha, adj<"x">) ] >; @@ -134,8 +134,8 @@ def lascl : CallBlasPattern<(Op $layout, $type, $kl, $ku, $cfrom, $cto, $m, $n, def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy), ["y"],[len, fp, vinc<["n"]>, vinc<["n"]>], [ - (b<"dot"> $n, adj<"y">, $incy, $x, $incx), - (b<"axpy"> $n, $alpha, adj<"y">, $incy, adj<"x">, $incx), + (b<"dot"> $n, adj<"y">, $x), + (b<"axpy"> $n, $alpha, adj<"y">, adj<"x">), (noop) // y = alpha*x + y, so nothing to do here ] >; @@ -143,8 +143,8 @@ def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy), def dot : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), [],[len, vinc<["n"]>, vinc<["n"]>], [ - (b<"axpy"> $n, DiffeRet, $y, $incy, adj<"x">, $incx), - (b<"axpy"> $n, DiffeRet, $x, $incx, adj<"y">, $incy) + (b<"axpy"> $n, DiffeRet, $y, adj<"x">), + (b<"axpy"> $n, DiffeRet, $x, adj<"y">), ] >; @@ -158,7 +158,7 @@ def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), ["y"],[len, vinc<["n"]>, vinc<["n"]>], [ (noop),// copy moves x into y, so x is never modified. - (b<"axpy"> $n, Constant<"1.0">, adj<"y">, $incy, adj<"x">, $incx) + (b<"axpy"> $n, Constant<"1.0">, adj<"y">, adj<"x">) ] >; @@ -184,8 +184,8 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ ["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>], [ /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]> - (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $n), $x, $incx, Constant<"0.0">, use<"Ax">, ConstantInt<1>), - (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, use<"Ax">, ConstantInt<1>)), + (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), + (b<"dot"> (Rows $transa, $m, $n), adj<"y">, use<"Ax">, ConstantInt<1>)), //if (is_normal $transa) { // call sger(m, n, alpha, ya, incy, x, incx, Aa, lda) @@ -193,13 +193,11 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ // call sger(m, n, alpha, x, incx, ya, incy, Aa, lda) //} /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, adj<"y">, $x), - (Rows $transa, $incy, $incx), (Rows $transa, $x, adj<"y">), - (Rows $transa, $incx, $incy), - adj<"A">, $lda), - /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, $transa, $lda, $m, $n), adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx), - /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, input<"y">, $incy), - /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">, $incy) + adj<"A">), + /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">), + /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">), + /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">) ] >; // @@ -226,11 +224,11 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A /* alpha */ (Seq<["AB", "product", "m", "n"]> (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n - (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, use<"AB">)), - /* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $ldc, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">, $lda), - /* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $ldc, $beta, adj<"B">, $ldb), - /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, input<"C">), - /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, $ldc, ConstantInt<0>) + (FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)), + /* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">), + /* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $beta, adj<"B">), + /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">), + /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, ConstantInt<0>) ] >; @@ -239,14 +237,14 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta [cblas_layout, uplo, len, fp, ap<["n"]>, vinc<["n"]>, fp, vinc<["n"]>], [ /* alpha */ (Seq<["y0", "triangular", "n"]> - (b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, $incx, Constant<"0.0">, use<"y0">, ConstantInt<1>), - (b<"dot"> $n, adj<"y">, $incy, use<"y0">, ConstantInt<1>)), + (b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, Constant<"0.0">, use<"y0">, ConstantInt<1>), + (b<"dot"> $n, adj<"y">, use<"y0">, ConstantInt<1>)), /* ap */ (Seq<[]> - (b<"spr2"> $layout, $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">), - (DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">)), - /* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx), - /* beta */ (b<"dot"> $n, adj<"y">, $incy, input<"y">, $incy), - /* y */ (b<"scal"> $n, $beta, adj<"y">, $incy) + (b<"spr2"> $layout, $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">), + (DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">)), + /* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, Constant<"1.0">, adj<"x">), + /* beta */ (b<"dot"> $n, adj<"y">, input<"y">), + /* y */ (b<"scal"> $n, $beta, adj<"y">) ] >; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 793d6dacef34..4772edbedfe5 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1409,7 +1409,8 @@ class EnzymeBase { Type *retElemType, SmallVectorImpl &args, const std::map &byVal, const std::vector &constants, Function *fn, - DerivativeMode mode, Options &options, bool sizeOnly) { + DerivativeMode mode, Options &options, bool sizeOnly, + SmallVectorImpl &calls) { auto &differet = options.differet; auto &tape = options.tape; auto &width = options.width; @@ -1702,63 +1703,13 @@ class EnzymeBase { } ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode); - - if (Logic.PostOpt) { - auto Params = llvm::getInlineParams(); - - llvm::SetVector Q; - Q.insert(diffretc); - while (Q.size()) { - auto cur = *Q.begin(); - Function *outerFunc = cur->getParent()->getParent(); - llvm::OptimizationRemarkEmitter ORE(outerFunc); - Q.erase(Q.begin()); - if (auto F = cur->getCalledFunction()) { - if (!F->empty()) { - // Garbage collect AC's created - SmallVector ACAlloc; - auto getAC = [&](Function &F) -> llvm::AssumptionCache & { - auto AC = new AssumptionCache(F); - ACAlloc.push_back(AC); - return *AC; - }; - auto GetTLI = - [&](llvm::Function &F) -> const llvm::TargetLibraryInfo & { - return Logic.PPC.FAM.getResult(F); - }; - - auto GetInlineCost = [&](CallBase &CB) { - TargetTransformInfo TTI(F->getParent()->getDataLayout()); - auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI); - return cst; - }; - if (llvm::shouldInline(*cur, GetInlineCost, ORE)) { - InlineFunctionInfo IFI; - InlineResult IR = InlineFunction(*cur, IFI); - if (IR.isSuccess()) { - LowerSparsification(outerFunc, /*replaceAll*/ false); - for (auto U : outerFunc->users()) { - if (auto CI = dyn_cast(U)) { - if (CI->getCalledFunction() == outerFunc) { - Q.insert(CI); - } - } - } - } - } - for (auto AC : ACAlloc) { - delete AC; - } - } - } - } - } - return true; + calls.push_back(diffretc); + return diffret; } /// Return whether successful - bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, - bool sizeOnly) { + bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, bool sizeOnly, + SmallVectorImpl &calls) { // determine function to differentiate Function *fn = parseFunctionParameter(CI); @@ -1796,16 +1747,17 @@ class EnzymeBase { #if LLVM_VERSION_MAJOR >= 16 return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.value(), - sizeOnly); + byVal, constants, fn, mode, options.value(), sizeOnly, + calls); #else return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, byVal, constants, fn, mode, options.getValue(), - sizeOnly); + sizeOnly, calls); #endif } - bool HandleProbProg(CallInst *CI, ProbProgMode mode) { + bool HandleProbProg(CallInst *CI, ProbProgMode mode, + SmallVectorImpl &calls) { IRBuilder<> Builder(CI); Function *F = parseFunctionParameter(CI); if (!F) @@ -1928,13 +1880,15 @@ class EnzymeBase { } #if LLVM_VERSION_MAJOR >= 16 - bool status = HandleAutoDiff( - CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, - newFunc, DerivativeMode::ReverseModeCombined, opt.value(), false); + bool status = + HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, + constants, newFunc, DerivativeMode::ReverseModeCombined, + opt.value(), false, calls); #else - bool status = HandleAutoDiff( - CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, - newFunc, DerivativeMode::ReverseModeCombined, opt.getValue(), false); + bool status = + HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, + constants, newFunc, DerivativeMode::ReverseModeCombined, + opt.getValue(), false, calls); #endif delete interface; @@ -2447,17 +2401,19 @@ class EnzymeBase { Changed = true; } + SmallVector calls; + // Perform all the size replacements first to create constants for (auto pair : toSize) { bool successful = HandleAutoDiffArguments(pair.first, pair.second, - /*sizeOnly*/ true); + /*sizeOnly*/ true, calls); Changed = true; if (!successful) break; } for (auto pair : toLower) { bool successful = HandleAutoDiffArguments(pair.first, pair.second, - /*sizeOnly*/ false); + /*sizeOnly*/ false, calls); Changed = true; if (!successful) break; @@ -2495,7 +2451,59 @@ class EnzymeBase { } for (auto &&[call, mode] : toProbProg) { - HandleProbProg(call, mode); + HandleProbProg(call, mode, calls); + } + + if (Logic.PostOpt) { + auto Params = llvm::getInlineParams(); + + llvm::SetVector Q; + for (auto call : calls) + Q.insert(call); + while (Q.size()) { + auto cur = *Q.begin(); + Function *outerFunc = cur->getParent()->getParent(); + llvm::OptimizationRemarkEmitter ORE(outerFunc); + Q.erase(Q.begin()); + if (auto F = cur->getCalledFunction()) { + if (!F->empty()) { + // Garbage collect AC's created + SmallVector ACAlloc; + auto getAC = [&](Function &F) -> llvm::AssumptionCache & { + auto AC = new AssumptionCache(F); + ACAlloc.push_back(AC); + return *AC; + }; + auto GetTLI = + [&](llvm::Function &F) -> const llvm::TargetLibraryInfo & { + return Logic.PPC.FAM.getResult(F); + }; + + auto GetInlineCost = [&](CallBase &CB) { + TargetTransformInfo TTI(F->getParent()->getDataLayout()); + auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI); + return cst; + }; + if (llvm::shouldInline(*cur, GetInlineCost, ORE)) { + InlineFunctionInfo IFI; + InlineResult IR = InlineFunction(*cur, IFI); + if (IR.isSuccess()) { + LowerSparsification(outerFunc, /*replaceAll*/ false); + for (auto U : outerFunc->users()) { + if (auto CI = dyn_cast(U)) { + if (CI->getCalledFunction() == outerFunc) { + Q.insert(CI); + } + } + } + } + } + for (auto AC : ACAlloc) { + delete AC; + } + } + } + } } if (Changed && EnzymeAttributor) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 94de076a4af1..3445550a1b13 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -645,7 +645,8 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef args, llvm::ArrayRef bundles) { - std::string copy_name = (blas.floatType + "lacpy" + blas.suffix).str(); + std::string copy_name = + (blas.prefix + blas.floatType + "lacpy" + blas.suffix).str(); SmallVector tys; for (auto arg : args) @@ -2554,14 +2555,16 @@ llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V) { // } else { // ld_A = arg_lda; // } -llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *trans, +llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, + llvm::ArrayRef trans, llvm::Value *arg_ld, llvm::Value *dim1, llvm::Value *dim2, bool cacheMat, bool byRef) { if (!cacheMat) return arg_ld; - Value *width = B.CreateSelect(is_normal(B, trans, byRef), dim1, dim2); + assert(trans.size() == 1); + Value *width = CreateSelect(B, is_normal(B, trans[0], byRef), dim1, dim2); return width; } @@ -2593,19 +2596,27 @@ llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::IntegerType *intType, return B.CreateLoad(intType, VP); } -llvm::Value *get_blas_row(llvm::IRBuilder<> &B, llvm::Value *trans, - llvm::Value *row, llvm::Value *col, bool byRef) { - +SmallVector get_blas_row(llvm::IRBuilder<> &B, + ArrayRef transA, + ArrayRef row, + ArrayRef col, + bool byRef) { + assert(transA.size() == 1); + auto trans = transA[0]; if (byRef) { auto charType = IntegerType::get(trans->getContext(), 8); trans = B.CreateLoad(charType, trans, "ld.row.trans"); } - return B.CreateSelect( - B.CreateOr( - B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')), - B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'))), - row, col); + auto cond = B.CreateOr( + B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')), + B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'))); + assert(row.size() == col.size()); + SmallVector toreturn; + for (size_t i = 0; i < row.size(); i++) { + toreturn.push_back(B.CreateSelect(cond, row[i], col[i])); + } + return toreturn; } // return how many Special pointers are in T (count > 0), diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 3ace54029ca3..8d8d47bcff51 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1637,7 +1637,8 @@ llvm::Value *to_blas_fp_callconv(llvm::IRBuilder<> &B, llvm::Value *V, llvm::IRBuilder<> &entryBuilder, llvm::Twine const & = ""); -llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *trans, +llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, + llvm::ArrayRef trans, llvm::Value *arg_ld, llvm::Value *dim_1, llvm::Value *dim_2, bool cacheMat, bool byRef); @@ -1651,8 +1652,10 @@ llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V); llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, llvm::IntegerType *IT, llvm::IRBuilder<> &entryBuilder, const llvm::Twine &name); -llvm::Value *get_blas_row(llvm::IRBuilder<> &B, llvm::Value *trans, - llvm::Value *row, llvm::Value *col, bool byRef); +llvm::SmallVector +get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef trans, + llvm::ArrayRef row, + llvm::ArrayRef col, bool byRef); // Parameter attributes from the original function/call that // we should preserve on the primal of the derivative code. diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll index 6cd71c23be62..9b514b2be53b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -84,7 +84,7 @@ entry: ; CHECK-NEXT: store i64 8, i64* %ldb, align 16 ; CHECK-NEXT: store double 0.000000e+00, double* %beta ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry @@ -110,8 +110,8 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll index 5a3d1793dbbe..991ae30a792c 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -176,7 +176,7 @@ entry: ; CHECK-NEXT: br i1 %39, label %[[enzyme_memcpy_double_mat_64_exit21]], label %[[init_idx]] ; CHECK: [[enzyme_memcpy_double_mat_64_exit21]]: ; preds = %__enzyme_memcpy_double_mat_64.exit, %[[init_end_i18]] -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !10, !noalias !13 @@ -213,13 +213,13 @@ entry: ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans4, 110 ; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-DAG: %[[r21:.+]] = select i1 %[[r20]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i43]], i8* %[[r21]], i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i43]], i8* %[[r21]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transa ; CHECK-DAG: %[[r22:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-DAG: %[[r23:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-DAG: %[[r24:.+]] = or i1 %[[r23]], %[[r22]] ; CHECK-DAG: %[[r25:.+]] = select i1 %[[r24]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i42]], i8* %[[r25]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i42]], i8* %[[r25]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll index 9b09574ce220..c9e40aac8e36 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B, i8* noalias %alpha, i8* noalias %beta) { entry: @@ -27,7 +27,7 @@ entry: store i64 4, i64* %lda, align 16 store i64 8, i64* %ldb, align 16 store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize5) ; CHECK-NEXT: %mat_AB = bitcast i8* %malloccall6 to double* ; CHECK-NEXT: %[[i21:.+]] = bitcast double* %mat_AB to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !0, !noalias !3 @@ -159,7 +159,7 @@ entry: ; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) ; CHECK: %[[i45:.+]] = bitcast i64* %byref.constant.one.i to i8* ; CHECK: %[[i46:.+]] = bitcast i64* %byref.mat.size.i to i8* ; CHECK: store i64 1, i64* %byref.constant.one.i @@ -207,13 +207,13 @@ entry: ; CHECK-NEXT: %[[i62:.+]] = load double, double* %[[i61]] ; CHECK-NEXT: %[[i63:.+]] = fadd fast double %[[i62]], %res.i ; CHECK-NEXT: store double %[[i63]], double* %[[i61]] -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans8 = load i8, i8* %transa ; CHECK-DAG: %[[i64:.+]] = icmp eq i8 %loaded.trans8, 78 ; CHECK-DAG: %[[i65:.+]] = icmp eq i8 %loaded.trans8, 110 ; CHECK-DAG: %[[i66:.+]] = or i1 %[[i65]], %[[i64]] ; CHECK-NEXT: %[[i67:.+]] = select i1 %[[i66]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK: %[[i68:.+]] = bitcast i64* %byref.constant.one.i15 to i8* ; CHECK: %[[i69:.+]] = bitcast i64* %byref.mat.size.i18 to i8* ; CHECK: store i64 1, i64* %byref.constant.one.i15 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index 1aae37d5b5d7..f8d13293ef97 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -enzyme-runtime-activity=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -enzyme-runtime-activity=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B, i8* noalias %alpha, i8* noalias %beta) { entry: @@ -27,7 +27,7 @@ entry: store i64 4, i64* %lda, align 16 store i64 8, i64* %ldb, align 16 store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -122,7 +122,7 @@ entry: ; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize5) ; CHECK-NEXT: %mat_AB = bitcast i8* %malloccall6 to double* ; CHECK-NEXT: %[[i21:.+]] = bitcast double* %mat_AB to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !0, !noalias !3 @@ -167,7 +167,7 @@ entry: ; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) ; CHECK: %[[i45:.+]] = bitcast i64* %byref.constant.one.i to i8* ; CHECK: %[[i46:.+]] = bitcast i64* %byref.mat.size.i to i8* ; CHECK: store i64 1, i64* %byref.constant.one.i @@ -221,7 +221,7 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.A, label %invertentry.A.done, label %invertentry.A.active ; CHECK: invertentry.A.active: ; preds = %invertentry.alpha.done -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry.A.done ; CHECK: invertentry.A.done: ; preds = %invertentry.A.active, %invertentry.alpha.done @@ -233,7 +233,7 @@ entry: ; CHECK-DAG: %[[i65:.+]] = icmp eq i8 %loaded.trans8, 110 ; CHECK-DAG: %[[i66:.+]] = or i1 %[[i65]], %[[i64]] ; CHECK-NEXT: %[[i67:.+]] = select i1 %[[i66]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry.B.done ; CHECK: invertentry.B.done: ; preds = %invertentry.B.active, %invertentry.A.done diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll index ca9e23085eaf..44cc10d80949 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) declare i8* @AData(i64) declare i8* @Aldap(i64) @@ -43,7 +43,7 @@ loop: store i64 4, i64* %ldc, align 16 %A = call i8* @AData(i64 %i) "enzyme_inactive" %lda_p = call i8* @Aldap(i64 %i) "enzyme_inactive" - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) call void @free(i8* %m_p) %cmp = icmp eq i64 %inc, 10 br i1 %cmp, label %exit, label %loop @@ -187,7 +187,7 @@ entry: ; CHECK-NEXT: br i1 %23, label %__enzyme_memcpy_double_mat_64.exit, label %init.idx.i ; CHECK: __enzyme_memcpy_double_mat_64.exit: ; preds = %loop, %init.end.i -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: call void @free(i8* %m_p) ; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, 10 ; CHECK-NEXT: br i1 %cmp, label %exit, label %loop @@ -275,7 +275,7 @@ entry: ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans30, 110 ; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-DAG: %[[r21:.+]] = select i1 %[[r20]], i8* %[[i46]], i8* %cast.k -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %byref.transb, i8* %cast.k, i8* %n_p_unwrap, i8* %[[i46]], i8* %cast.alpha, i8* %[[i44]], i8* %[[r21]], i8* %"C'", i8* %cast.ldc, i8* %cast.beta, i8* %"B'", i8* %cast.ldb) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %byref.transb, i8* %cast.k, i8* %n_p_unwrap, i8* %[[i46]], i8* %cast.alpha, i8* %[[i44]], i8* %[[r21]], i8* %"C'", i8* %cast.ldc, i8* %cast.beta, i8* %"B'", i8* %cast.ldb, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll index d65e7c4aba61..b3a6f1743944 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -146,7 +146,7 @@ entry: ; CHECK-NEXT: br i1 %36, label %__enzyme_memcpy_double_mat_64.exit, label %init.idx.i ; CHECK: __enzyme_memcpy_double_mat_64.exit: ; preds = %entry, %init.end.i -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %37 = load double*, double** %0 ; CHECK-NEXT: ret double* %37 ; CHECK-NEXT: } @@ -232,13 +232,13 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-NEXT: %[[r21:.+]] = select i1 %[[r20]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index 752f85319613..b15204f9b31d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -118,7 +118,7 @@ entry: ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage ; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] ; CHECK-NEXT: } @@ -204,13 +204,13 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-NEXT: %[[r21:.+]] = select i1 %[[r20]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index 689904cc286b..4561682b3ae9 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -118,7 +118,7 @@ entry: ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage ; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] ; CHECK-NEXT: } @@ -204,13 +204,13 @@ entry: ; CHECK-NEXT: store i8 %[[i33]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[i34:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[i35:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-DAG: %36 = or i1 %[[i35]], %[[i34]] ; CHECK-NEXT: %37 = select i1 %36, i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %37, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %37, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index 7241127eb8bb..d6c145ee54ae 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -120,7 +120,7 @@ entry: ; CHECK-NEXT: %cache.B = bitcast i8* %[[malloccall2]] to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage4 ; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %13, i8* %14, i8* %B, i8* %ldb_p, double* %cache.B, i8* %13) -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !0, !noalias !3 @@ -157,13 +157,13 @@ entry: ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[i42:.+]] = or i1 %[[i41]], %[[i40]] ; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i25]], i8* %[[i43]], i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i25]], i8* %[[i43]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %[[cachedtrans2:.+]] = load i8, i8* %transa ; CHECK-DAG: %[[i54:.+]] = icmp eq i8 %[[cachedtrans2]], 78 ; CHECK-DAG: %[[i55:.+]] = icmp eq i8 %[[cachedtrans2]], 110 ; CHECK-NEXT: %[[i56:.+]] = or i1 %[[i55]], %[[i54]] ; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i24]], i8* %[[i57]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i24]], i8* %[[i57]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %[[intcast0:.+]] = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index 2afc2b06f78a..51879c3537da 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-lapack-copy=1 -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 16, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 8, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %B to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -104,7 +104,7 @@ entry: ; CHECK-NEXT: %cache.B = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage ; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %3, i8* %4, i8* %B, i8* %ldb_p, double* %cache.B, i8* %3) -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %ptr = bitcast i8* %B to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8 ; CHECK-NEXT: br label %invertentry @@ -138,7 +138,7 @@ entry: ; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans1, 110 ; CHECK-NEXT: %[[r18:.+]] = or i1 %[[r17]], %[[r16]] ; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %10, i8* %[[r19]], i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %10, i8* %[[r19]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll index c7571f1ea8ef..f07275c6df1e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -84,7 +84,7 @@ entry: ; CHECK-NEXT: store i64 8, i64* %ldb, align 16 ; CHECK-NEXT: store double 0.000000e+00, double* %beta ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry @@ -110,8 +110,8 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %[[int00]] ; CHECK-NEXT: %[[intcast00:.+]] = bitcast i64* %[[int00]] to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll index c95f61c138a4..5f5ff35f0659 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) store i64 0, i64* %m, align 16 ret void } @@ -88,7 +88,7 @@ entry: ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 ; CHECK-NEXT: %pcld.m = bitcast i8* %m_p to i64* ; CHECK-NEXT: %avld.m = load i64, i64* %pcld.m -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: store i64 0, i64* %m ; CHECK-NEXT: br label %invertentry @@ -117,8 +117,8 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll index 63283c5e271c..1c1c0cea8c2c 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) store i64 0, i64* %m, align 16 ret void } @@ -88,7 +88,7 @@ entry: ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 ; CHECK-NEXT: %pcld.m = bitcast i8* %m_p to i64* ; CHECK-NEXT: %avld.m = load i64, i64* %pcld.m -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: store i64 0, i64* %m ; CHECK-NEXT: br label %invertentry @@ -117,8 +117,8 @@ entry: ; CHECK-DAG: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll index 21807cc0b4de..8997057fde01 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll @@ -55,7 +55,7 @@ entry: ; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %0, 8 ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* -; CHECK-NEXT: call void @dlacpy(i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) +; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) ; CHECK-NEXT: %1 = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8 ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll index 872964e04cf7..8501555f6e06 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll @@ -55,7 +55,7 @@ entry: ; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %0, 8 ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* -; CHECK-NEXT: call void @dlacpy(i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) +; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) ; CHECK-NEXT: %1 = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8 ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1) @@ -108,8 +108,7 @@ entry: ; CHECK-DAG: %[[r20:.+]] = select i1 false, double* %"v0'", double* %cache.x_unwrap ; CHECK-DAG: %[[r21:.+]] = select i1 false, double* %cache.x_unwrap, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r20]], i32 1, double* %[[r21]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i22:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A_unwrap, i32 %[[i22]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A_unwrap, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i23:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i23]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i24:.+]] = bitcast double* %cache.A_unwrap to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll index 3819ed2b6eb9..5133760023d2 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll @@ -201,8 +201,7 @@ entry: ; CHECK-DAG: %[[r42:.+]] = select i1 false, double* %"v0'", double* %cache.x ; CHECK-DAG: %[[r43:.+]] = select i1 false, double* %cache.x, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r42]], i32 1, double* %[[r43]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i48:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A, i32 %[[i48]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i49:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i49]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i50:.+]] = bitcast double* %cache.A to i8* @@ -212,8 +211,7 @@ entry: ; CHECK-DAG: %[[r48:.+]] = select i1 false, double* %"v0'", double* %cache.x8 ; CHECK-DAG: %[[r49:.+]] = select i1 false, double* %cache.x8, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r48]], i32 1, double* %[[r49]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i52:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A5, i32 %[[i52]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A5, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i53:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i53]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i54:.+]] = bitcast double* %cache.A5 to i8* @@ -223,8 +221,7 @@ entry: ; CHECK-DAG: %[[r54:.+]] = select i1 false, double* %"v0'", double* %cache.x16 ; CHECK-DAG: %[[r55:.+]] = select i1 false, double* %cache.x16, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r54]], i32 1, double* %[[r55]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i56:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A13, i32 %[[i56]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A13, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i57:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i57]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i58:.+]] = bitcast double* %cache.A13 to i8* @@ -234,8 +231,7 @@ entry: ; CHECK-DAG: %[[r60:.+]] = select i1 false, double* %"v0'", double* %cache.x24 ; CHECK-DAG: %[[r61:.+]] = select i1 false, double* %cache.x24, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r60]], i32 1, double* %[[r61]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i60:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A21, i32 %[[i60]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A21, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i61:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i61]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i62:.+]] = bitcast double* %cache.A21 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll index 856451959941..7f70c9865b6e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll @@ -4,7 +4,7 @@ ; Here we don't transpose the matrix a (78 equals 'N' in ASCII) and we therefore also don't transpose x. ; Therfore the first arg to dcopy is n_p, as opposed to the gemv_transpose test. ; trans, M, N, alpha, A, lda, x, , incx, beta, y, incy -declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly) +declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i64) define void @f(i8* noalias %y, i8* noalias %A, i8* noalias %x, i8* noalias %alpha, i8* noalias %beta) { entry: @@ -25,7 +25,7 @@ entry: store i64 4, i64* %lda, align 16 store i64 2, i64* %incx, align 16 store i64 1, i64* %incy, align 16 - call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p) + call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p, i64 1) ret void } @@ -104,7 +104,7 @@ entry: ; CHECK-NEXT: %23 = insertvalue { double*, double* } undef, double* %cache.x, 0 ; CHECK-NEXT: %24 = insertvalue { double*, double* } %23, double* %cache.y, 1 ; CHECK-NEXT: store { double*, double* } %24, { double*, double* }* %0 -; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p) +; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p, i64 1) ; CHECK-NEXT: %25 = load { double*, double* }, { double*, double* }* %0 ; CHECK-NEXT: ret { double*, double* } %25 ; CHECK-NEXT: } @@ -114,11 +114,13 @@ entry: ; CHECK-NEXT: %ret = alloca double ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %byref.constant.fp.0.0 = alloca double +; CHECK-DAG: %byref.constant.fp.1.0 = alloca double +; CHECK-DAG: %byref.constant.char.N = alloca i8, align 1 +; CHECK-DAG: %byref.constant.fp.0.0 = alloca double ; CHECK-NEXT: %byref.constant.int.1 = alloca i64 ; CHECK-NEXT: %byref.constant.int.17 = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.013 = alloca double +; CHECK-NEXT: %byref.constant.char.N11 = alloca i8, align 1 +; CHECK-NEXT: %[[byrefconstantfp1:.+]] = alloca double ; CHECK-NEXT: %incy = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %incy to i8* ; CHECK-NEXT: %incx = alloca i64, i64 1, align 16 @@ -182,11 +184,12 @@ entry: ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.1 ; CHECK-NEXT: %intcast.constant.int.1 = bitcast i64* %byref.constant.int.1 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %fpcast.constant.fp.1.0, i8* %A, i8* %lda_p, i8* %20, i8* %intcast.int.one, i8* %fpcast.constant.fp.0.0, i8* %19, i8* %intcast.constant.int.1) +; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %fpcast.constant.fp.1.0, i8* %A, i8* %lda_p, i8* %20, i8* %intcast.int.one, i8* %fpcast.constant.fp.0.0, i8* %19, i8* %intcast.constant.int.1, i64 1) ; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[c1:.+]] = icmp eq i8 %ld.row.trans, 110 ; CHECK-DAG: %[[c2:.+]] = icmp eq i8 %ld.row.trans, 78 @@ -204,28 +207,21 @@ entry: ; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %ld.row.trans9, 78 ; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]] ; CHECK-NEXT: %[[i42:.+]] = select i1 %41, i8* %"y'", i8* %20 +; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i41]], i8* %incy_p, i8* %intcast.int.one ; CHECK-NEXT: %ld.row.trans10 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[i43:.+]] = icmp eq i8 %ld.row.trans10, 110 -; CHECK-DAG: %[[i44:.+]] = icmp eq i8 %ld.row.trans10, 78 -; CHECK-NEXT: %[[i45:.+]] = or i1 %[[i44]], %[[i43]] -; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i45]], i8* %incy_p, i8* %incx_p -; CHECK-NEXT: %ld.row.trans11 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[i47:.+]] = icmp eq i8 %ld.row.trans11, 110 -; CHECK-DAG: %[[i48:.+]] = icmp eq i8 %ld.row.trans11, 78 +; CHECK-DAG: %[[i47:.+]] = icmp eq i8 %ld.row.trans10, 110 +; CHECK-DAG: %[[i48:.+]] = icmp eq i8 %ld.row.trans10, 78 ; CHECK-NEXT: %[[i49:.+]] = or i1 %[[i48]], %[[i47]] ; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8* %20, i8* %"y'" -; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %malloccall, align 1 -; CHECK-NEXT: %[[i51:.+]] = icmp eq i8 %ld.row.trans12, 110 -; CHECK-NEXT: %[[i52:.+]] = icmp eq i8 %ld.row.trans12, 78 -; CHECK-NEXT: %[[i53:.+]] = or i1 %52, %51 -; CHECK-NEXT: %[[i54:.+]] = select i1 %53, i8* %incx_p, i8* %incy_p +; CHECK-NEXT: %[[i54:.+]] = select i1 %[[i49]], i8* %intcast.int.one, i8* %incy_p ; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[i42]], i8* %[[i46]], i8* %[[i50]], i8* %[[i54]], i8* %"A'", i8* %lda_p) -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.013 -; CHECK-NEXT: %fpcast.constant.fp.1.014 = bitcast double* %byref.constant.fp.1.013 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.014, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans15, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans15, 78 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N11, align 1 +; CHECK-NEXT: store double 1.000000e+00, double* %[[byrefconstantfp1]] +; CHECK-NEXT: %[[fpcast14:.+]] = bitcast double* %[[byrefconstantfp1]] to i8* +; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %[[fpcast14]], i8* %"x'", i8* %incx_p, i64 1) +; CHECK-NEXT: %ld.row.trans14 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans14, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans14, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] ; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p ; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %21, i8* %intcast.int.one) @@ -233,9 +229,9 @@ entry: ; CHECK-NEXT: %[[r45:.+]] = load double, double* %[[r44]] ; CHECK-NEXT: %[[r46:.+]] = fadd fast double %[[r45]], %[[r43]] ; CHECK-NEXT: store double %[[r46]], double* %[[r44]] -; CHECK-NEXT: %ld.row.trans16 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans16, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans16, 78 +; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans15, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans15, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] ; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll index 2e2ea7b4dac4..7337b69b7794 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll @@ -176,21 +176,13 @@ entry: ; CHECK-DAG: %[[r23:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-NEXT: %[[r24:.+]] = or i1 %[[r23]], %[[r22]] ; CHECK-NEXT: %[[r25:.+]] = select i1 %[[r24]], i8* %"y'", i8* %11 +; CHECK-NEXT: %[[r29:.+]] = select i1 %[[r24]], i8* %incy_p, i8* %intcast.int.one ; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r26:.+]] = icmp eq i8 %ld.row.trans2, 110 -; CHECK-DAG: %[[r27:.+]] = icmp eq i8 %ld.row.trans2, 78 -; CHECK-NEXT: %[[r28:.+]] = or i1 %[[r27]], %[[r26]] -; CHECK-NEXT: %[[r29:.+]] = select i1 %[[r28]], i8* %incy_p, i8* %incx_p -; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r30:.+]] = icmp eq i8 %ld.row.trans3, 110 -; CHECK-DAG: %[[r31:.+]] = icmp eq i8 %ld.row.trans3, 78 +; CHECK-DAG: %[[r30:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-DAG: %[[r31:.+]] = icmp eq i8 %ld.row.trans2, 78 ; CHECK-NEXT: %[[r32:.+]] = or i1 %[[r31]], %[[r30]] ; CHECK-NEXT: %[[r33:.+]] = select i1 %[[r32]], i8* %11, i8* %"y'" -; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r34:.+]] = icmp eq i8 %ld.row.trans4, 110 -; CHECK-DAG: %[[r35:.+]] = icmp eq i8 %ld.row.trans4, 78 -; CHECK-NEXT: %[[r36:.+]] = or i1 %[[r35]], %[[r34]] -; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r36]], i8* %incx_p, i8* %incy_p +; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r32]], i8* %intcast.int.one, i8* %incy_p ; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[r25]], i8* %[[r29]], i8* %[[r33]], i8* %[[r37]], i8* %"A'", i8* %lda_p) ; CHECK-NEXT: br label %invertentry.A.done @@ -198,9 +190,9 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.beta, label %invertentry.beta.done, label %invertentry.beta.active ; CHECK: invertentry.beta.active: ; preds = %invertentry.A.done -; CHECK-NEXT: %ld.row.trans5 = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans5, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans5, 78 +; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans3, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans3, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] ; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p ; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %[[i12]], i8* %intcast.int.one) @@ -214,9 +206,9 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %invertentry.beta.done -; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans6, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans6, 78 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans4, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] ; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll index bef22acf4410..cd853b3f207a 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll @@ -2,7 +2,7 @@ ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-blas-copy=0 -enzyme-lapack-copy=1 -S | FileCheck %s ; trans, M, N, alpha, A, lda, x, , incx, beta, y, incy -declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly) +declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i64) define void @f(i8* noalias %y, i8* noalias %A, i8* noalias %x) { entry: @@ -29,7 +29,7 @@ entry: store i64 2, i64* %incx, align 16 store double 0.000000e+00, double* %beta store i64 1, i64* %incy, align 16 - call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p) + call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p, i64 1) ret void } @@ -126,7 +126,7 @@ entry: ; CHECK-NEXT: br i1 %25, label %__enzyme_memcpy_double_64_da0sa0stride.exit, label %for.body.i ; CHECK: __enzyme_memcpy_double_64_da0sa0stride.exit: ; preds = %entry, %for.body.i -; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p) +; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p, i64 1) ; CHECK-NEXT: %26 = load double*, double** %0 ; CHECK-NEXT: ret double* %26 ; CHECK-NEXT: } @@ -136,6 +136,7 @@ entry: ; CHECK-NEXT: %ret = alloca double ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 ; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double ; CHECK-NEXT: %incy = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %incy to i8* @@ -195,28 +196,21 @@ entry: ; CHECK-DAG: %[[r25:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-DAG: %[[r26:.+]] = or i1 %[[r25]], %[[r24]] ; CHECK-NEXT: %[[r27:.+]] = select i1 %[[r26]], i8* %"y'", i8* %15 -; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r28:.+]] = icmp eq i8 %ld.row.trans1, 110 -; CHECK-DAG: %[[r29:.+]] = icmp eq i8 %ld.row.trans1, 78 -; CHECK-DAG: %[[r30:.+]] = or i1 %[[r29]], %[[r28]] -; CHECK-NEXT: %[[r31:.+]] = select i1 %[[r30]], i8* %incy_p, i8* %incx_p -; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r32:.+]] = icmp eq i8 %ld.row.trans2, 110 -; CHECK-DAG: %[[r33:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-NEXT: %[[r31:.+]] = select i1 %[[r26]], i8* %incy_p, i8* %intcast.int.one +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall +; CHECK-DAG: %[[r32:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-DAG: %[[r33:.+]] = icmp eq i8 %ld.row.trans1, 78 ; CHECK-DAG: %[[r34:.+]] = or i1 %[[r33]], %[[r32]] ; CHECK-NEXT: %[[r35:.+]] = select i1 %[[r34]], i8* %15, i8* %"y'" -; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r36:.+]] = icmp eq i8 %ld.row.trans3, 110 -; CHECK-DAG: %[[r37:.+]] = icmp eq i8 %ld.row.trans3, 78 -; CHECK-DAG: %[[r38:.+]] = or i1 %[[r37]], %[[r36]] -; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r38]], i8* %incx_p, i8* %incy_p -; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %27, i8* %31, i8* %35, i8* %39, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r34]], i8* %intcast.int.one, i8* %incy_p +; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %[[r27]], i8* %[[r31]], i8* %[[r35]], i8* %[[r39]], i8* %"A'", i8* %lda_p) +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans4, 110 -; CHECK-DAG: %[[r41:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p, i64 1) +; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-DAG: %[[r41:.+]] = icmp eq i8 %ld.row.trans2, 78 ; CHECK-NEXT: %[[r42:.+]] = or i1 %[[r41]], %[[r40]] ; CHECK-NEXT: %[[r43:.+]] = select i1 %[[r42]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r43]], i8* %beta_p, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index a225cbd84013..1561cda793f2 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -1,17 +1,17 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi #include "test_utils.h" @@ -98,8 +98,9 @@ bool inDerivative = false; double alpha = 2.71828; double beta = 47.56; - int M = 105688; - int N = 78412; + int M = 105688; + int N = 78412; + int K = 5013424; int lda = 3416; char UNUSED_TRANS = 'A'; int UNUSED_INT = -1; @@ -110,7 +111,11 @@ enum class CallType { GEMM, SCAL, GER, - COPY + DOT, + AXPY, + LASCL, + COPY, + LACPY, }; struct BlasCall { @@ -168,6 +173,10 @@ void printty(CallType v) { case CallType::SCAL: printf("SCAL"); return; case CallType::GER: printf("GER"); return; case CallType::COPY: printf("COPY"); return; + case CallType::LACPY: printf("LACPY"); return; + case CallType::DOT: printf("DOT"); return; + case CallType::AXPY: printf("AXPY"); return; + case CallType::LASCL: printf("LASCL"); return; default: printf("UNKNOWN CALL (%d)", (int)v); } } @@ -190,6 +199,7 @@ void printty(int v) { else if (v == incC) printf("incC"); else if (v == M) printf("M"); else if (v == N) printf("N"); + else if (v == K) printf("K"); else if (v == lda) printf("lda"); else if (v == UNUSED_INT) printf("UNUSED_INT"); else printf("Unknown int"); @@ -226,6 +236,77 @@ void printty(double v) { void printcall(BlasCall rcall) { switch (rcall.type) { + case CallType::LACPY: + printf("LACPY(layout="); + printty(rcall.layout); + printf(", uplo="); + printty(rcall.targ1); + printf(", M="); + printty(rcall.iarg1); + printf(", N="); + printty(rcall.iarg2); + printf(", A="); + printty(rcall.pin_arg1); + printf(", lda="); + printty(rcall.iarg4); + printf(", B="); + printty(rcall.pout_arg1); + printf(", ldb="); + printty(rcall.iarg5); + printf(")"); + return; + case CallType::LASCL: + printf("LASCL(layout="); + printty(rcall.layout); + printf(", type="); + printty(rcall.targ1); + printf(", KL="); + printty(rcall.iarg5); + printf(", KU="); + printty(rcall.iarg6); + printf(", cfrom="); + printty(rcall.farg1); + printf(", cto="); + printty(rcall.farg2); + + printf(", M="); + printty(rcall.iarg1); + printf(", N="); + printty(rcall.iarg2); + printf(", A="); + printty(rcall.pout_arg1); + printf(", lda="); + printty(rcall.iarg4); + printf(")"); + return; + case CallType::AXPY: + printf("DOT(N="); + printty(rcall.iarg1); + printf(", alpha="); + printty(rcall.farg1); + printf(", X="); + printty(rcall.pin_arg1); + printf(", incx="); + printty(rcall.iarg4); + printf(", Y="); + printty(rcall.pout_arg1); + printf(", incy="); + printty(rcall.iarg5); + printf(")"); + return; + case CallType::DOT: + printf("DOT(N="); + printty(rcall.iarg1); + printf(", X="); + printty(rcall.pin_arg1); + printf(", incx="); + printty(rcall.iarg4); + printf(", Y="); + printty(rcall.pin_arg2); + printf(", incy="); + printty(rcall.iarg5); + printf(")"); + return; case CallType::GEMV: printf("GEMV(layout="); printty(rcall.layout); @@ -388,6 +469,43 @@ vector foundCalls; extern "C" { +// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-0/lascl.html +// technically LAPACKE_dlascl +__attribute__((noinline)) +void cblas_dlascl(char layout, char type, int KL, int KU, double cfrom, double cto, int M, int N, double* A, int lda) { + BlasCall call = {inDerivative, CallType::LASCL, + A, UNUSED_POINTER, UNUSED_POINTER, + cfrom, cto, + layout, + type, UNUSED_TRANS, + M, N, UNUSED_INT, lda, KL, KU}; + calls.push_back(call); +} + +__attribute__((noinline)) +double cblas_ddot(int N, double* X, int incx, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::DOT, + UNUSED_POINTER, X, Y, + UNUSED_DOUBLE, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; + calls.push_back(call); + return 3.15+N; +} + +// Y += alpha * X +__attribute__((noinline)) +void cblas_daxpy(int N, double alpha, double* X, int incx, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::AXPY, + Y, X, UNUSED_POINTER, + alpha, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; + calls.push_back(call); +} + // Y = alpha * op(A) * X + beta * Y __attribute__((noinline)) void cblas_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { @@ -443,6 +561,22 @@ void cblas_dcopy(int N, double* X, int incX, double* Y, int incY) { UNUSED_TRANS, UNUSED_TRANS, N, UNUSED_INT, UNUSED_INT, incX, incY, UNUSED_INT}); } + +__attribute__((noinline)) +void cblas_dlacpy(char layout, char uplo, int M, int N, double* A, int lda, double* B, int ldb) { + calls.push_back((BlasCall){inDerivative, CallType::LACPY, + B, A, UNUSED_POINTER, + UNUSED_DOUBLE, UNUSED_DOUBLE, + layout, + uplo, UNUSED_TRANS, + M, N, UNUSED_INT, lda, ldb, UNUSED_INT}); +} + +__attribute__((noinline)) +void dlacpy(char *uplo, int *M, int* N, double* A, int *lda, double* B, int* ldb) { + cblas_dlacpy(CblasColMajor, *uplo, *M, *N, A, *lda, B, *ldb); +} + } enum class ValueType { @@ -450,6 +584,7 @@ enum class ValueType { Vector }; struct BlasInfo { + void* ptr; ValueType ty; int vec_length; int vec_increment; @@ -457,7 +592,8 @@ struct BlasInfo { int mat_rows; int mat_cols; int mat_ld; - BlasInfo (int length, int increment) { + BlasInfo (void* v_ptr, int length, int increment) { + ptr = v_ptr; ty = ValueType::Vector; vec_length = length; vec_increment = increment; @@ -466,7 +602,8 @@ struct BlasInfo { mat_cols = -1; mat_ld = -1; } - BlasInfo (char layout, int rows, int cols, int ld) { + BlasInfo (void* v_ptr, char layout, int rows, int cols, int ld) { + ptr = v_ptr; ty = ValueType::Matrix; vec_length = -1; vec_increment = -1; @@ -475,12 +612,25 @@ struct BlasInfo { mat_cols = cols; mat_ld = ld; } + BlasInfo () { + ptr = (void*)(-1); + ty = ValueType::Matrix; + vec_length = -1; + vec_increment = -1; + mat_layout = -1; + mat_rows = -1; + mat_cols = -1; + mat_ld = -1; + } }; -int pointer_to_index(void* v) { - if (v == A || v == dA) return 0; - if (v == B || v == dB) return 1; - if (v == C || v == dC) return 2; +BlasInfo pointer_to_index(void* v, BlasInfo inputs[6]) { + if (v == A || v == dA) return inputs[0]; + if (v == B || v == dB) return inputs[1]; + if (v == C || v == dC) return inputs[2]; + for (int i=3; i<6; i++) + if (inputs[i].ptr == v) + return inputs[i]; assert(0 && " illegal pointer to invert"); } @@ -570,17 +720,66 @@ void checkMatrix(BlasInfo info, std::string matname, char layout, int rows, int } } -void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vector & trace) { +void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test, const vector & trace) { switch (rcall.type) { + return; + case CallType::LASCL: { + auto A = pointer_to_index(rcall.pout_arg1, inputs); + + auto layout = rcall.layout; + auto type = rcall.targ1; + auto KL = rcall.iarg5; + auto KU = rcall.iarg6; + auto cfrom = rcall.farg1; + auto cto = rcall.farg2; + + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto lda = rcall.iarg4; + + // = 'G': A is a full matrix. + assert(type == 'G'); + + // A is an m-by-n matrix + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + return; + } + case CallType::AXPY: { + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + + auto X = pointer_to_index(rcall.pin_arg1, inputs); + + auto alpha = rcall.farg1; + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } + case CallType::DOT: { + auto X = pointer_to_index(rcall.pin_arg1, inputs); + auto Y = pointer_to_index(rcall.pin_arg2, inputs); + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } case CallType::GEMV:{ // Y = alpha * op(A) * X + beta * Y - auto Y = inputs[pointer_to_index(rcall.pout_arg1)]; - auto A = inputs[pointer_to_index(rcall.pin_arg1)]; - auto X = inputs[pointer_to_index(rcall.pin_arg2)]; + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg2, inputs); auto layout = rcall.layout; auto trans_char = rcall.targ1; - auto trans = (trans_char == 'N' || trans_char == 'n'); + auto trans = !(trans_char == 'N' || trans_char == 'n'); auto M = rcall.iarg1; auto N =rcall.iarg2; auto alpha = rcall.farg1; @@ -600,24 +799,26 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vec // ( 1 + ( m - 1 )*abs( INCX ) ) otherwise. // Before entry, the incremented array X must contain the // vector x. - auto Xlen = trans ? N : M; + auto Xlen = trans ? M : N; checkVector(X, "X", /*len=*/Xlen, /*inc=*/incX, test, rcall, trace); // if no trans, Y must be M otherwise must be N - auto Ylen = trans ? M : N; + auto Ylen = trans ? N : M; checkVector(Y, "Y", /*len=*/Ylen, /*inc=*/incY, test, rcall, trace); return; } case CallType::GEMM:{ // C = alpha * A^transA * B^transB + beta * C - auto C = inputs[pointer_to_index(rcall.pout_arg1)]; - auto A = inputs[pointer_to_index(rcall.pin_arg1)]; - auto B = inputs[pointer_to_index(rcall.pin_arg2)]; + auto C = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + auto B = pointer_to_index(rcall.pin_arg2, inputs); auto layout = rcall.layout; - auto transA = rcall.targ1; - auto transB = rcall.targ2; + auto transA_char = rcall.targ1; + auto transA = !(transA_char == 'N' || transA_char == 'n'); + auto transB_char = rcall.targ2; + auto transB = !(transB_char == 'N' || transB_char == 'n'); auto M = rcall.iarg1; auto N = rcall.iarg2; auto K = rcall.iarg3; @@ -655,16 +856,16 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vec case CallType::SCAL: { auto N = rcall.iarg1; auto alpha = rcall.farg1; - auto X = inputs[pointer_to_index(rcall.pout_arg1)]; + auto X = pointer_to_index(rcall.pout_arg1, inputs); auto incX = rcall.iarg4; checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); return; } case CallType::GER: { // A = alpha * X * transpose(Y) + A - auto A = inputs[pointer_to_index(rcall.pout_arg1)]; - auto X = inputs[pointer_to_index(rcall.pin_arg1)]; - auto Y = inputs[pointer_to_index(rcall.pin_arg2)]; + auto A = pointer_to_index(rcall.pout_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg1, inputs); + auto Y = pointer_to_index(rcall.pin_arg2, inputs); auto layout = rcall.layout; auto M = rcall.iarg1; @@ -683,24 +884,44 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vec return; } case CallType::COPY: { - auto Y = inputs[pointer_to_index(rcall.pout_arg1)]; - auto X = inputs[pointer_to_index(rcall.pin_arg1)]; + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg1, inputs); auto N = rcall.iarg1; auto incX = rcall.iarg4; + auto incY = rcall.iarg5; checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); - checkVector(Y, "Y", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } + case CallType::LACPY: { + auto B = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + + auto layout = rcall.layout; + auto uplo = rcall.targ1; + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto lda = rcall.iarg4; + auto ldb = rcall.iarg5; + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + checkMatrix(B, "B", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldb, test, rcall, trace); return; } default: printf("UNKNOWN CALL (%d)", (int)rcall.type); return; } } -void checkMemoryTrace(BlasInfo inputs[3], std::string test, const vector & trace) { +void checkMemoryTrace(BlasInfo inputs[6], std::string test, const vector & trace) { for (size_t i=0; i 2); + auto A_cache = (double*)foundCalls[0].pout_arg1; + cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, M); + inputs[4] = BlasInfo(A_cache, layout, M, N, M); + auto B_cache = (double*)foundCalls[1].pout_arg1; + cblas_dcopy(trans ? M : N, B, incB, B_cache, 1); + inputs[5] = BlasInfo(B_cache, trans ? M : N, 1); + + ow_dgemv(layout, transA, M, N, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + // dC = alpha * X * transpose(Y) + A + cblas_dger(layout, M, N, alpha, + trans ? B_cache : dC, + trans ? 1 : incC, + trans ? dC : B_cache, + trans ? incC : 1, dA, + lda); + + // dB = alpha * trans(A) * dC + dB + cblas_dgemv(layout, transpose(transA), M, N, alpha, A_cache, M, dC, incC, 1.0, dB, incB); + + // dY = beta * dY + cblas_dscal(trans ? N : M, beta, dC, incC); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + inputs[4] = BlasInfo(); + inputs[5] = BlasInfo(); + } + + + } + } +} + +static void gemmTests() { + // N means normal matrix, T means transposed + for (char layout : { CblasRowMajor, CblasColMajor }) { + for (char transA : {'N', 'n', 'T', 't'}) { + for (char transB : {'N', 'n', 'T', 't'}) { + + { + + bool transA_bool = !(transA == 'N' || transA == 'n'); + bool transB_bool = !(transA == 'N' || transA == 'n'); + std::string Test = "GEMM"; + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, transA_bool ? K : M, transA_bool ? M : K, lda), + /*B*/ BlasInfo(B, layout, transB_bool ? N : K , transA_bool ? K : N, incB), + /*C*/ BlasInfo(C, layout, M, N, incC), + BlasInfo(), + BlasInfo(), + BlasInfo() + }; + init(); + my_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + assert(calls.size() == 1); + assert(calls[0].inDerivative == false); + assert(calls[0].type == CallType::GEMM); + assert(calls[0].pout_arg1 == C); + assert(calls[0].pin_arg1 == A); + assert(calls[0].pin_arg2 == B); + assert(calls[0].farg1 == alpha); + assert(calls[0].farg2 == beta); + assert(calls[0].layout == layout); + assert(calls[0].targ1 == transA); + assert(calls[0].targ2 == transB); + assert(calls[0].iarg1 == M); + assert(calls[0].iarg2 == N); + assert(calls[0].iarg3 == K); + assert(calls[0].iarg4 == lda); + assert(calls[0].iarg5 == incB); + assert(calls[0].iarg6 == incC); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void*) my_dgemm, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, transB, + enzyme_const, M, + enzyme_const, N, + enzyme_const, K, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + + my_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + // dA = + my_dgemm(layout, + transA_bool ? transpose(transB) : transA, + transA_bool ? transA : transpose(transB), + transA_bool ? K : M, + transA_bool ? M : K, + N, + alpha, + transA_bool ? B : dC, + transA_bool ? incB : incC, + transA_bool ? C : dB, + transA_bool ? incC : incB, + 1.0, dA, lda); + + // dB = + my_dgemm(layout, + transB_bool ? transB : transpose(transA), + transB_bool ? transA : transB, + transB_bool ? N : K, + transB_bool ? K : N, + M, + alpha, + transB_bool ? dC : A, + transB_bool ? incC : lda, + transB_bool ? A : dC, + transB_bool ? lda : incC, + 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC /*, extra 0*/ ); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + } } } + } +} + +int main() { + + dotTests(); + + gemvTests(); + + // gemmTests(); } diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index d8a3d7c7fffa..5123aec615f1 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -53,24 +53,6 @@ static void checkBlasCallsInDag(const RecordKeeper &RK, checkBlasCallsInDag(RK, blasPatterns, blasName, arg); } } - - auto Def = cast(toSearch->getOperator())->getDef(); - if (Def->isSubClassOf("b")) { - auto numArgs = toSearch->getNumArgs(); - auto opName = Def->getValueAsString("s"); - auto CalledBlas = RK.getDef(opName); - if (!CalledBlas) - errs() << " opName: " << opName << "\n"; - assert(CalledBlas); - auto expectedNumArgs = - CalledBlas->getValueAsDag("PatternToMatch")->getNumArgs(); - if (expectedNumArgs != numArgs) { - errs() << "failed calling " << opName << " in the derivative of " - << blasName << " incorrect number of params. Expected " - << expectedNumArgs << " but got " << numArgs << "\n"; - assert(expectedNumArgs == numArgs); - } - } } /// Here we check that all the Blas derivatives who call another @@ -803,6 +785,7 @@ std::string get_blas_ret_ty(StringRef dfnc_name) { return "Builder2.getVoidTy()"; } +/* void emit_deriv_blas_call(DagInit *ruleDag, const StringMap &patternMap, StringSet<> &handled, raw_ostream &os) { @@ -834,7 +817,11 @@ void emit_deriv_blas_call(DagInit *ruleDag, if (Def->isSubClassOf("DiffeRetIndex")) { typeToAdd = "byRef ? PointerType::getUnqual(call.getType()) : " "call.getType()\n"; - } else if (Def->isSubClassOf("input") || Def->isSubClassOf("adj")) { + } else if (Def->isSubClassOf("adj")) { + auto argStr = Def->getValueAsString("name"); + // primary and adj have the same type + typeToAdd = (Twine("type_") + argStr).str(); + } else if (Def->isSubClassOf("input")) { auto argStr = Def->getValueAsString("name"); // primary and adj have the same type typeToAdd = (Twine("type_") + argStr).str(); @@ -926,6 +913,7 @@ void emit_deriv_blas_call(DagInit *ruleDag, << " }\n\n"; return; } +*/ void emit_tmp_creation(Record *Def, raw_ostream &os) { const auto args = Def->getValueAsListOfStrings("args"); @@ -997,7 +985,7 @@ void emit_deriv_rule(const StringMap &patternMap, Rule &rule, const auto nameMap = rule.getArgNameMap(); const auto Def = cast(ruleDag->getOperator())->getDef(); if (Def->isSubClassOf("b")) { - emit_deriv_blas_call(ruleDag, patternMap, handled, os); + // emit_deriv_blas_call(ruleDag, patternMap, handled, os); } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "noop") { // nothing to prepare } else if (Def->isSubClassOf("DiffeRetIndex")) { @@ -1015,103 +1003,137 @@ void emit_deriv_rule(const StringMap &patternMap, Rule &rule, const auto sub_Def = sub_def->getDef(); if (sub_Def->isSubClassOf("b")) { os << " //handling nested blas: " << std::to_string(i) << "\n"; - emit_deriv_blas_call(sub_Dag, patternMap, handled, os); + // emit_deriv_blas_call(sub_Dag, patternMap, handled, os); os << " //handled nested blas: " << std::to_string(i) << "\n"; } else if (sub_Def->isSubClassOf("FrobInnerProd")) { // nothing to prepare - assert(sub_Dag->getNumArgs() == 5); + assert(sub_Dag->getNumArgs() == 4); } else if (sub_Def->isSubClassOf("DiagUpdateSPMV")) { // nothing to prepare - assert(sub_Dag->getNumArgs() == 8); + assert(sub_Dag->getNumArgs() == 6); } } } } else if (Def->isSubClassOf("FrobInnerProd")) { // nothing to prepare - assert(ruleDag->getNumArgs() == 5); + assert(ruleDag->getNumArgs() == 4); } else if (Def->isSubClassOf("DiagUpdateSPMV")) { // nothing to prepare - assert(ruleDag->getNumArgs() == 8); + assert(ruleDag->getNumArgs() == 6); } else { PrintFatalError("Unhandled deriv Rule!"); } } -void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, - size_t actArg, size_t &pos, raw_ostream &os) { +// Emit the corresponding code rom (ruleDag arg # pos), given +// that the arg being differentiated is argAct. +// The map offsetToBaseNames takes vinc, ld, and maps them to +// the arg name of the original vector/matrix +void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, + raw_ostream &os) { const auto nameMap = rule.getArgNameMap(); const auto typeMap = rule.getArgTypeMap(); auto arg = ruleDag->getArg(pos); if (auto Dag = dyn_cast(arg)) { auto Def = cast(Dag->getOperator())->getDef(); - if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") { - std::string tname, rname, cname; - tname = (Twine("arg_") + Dag->getArgNameStr(0)).str(); - if (DefInit *Def1 = dyn_cast(Dag->getArg(1))) { - auto Def1Name = Def1->getDef()->getValueAsString("name"); - assert(Def1->getDef()->isSubClassOf("adj")); - rname = (Twine("d_") + Def1Name).str(); - } else { - rname = (Twine("arg_") + Dag->getArgNameStr(1)).str(); + if (Def->isSubClassOf("MagicInst")) { + if (Def->getName() == "Rows") { + os << "get_blas_row(Builder2, "; + for (size_t i = 0; i < Dag->getNumArgs(); i++) { + rev_call_arg(Dag, rule, actArg, i, os); + os << ", "; + } + os << "byRef)"; + return; } - if (DefInit *Def2 = dyn_cast(Dag->getArg(2))) { - auto Def2Name = Def2->getDef()->getValueAsString("name"); - assert(Def2->getDef()->isSubClassOf("adj")); - cname = (Twine("d_") + Def2Name).str(); - } else { - cname = (Twine("arg_") + Dag->getArgNameStr(2)).str(); + if (Def->getName() == "ld") { + assert(Dag->getNumArgs() == 5); + //(ld $A, $transa, $lda, $m, $k) + const auto ldName = Dag->getArgNameStr(2); + const auto dim1Name = Dag->getArgNameStr(3); + const auto dim2Name = Dag->getArgNameStr(4); + const auto matName = Dag->getArgNameStr(0); + os << "{get_cached_mat_width(Builder2, "; + rev_call_arg(Dag, rule, actArg, 1, os); + os << ", arg_" << ldName << ", arg_" << dim1Name << ", arg_" << dim2Name + << ", cache_" << matName << ", byRef)}"; + return; } - os << "get_blas_row(Builder2, " << tname << ", " << rname << ", " << cname - << ", byRef)"; - } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") { - assert(Dag->getNumArgs() == 5); - //(ld $A, $transa, $lda, $m, $k) - const auto transName = Dag->getArgNameStr(1); - const auto ldName = Dag->getArgNameStr(2); - const auto dim1Name = Dag->getArgNameStr(3); - const auto dim2Name = Dag->getArgNameStr(4); - const auto matName = Dag->getArgNameStr(0); - os << "get_cached_mat_width(Builder2, " - << "arg_" << transName << ", arg_" << ldName << ", arg_" << dim1Name - << ", arg_" << dim2Name << ", cache_" << matName << ", byRef)"; - } else { - errs() << Def->getName() << "\n"; - PrintFatalError("Dag/Def that isn't a DiffeRet!!"); } + + errs() << Def->getName() << "\n"; + PrintFatalError("Dag/Def that isn't a DiffeRet!!"); } else if (DefInit *DefArg = dyn_cast(arg)) { auto Def = DefArg->getDef(); if (Def->isSubClassOf("DiffeRetIndex")) { - os << "dif"; + os << "{dif}"; } else if (Def->isSubClassOf("adj")) { auto name = Def->getValueAsString("name"); - os << "d_" << name; + os << "{d_" << name; + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < rule.nameVec.size(); i++) { + if (rule.nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = rule.argTypesFull.lookup(argPosition); + auto incName = rule.nameVec[argPosition + 1]; + if (ty == ArgType::vincData || ty == ArgType::mldData) + os << ", arg_" << incName; + else + assert(ty == ArgType::fp || ty == ArgType::ap); + os << "}"; } else if (Def->isSubClassOf("input")) { auto name = Def->getValueAsString("name"); - os << "input_" << name; + os << "{input_" << name; + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < rule.nameVec.size(); i++) { + if (rule.nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = rule.argTypesFull.lookup(argPosition); + auto incName = rule.nameVec[argPosition + 1]; + if (ty == ArgType::vincData) + os << ", (cache_" << name << " ? const_one : arg_" << incName << ")"; + else + assert(ty == ArgType::fp || ty == ArgType::ap || + ty == ArgType::mldData); + os << "}"; } else if (Def->isSubClassOf("use")) { auto name = Def->getValueAsString("name"); - os << "mat_" << name; - } else if (Def->isSubClassOf("MagicInst")) { - errs() << "MagicInst\n"; + os << "{mat_" << name << "}"; } else if (Def->isSubClassOf("Constant")) { auto val = Def->getValueAsString("value"); - os << "to_blas_fp_callconv(Builder2, ConstantFP::get(fpType, " << val + os << "{to_blas_fp_callconv(Builder2, ConstantFP::get(fpType, " << val << "), byRef, blasFPType, allocationBuilder, \"constant.fp." << val - << "\")"; + << "\")}"; } else if (Def->isSubClassOf("Char")) { auto val = Def->getValueAsString("value"); - os << "to_blas_callconv(Builder2, ConstantInt::get(charType, '" << val + os << "{to_blas_callconv(Builder2, ConstantInt::get(charType, '" << val << "'), byRef, nullptr, allocationBuilder, \"constant.char." << val - << "\")"; + << "\")}"; } else if (Def->isSubClassOf("ConstantInt")) { auto val = Def->getValueAsInt("value"); - os << "to_blas_callconv(Builder2, ConstantInt::get(intType, " << val + os << "{to_blas_callconv(Builder2, ConstantInt::get(intType, " << val << "), byRef, intType, allocationBuilder, \"constant.int." << val - << "\")"; + << "\")}"; } else if (Def->isSubClassOf("transpose")) { auto name = Def->getValueAsString("name"); - os << "arg_transposed_" << name; + os << "{arg_transposed_" << name << "}"; } else { errs() << Def->getName() << "\n"; PrintFatalError("Def that isn't a DiffeRet!"); @@ -1132,74 +1154,42 @@ void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, // and based on that get the fp/int + scalar/vector type auto ty = typeMap.lookup(argPosition); - // Now we create the adj call args through concating type and primal name - if (ty == ArgType::len) { - os << "arg_" << name; - } else if (ty == ArgType::fp || ty == ArgType::ap || - ty == ArgType::vincData) { + switch (ty) { + case ArgType::cblas_layout: + case ArgType::len: + case ArgType::fp: + case ArgType::ap: + case ArgType::trans: + case ArgType::diag: + case ArgType::uplo: + case ArgType::side: + case ArgType::vincInc: + case ArgType::vincData: + case ArgType::mldData: { + os << "{"; if (argPosition == actArg) { os << "d_" << name; } else { os << "arg_" << name; } - } else if (ty == ArgType::vincInc) { - auto prevArg = ruleDag->getArg(pos - 1); - if (DefInit *DefArg = dyn_cast(prevArg)) { - auto Def = DefArg->getDef(); - if (Def->isSubClassOf("adj")) { - // all ok, single inc after shadow of vec - // use original inc, since shadow is never cached - os << "arg_" << name; - } else { - auto prevName = Def->getValueAsString("name"); - os << "(cache_" << prevName << " ? const_one : arg_" << name << ")"; - } - } else { - auto prevName = ruleDag->getArgNameStr(pos - 1); - os << "(cache_" << prevName << " ? const_one : arg_" << name << ")"; - } - } else if (ty == ArgType::mldData) { - // TODO: update this to use width_ instead of true_, - // similar to the vector inc case - auto nextName = ruleDag->getArgNameStr(pos + 1); - // get the position of the argument in the primary blas call - auto nextArgPosition = nameMap.lookup(nextName); - // and based on that get the fp/int + scalar/vector type - auto nextTy = typeMap.lookup(nextArgPosition); - if (pos == actArg) { - assert(nextTy == ArgType::mldLD); - os << "d_" << name << ", true_" << nextName; - pos++; // extra ++ due to also handling mldLD - } else { - // if this matrix got cached, we need more complex logic - // to determine the next arg. Thus handle it once we reach it - os << "arg_" << name; + if (ty == ArgType::vincData) { + auto incName = rule.nameVec[argPosition + 1]; + os << ", (cache_" << name << " ? const_one : arg_" << incName << ")"; } - } else if (ty == ArgType::mldLD) { - auto prevArg = ruleDag->getArg(pos - 1); - if (DefInit *DefArg = dyn_cast(prevArg)) { - auto Def = DefArg->getDef(); - if (Def->isSubClassOf("adj")) { - // all ok, single LD after shadow of mat - // use original ld, since shadow is never cached - os << "arg_" << name; + if (ty == ArgType::mldData) { + auto ldName = rule.nameVec[argPosition + 1]; + if (argPosition == actArg) { + os << ", true_" << ldName; } else { - errs() << rule.to_string() << "\n"; - PrintFatalError("sholdn't be hit?\n"); + // if this matrix got cached, we need more complex logic + // to determine the next arg. Thus handle it once we reach it } - } else { - errs() << rule.to_string() << "\n"; - llvm::errs() << "name: " << name << " typename: " << ty << "\n"; - PrintFatalError("shouldn't be hit??\n"); } - } else if (ty == ArgType::trans || ty == ArgType::diag || - ty == ArgType::uplo || ty == ArgType::side) { - os << "arg_" << name; - // Extra handled in the calling function, so - // if we are here for a layout arg something went wrong (error) - //} else if (ty == ArgType::cblas_layout) { - // os << "arg_" << name; - } else { + + os << "}"; + return; + } + default: errs() << "name: " << name << " typename: " << ty << "\n"; llvm_unreachable("unimplemented input type in reverse mode!\n"); } @@ -1208,7 +1198,7 @@ void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, // fill the result string and return the number of added args void rev_call_args(StringRef argName, Rule &rule, size_t actArg, - raw_ostream &os, int subRule = -1) { + raw_ostream &os, int subRule, StringRef func) { const auto nameMap = rule.getArgNameMap(); @@ -1221,42 +1211,36 @@ void rev_call_args(StringRef argName, Rule &rule, size_t actArg, numArgs = ruleDag->getNumArgs(); } + os << " std::vector" << argName << ";\n"; + // layout exist only under the cBLas ABI and not for all fncs. bool fncHasLayout = (ruleDag->getArgNameStr(0) == "layout"); - if (!fncHasLayout) { - os << " std::vector" << argName << " = {"; - for (size_t pos = 0; pos < numArgs;) { - if (pos > 0) { - os << ", "; - } - - rev_call_arg(argName, ruleDag, rule, actArg, pos, os); - pos++; - } - os << "};\n"; - return; + if (fncHasLayout) { + // Fnc has a layout if cBLAS, that makes it more complex. + // Distinguish later trough byRef if it is cblas (thus has layout) + os << " if (!byRef) " << argName << ".push_back(arg_layout);\n"; } - // Fnc has a layout if cBLAS, that makes it more complex. - // Distinguish later trough byRef if it is cblas (thus has layout) - - os << " std::vector" << argName << ";\n"; - os << " if (!byRef) " << argName << ".push_back(arg_layout);\n"; - os << " auto tmp = {\n"; - // just replace argOps with rule - for (size_t pos = 1; pos < numArgs;) { - if (pos > 1) { - os << ", "; - } - rev_call_arg(argName, ruleDag, rule, actArg, pos, os); - pos++; + for (size_t pos = fncHasLayout ? 1 : 0; pos < numArgs; pos++) { + os << " for (auto item : "; + rev_call_arg(ruleDag, rule, actArg, pos, os); + os << ") " << argName << ".push_back(item);\n"; } - os << "};\n"; - os << " for (auto val : tmp) " << argName << ".push_back(val);\n"; + os << " if (byRef) {\n"; + int n = 0; + if (func == "gemv") + n = 1; + if (func == "gemm") + n = 2; + for (int i = 0; i < n; i++) + os << " " << argName + << ".push_back(ConstantInt::get(intType, 1));\n"; + os << " }\n"; } void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name, StringRef bb, raw_ostream &os) { + os << "{\n"; if (dfnc_name == "inner_prod") { os << " auto derivcall_inner_prod = \n" " getorInsertInnerProd(" @@ -1268,6 +1252,21 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name, << " CallInst *cubcall = " "cast(derivcall_inner_prod);\n"; } else { + os << " SmallVector tys; for (auto arg : " << argName + << ") tys.push_back(arg->getType());\n"; + std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); + os << " llvm::FunctionType *FT" << dfnc_name << " = FunctionType::get(" + << dfnc_ret_ty << ", tys, false);\n"; + os << " auto derivcall_" << dfnc_name + << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" + << " (blas.prefix + blas.floatType + \"" << dfnc_name + << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; + + os << " if (auto F = dyn_cast(derivcall_" << dfnc_name + << ".getCallee()))\n" + << " {\n" + << " attribute_" << dfnc_name << "(blas, F);\n" + << " }\n\n"; os << " CallInst *cubcall = " "cast(" << bb << ".CreateCall(derivcall_" << dfnc_name << ", " << argName @@ -1282,6 +1281,7 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name, << " addToDiffe(orig_" << name << ", cubcall, " << bb << ", fpType);\n" << " }\n"; + os << "}\n"; } // todo: update rt_active_ to use actual dag requirements, @@ -1500,11 +1500,11 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, emit_if_rule_condition(ruleDag, name, " ", os); emit_runtime_condition(ruleDag, name, " ", "Builder2", (ty == ArgType::fp), os); - rev_call_args("args1", rule, actArg, os); + const auto dfnc_name = Def->getValueAsString("s"); + rev_call_args("args1", rule, actArg, os, -1, dfnc_name); os << " const auto Defs = gutils->getInvertedBundles(&call, {" << valueTypes << "}, Builder2, /* lookup */ true);\n"; - const auto dfnc_name = Def->getValueAsString("s"); if (ty == ArgType::fp) { // extra handling, since we will update only a fp scalar as part of the // return struct it's presumably done by setting it to the value @@ -1512,8 +1512,23 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, emit_fret_call(dfnc_name, "ArrayRef(args1)", name, "Builder2", os); } else { + os << " SmallVector tys; for (auto arg : args1) " + "tys.push_back(arg->getType());\n"; + std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); + os << " llvm::FunctionType *FT" << dfnc_name + << " = FunctionType::get(" << dfnc_ret_ty << ", tys, false);\n"; + os << " auto derivcall_" << dfnc_name + << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" + << " (blas.prefix + blas.floatType + \"" << dfnc_name + << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; + + os << " if (auto F = dyn_cast(derivcall_" << dfnc_name + << ".getCallee()))\n" + << " {\n" + << " attribute_" << dfnc_name << "(blas, F);\n" + << " }\n\n"; os << " Builder2.CreateCall(derivcall_" << dfnc_name - << ", ArrayRef(args1), Defs);\n"; + << ", args1, Defs);\n"; } emit_runtime_continue(ruleDag, name, " ", "Builder2", (ty == ArgType::fp), os); @@ -1525,7 +1540,7 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, os << " // DiagUpdateSPMV\n"; emit_if_rule_condition(ruleDag, name, " ", os); emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os); - rev_call_args("args1", rule, actArg, os); + rev_call_args("args1", rule, actArg, os, -1, ""); os << " const auto Defs = gutils->getInvertedBundles(&call, {" << valueTypes << "}, Builder2, /* lookup */ true);\n"; // Now that we have the defs, we can create the call @@ -1541,7 +1556,7 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, os << " // FrobInnerProd\n"; emit_if_rule_condition(ruleDag, name, " ", os); emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os); - rev_call_args("args1", rule, actArg, os); + rev_call_args("args1", rule, actArg, os, -1, ""); os << " const auto Defs = gutils->getInvertedBundles(&call, {" << valueTypes << "}, Builder2, /* lookup */ true);\n"; // Now that we have the defs, we can create the call @@ -1565,32 +1580,52 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, // handle seq rules for (size_t i = 0; i < ruleDag->getNumArgs(); i++) { - std::string argName = "args" + std::to_string(i); - rev_call_args(argName, rule, actArg, os, i); Init *subArg = ruleDag->getArg(i); DagInit *sub_Dag = cast(subArg); if (auto sub_def = dyn_cast(sub_Dag->getOperator())) { const auto sub_Def = sub_def->getDef(); if (sub_Def->isSubClassOf("b")) { const auto dfnc_name = sub_Def->getValueAsString("s"); + std::string argName = "args" + std::to_string(i); + rev_call_args(argName, rule, actArg, os, i, dfnc_name); os << " //handling nested blas: " << std::to_string(i) << "\n"; - emit_deriv_blas_call(sub_Dag, patternMap, handled, os); + // emit_deriv_blas_call(sub_Dag, patternMap, handled, os); if (get_blas_ret_ty(dfnc_name) == "fpType") { // returns, so assume it's the last step of the sequence // and update the diffe accordingly assert(i == ruleDag->getNumArgs() - 1); emit_fret_call(dfnc_name, argName, name, "Builder2", os); } else { + os << " SmallVector tys; for (auto arg : " << argName + << ") tys.push_back(arg->getType());\n"; + std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); + os << " llvm::FunctionType *FT" << dfnc_name + << " = FunctionType::get(" << dfnc_ret_ty + << ", tys, false);\n"; + os << " auto derivcall_" << dfnc_name + << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" + << " (blas.prefix + blas.floatType + \"" << dfnc_name + << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; + + os << " if (auto F = dyn_cast(derivcall_" + << dfnc_name << ".getCallee()))\n" + << " {\n" + << " attribute_" << dfnc_name << "(blas, F);\n" + << " }\n\n"; os << " Builder2.CreateCall(derivcall_" << dfnc_name << ", " << argName << ", Defs);\n"; } os << " //handled nested blas: " << std::to_string(i) << "\n"; } else if (sub_Def->isSubClassOf("FrobInnerProd")) { - assert(sub_Dag->getNumArgs() == 5); + std::string argName = "args" + std::to_string(i); + rev_call_args(argName, rule, actArg, os, i, ""); + assert(sub_Dag->getNumArgs() == 4); assert(ty == ArgType::fp); emit_fret_call("inner_prod", argName, name, "Builder2", os); } else if (sub_Def->isSubClassOf("DiagUpdateSPMV")) { - assert(sub_Dag->getNumArgs() == 8); + std::string argName = "args" + std::to_string(i); + rev_call_args(argName, rule, actArg, os, i, ""); + assert(sub_Dag->getNumArgs() == 6); assert(ty == ArgType::ap); os << "callSPMVDiagUpdate(Builder2, *gutils->oldFunc->getParent(), " "blas, intType, blasCharType, blasFPType, type_vec_like, " diff --git a/enzyme/tools/enzyme-tblgen/caching.cpp b/enzyme/tools/enzyme-tblgen/caching.cpp index cefbec9b0ba7..bede971bc5af 100644 --- a/enzyme/tools/enzyme-tblgen/caching.cpp +++ b/enzyme/tools/enzyme-tblgen/caching.cpp @@ -281,7 +281,7 @@ os << " if (EnzymeBlasCopy) {\n" << " auto *len2 = load_if_ref(BuilderZ, intType, N, byRef);\n" << " auto *matSize = BuilderZ.CreateMul(len1, len2);\n" << " auto malins = CreateAllocation(BuilderZ, fpType, matSize, \"cache." << matName << "\");\n" -<< " ValueType valueTypes[] = {" << valueTypes << "};\n" +<< " SmallVector valueTypes = {" << valueTypes << "};\n" <<" valueTypes[" << argIdx << "] = ValueType::Primal;\n" << " if (byRef) valueTypes[" << argIdx+1 << "] = ValueType::Primal;\n"; for (auto len_pos : dimensions ) { @@ -290,7 +290,9 @@ os << " if (byRef) valueTypes[" << len_pos << "] = ValueType::Primal;\n"; os << " if (EnzymeLapackCopy) {\n" << " Value *uplo = llvm::ConstantInt::get(charTy, 0);\n" // garbage data, just should not match U or L << " uplo = to_blas_callconv(BuilderZ, uplo, byRef, nullptr, allocationBuilder, \"copy.garbage\");\n" -<< " Value *args[7] = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" +<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" +<< " if (!byRef) {\n" +<< " args.insert(args.begin(), arg_layout); valueTypes.insert(valueTypes.begin(), ValueType::Primal); }\n" << " callMemcpyStridedLapack(BuilderZ, *gutils->oldFunc->getParent(), blas, args, gutils->getInvertedBundles(&call, valueTypes, BuilderZ, /*lookup*/false));\n" << " } else {\n" << " auto dmemcpy = getOrInsertMemcpyMat(*gutils->oldFunc->getParent(), fpType, cast(malins->getType()), intType, 0, 0);\n" diff --git a/enzyme/tools/enzyme-tblgen/datastructures.cpp b/enzyme/tools/enzyme-tblgen/datastructures.cpp index d38dddaa456c..2d7a14d26e9f 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.cpp +++ b/enzyme/tools/enzyme-tblgen/datastructures.cpp @@ -60,11 +60,13 @@ bool isVecLikeArg(ArgType ty) { return false; } -bool isArgUsed(StringRef toFind, const DagInit *toSearch) { +bool isArgUsed(StringRef toFind, const DagInit *toSearch, + ArrayRef nameVec, + const DenseMap &argTypesFull) { for (size_t i = 0; i < toSearch->getNumArgs(); i++) { if (DagInit *arg = dyn_cast(toSearch->getArg(i))) { // os << " Recursing. Magic!\n"; - if (isArgUsed(toFind, arg)) + if (isArgUsed(toFind, arg, nameVec, argTypesFull)) return true; } else { auto name = toSearch->getArgNameStr(i); @@ -80,30 +82,79 @@ bool isArgUsed(StringRef toFind, const DagInit *toSearch) { if (toFind == transName) { return true; } - } else if (opName == "adj" || Def->isSubClassOf("adj")) { + } else if (opName == "adj" || Def->isSubClassOf("adj") || + opName == "input" || Def->isSubClassOf("input")) { // shadow is unrelated, ignore it + // However, consider the extra added inc. + + auto name = Def->getValueAsString("name"); + + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < nameVec.size(); i++) { + if (nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = argTypesFull.lookup(argPosition); + if (ty == ArgType::vincData || + ((opName == "adj" || Def->isSubClassOf("adj")) && + ty == ArgType::mldData)) { + auto incName = nameVec[argPosition + 1]; + if (incName == toFind) + return true; + } } } else { if (name == toFind) { return true; } + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < nameVec.size(); i++) { + if (nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = argTypesFull.lookup(argPosition); + if (ty == ArgType::vincData || ty == ArgType::mldData) { + auto incName = nameVec[argPosition + 1]; + if (incName == toFind) + return true; + } } } } return false; } -Rule::Rule(DagInit *dag, size_t activeArgIdx, +Rule::Rule(ArrayRef nameVec, DagInit *dag, size_t activeArgIdx, const StringMap &patternArgs, const DenseMap &patternTypes, const DenseSet &patternMutables) - : rewriteRule(dag), activeArg(activeArgIdx) { + : rewriteRule(dag), activeArg(activeArgIdx), + nameVec(nameVec.begin(), nameVec.end()) { // For each arg found in the dag: // 1) copy patternArgs to ruleArgs if arg shows up in this rule for (auto argName : patternArgs.keys()) { assert(patternArgs.count(argName) == 1); size_t argPos = patternArgs.lookup(argName); - bool argUsedInRule = isArgUsed(argName, rewriteRule); + argTypesFull.insert(*patternTypes.find(argPos)); + } + for (auto argName : patternArgs.keys()) { + assert(patternArgs.count(argName) == 1); + size_t argPos = patternArgs.lookup(argName); + bool argUsedInRule = isArgUsed(argName, rewriteRule, nameVec, argTypesFull); if (argUsedInRule) { argNameToPos.insert(std::pair(argName, argPos)); // 2) look up and copy the corresponding argType @@ -331,7 +382,7 @@ TGPattern::TGPattern(Record *r) : blasName(r->getNameInitAsString()) { DagInit *derivRule = cast(derivOp.value()); size_t actIdx = posActArgs[derivOp.index()]; rules.push_back( - Rule(derivRule, actIdx, argNameToPos, argTypes, mutables)); + Rule(args, derivRule, actIdx, argNameToPos, argTypes, mutables)); } } diff --git a/enzyme/tools/enzyme-tblgen/datastructures.h b/enzyme/tools/enzyme-tblgen/datastructures.h index 3f18e541e870..3416f3105491 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.h +++ b/enzyme/tools/enzyme-tblgen/datastructures.h @@ -40,7 +40,9 @@ using namespace llvm; const char *TyToString(ArgType ty); bool isVecLikeArg(ArgType ty); -bool isArgUsed(StringRef toFind, const DagInit *toSearch); +bool isArgUsed(StringRef toFind, const DagInit *toSearch, + llvm::ArrayRef nameVec, + const llvm::DenseMap &argTypesFull); /// Subset of the general pattern info, /// but only the part that affects the specific argument being active. @@ -55,7 +57,10 @@ class Rule { bool BLASLevel2or3; public: - Rule(DagInit *dag, size_t activeArgIdx, const StringMap &patternArgs, + SmallVector nameVec; + DenseMap argTypesFull; + Rule(ArrayRef nameVec, DagInit *dag, size_t activeArgIdx, + const StringMap &patternArgs, const DenseMap &patternTypes, const DenseSet &patternMutables); bool isBLASLevel2or3() const;