Skip to content

Commit

Permalink
Add dot blas test (#1446)
Browse files Browse the repository at this point in the history
* Add dot blas test

* Now adding caching test

* Cleanup and progress

* cleanup

* Fix autodiff ordering with inlining

* Rebase

* Fix fortran calling conv
  • Loading branch information
wsmoses authored Sep 20, 2023
1 parent 53a15ad commit edd0331
Show file tree
Hide file tree
Showing 28 changed files with 1,082 additions and 481 deletions.
52 changes: 25 additions & 27 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def scal : CallBlasPattern<(Op $n, $alpha, $x, $incx),
["x"],[len, fp, vinc<["n"]>],
[
// dot must proceed scal, because scal modifies adj<"x">
(b<"dot"> $n, $x, $incx, adj<"x">, $incx),
(b<"scal"> $n, $alpha, adj<"x">, $incx)
(b<"dot"> $n, $x, adj<"x">),
(b<"scal"> $n, $alpha, adj<"x">)
]
>;

Expand All @@ -134,17 +134,17 @@ def lascl : CallBlasPattern<(Op $layout, $type, $kl, $ku, $cfrom, $cto, $m, $n,
def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy),
["y"],[len, fp, vinc<["n"]>, vinc<["n"]>],
[
(b<"dot"> $n, adj<"y">, $incy, $x, $incx),
(b<"axpy"> $n, $alpha, adj<"y">, $incy, adj<"x">, $incx),
(b<"dot"> $n, adj<"y">, $x),
(b<"axpy"> $n, $alpha, adj<"y">, adj<"x">),
(noop) // y = alpha*x + y, so nothing to do here
]
>;

def dot : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
[],[len, vinc<["n"]>, vinc<["n"]>],
[
(b<"axpy"> $n, DiffeRet, $y, $incy, adj<"x">, $incx),
(b<"axpy"> $n, DiffeRet, $x, $incx, adj<"y">, $incy)
(b<"axpy"> $n, DiffeRet, $y, adj<"x">),
(b<"axpy"> $n, DiffeRet, $x, adj<"y">),
]
>;

Expand All @@ -158,7 +158,7 @@ def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
["y"],[len, vinc<["n"]>, vinc<["n"]>],
[
(noop),// copy moves x into y, so x is never modified.
(b<"axpy"> $n, Constant<"1.0">, adj<"y">, $incy, adj<"x">, $incx)
(b<"axpy"> $n, Constant<"1.0">, adj<"y">, adj<"x">)
]
>;

Expand All @@ -184,22 +184,20 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $
["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>],
[
/* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]>
(b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $n), $x, $incx, Constant<"0.0">, use<"Ax">, ConstantInt<1>),
(b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, use<"Ax">, ConstantInt<1>)),
(b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>),
(b<"dot"> (Rows $transa, $m, $n), adj<"y">, use<"Ax">, ConstantInt<1>)),

//if (is_normal $transa) {
// call sger(m, n, alpha, ya, incy, x, incx, Aa, lda)
//} else {
// call sger(m, n, alpha, x, incx, ya, incy, Aa, lda)
//}
/* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, adj<"y">, $x),
(Rows $transa, $incy, $incx),
(Rows $transa, $x, adj<"y">),
(Rows $transa, $incx, $incy),
adj<"A">, $lda),
/* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, $transa, $lda, $m, $n), adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx),
/* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, input<"y">, $incy),
/* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">, $incy)
adj<"A">),
/* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">),
/* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">),
/* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">)
]
>;
//
Expand All @@ -226,11 +224,11 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A

/* alpha */ (Seq<["AB", "product", "m", "n"]>
(b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n
(FrobInnerProd<""> $m, $n, adj<"C">, $ldc, use<"AB">)),
/* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $ldc, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">, $lda),
/* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $ldc, $beta, adj<"B">, $ldb),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, $ldc, ConstantInt<0>)
(FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)),
/* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">),
/* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $beta, adj<"B">),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, ConstantInt<0>)
]
>;

Expand All @@ -239,14 +237,14 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta
[cblas_layout, uplo, len, fp, ap<["n"]>, vinc<["n"]>, fp, vinc<["n"]>],
[
/* alpha */ (Seq<["y0", "triangular", "n"]>
(b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, $incx, Constant<"0.0">, use<"y0">, ConstantInt<1>),
(b<"dot"> $n, adj<"y">, $incy, use<"y0">, ConstantInt<1>)),
(b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, Constant<"0.0">, use<"y0">, ConstantInt<1>),
(b<"dot"> $n, adj<"y">, use<"y0">, ConstantInt<1>)),
/* ap */ (Seq<[]>
(b<"spr2"> $layout, $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">),
(DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">)),
/* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx),
/* beta */ (b<"dot"> $n, adj<"y">, $incy, input<"y">, $incy),
/* y */ (b<"scal"> $n, $beta, adj<"y">, $incy)
(b<"spr2"> $layout, $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">),
(DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">)),
/* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, Constant<"1.0">, adj<"x">),
/* beta */ (b<"dot"> $n, adj<"y">, input<"y">),
/* y */ (b<"scal"> $n, $beta, adj<"y">)
]
>;

Expand Down
144 changes: 76 additions & 68 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,8 @@ class EnzymeBase {
Type *retElemType, SmallVectorImpl<Value *> &args,
const std::map<int, Type *> &byVal,
const std::vector<DIFFE_TYPE> &constants, Function *fn,
DerivativeMode mode, Options &options, bool sizeOnly) {
DerivativeMode mode, Options &options, bool sizeOnly,
SmallVectorImpl<CallInst *> &calls) {
auto &differet = options.differet;
auto &tape = options.tape;
auto &width = options.width;
Expand Down Expand Up @@ -1702,63 +1703,13 @@ class EnzymeBase {
}

ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode);

if (Logic.PostOpt) {
auto Params = llvm::getInlineParams();

llvm::SetVector<CallInst *> Q;
Q.insert(diffretc);
while (Q.size()) {
auto cur = *Q.begin();
Function *outerFunc = cur->getParent()->getParent();
llvm::OptimizationRemarkEmitter ORE(outerFunc);
Q.erase(Q.begin());
if (auto F = cur->getCalledFunction()) {
if (!F->empty()) {
// Garbage collect AC's created
SmallVector<AssumptionCache *, 2> ACAlloc;
auto getAC = [&](Function &F) -> llvm::AssumptionCache & {
auto AC = new AssumptionCache(F);
ACAlloc.push_back(AC);
return *AC;
};
auto GetTLI =
[&](llvm::Function &F) -> const llvm::TargetLibraryInfo & {
return Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
};

auto GetInlineCost = [&](CallBase &CB) {
TargetTransformInfo TTI(F->getParent()->getDataLayout());
auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
return cst;
};
if (llvm::shouldInline(*cur, GetInlineCost, ORE)) {
InlineFunctionInfo IFI;
InlineResult IR = InlineFunction(*cur, IFI);
if (IR.isSuccess()) {
LowerSparsification(outerFunc, /*replaceAll*/ false);
for (auto U : outerFunc->users()) {
if (auto CI = dyn_cast<CallInst>(U)) {
if (CI->getCalledFunction() == outerFunc) {
Q.insert(CI);
}
}
}
}
}
for (auto AC : ACAlloc) {
delete AC;
}
}
}
}
}
return true;
calls.push_back(diffretc);
return diffret;
}

/// Return whether successful
bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode,
bool sizeOnly) {
bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, bool sizeOnly,
SmallVectorImpl<CallInst *> &calls) {

// determine function to differentiate
Function *fn = parseFunctionParameter(CI);
Expand Down Expand Up @@ -1796,16 +1747,17 @@ class EnzymeBase {

#if LLVM_VERSION_MAJOR >= 16
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.value(),
sizeOnly);
byVal, constants, fn, mode, options.value(), sizeOnly,
calls);
#else
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.getValue(),
sizeOnly);
sizeOnly, calls);
#endif
}

bool HandleProbProg(CallInst *CI, ProbProgMode mode) {
bool HandleProbProg(CallInst *CI, ProbProgMode mode,
SmallVectorImpl<CallInst *> &calls) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
if (!F)
Expand Down Expand Up @@ -1928,13 +1880,15 @@ class EnzymeBase {
}

#if LLVM_VERSION_MAJOR >= 16
bool status = HandleAutoDiff(
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.value(), false);
bool status =
HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal,
constants, newFunc, DerivativeMode::ReverseModeCombined,
opt.value(), false, calls);
#else
bool status = HandleAutoDiff(
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.getValue(), false);
bool status =
HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal,
constants, newFunc, DerivativeMode::ReverseModeCombined,
opt.getValue(), false, calls);
#endif

delete interface;
Expand Down Expand Up @@ -2447,17 +2401,19 @@ class EnzymeBase {
Changed = true;
}

SmallVector<CallInst *, 1> calls;

// Perform all the size replacements first to create constants
for (auto pair : toSize) {
bool successful = HandleAutoDiffArguments(pair.first, pair.second,
/*sizeOnly*/ true);
/*sizeOnly*/ true, calls);
Changed = true;
if (!successful)
break;
}
for (auto pair : toLower) {
bool successful = HandleAutoDiffArguments(pair.first, pair.second,
/*sizeOnly*/ false);
/*sizeOnly*/ false, calls);
Changed = true;
if (!successful)
break;
Expand Down Expand Up @@ -2495,7 +2451,59 @@ class EnzymeBase {
}

for (auto &&[call, mode] : toProbProg) {
HandleProbProg(call, mode);
HandleProbProg(call, mode, calls);
}

if (Logic.PostOpt) {
auto Params = llvm::getInlineParams();

llvm::SetVector<CallInst *> Q;
for (auto call : calls)
Q.insert(call);
while (Q.size()) {
auto cur = *Q.begin();
Function *outerFunc = cur->getParent()->getParent();
llvm::OptimizationRemarkEmitter ORE(outerFunc);
Q.erase(Q.begin());
if (auto F = cur->getCalledFunction()) {
if (!F->empty()) {
// Garbage collect AC's created
SmallVector<AssumptionCache *, 2> ACAlloc;
auto getAC = [&](Function &F) -> llvm::AssumptionCache & {
auto AC = new AssumptionCache(F);
ACAlloc.push_back(AC);
return *AC;
};
auto GetTLI =
[&](llvm::Function &F) -> const llvm::TargetLibraryInfo & {
return Logic.PPC.FAM.getResult<TargetLibraryAnalysis>(F);
};

auto GetInlineCost = [&](CallBase &CB) {
TargetTransformInfo TTI(F->getParent()->getDataLayout());
auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI);
return cst;
};
if (llvm::shouldInline(*cur, GetInlineCost, ORE)) {
InlineFunctionInfo IFI;
InlineResult IR = InlineFunction(*cur, IFI);
if (IR.isSuccess()) {
LowerSparsification(outerFunc, /*replaceAll*/ false);
for (auto U : outerFunc->users()) {
if (auto CI = dyn_cast<CallInst>(U)) {
if (CI->getCalledFunction() == outerFunc) {
Q.insert(CI);
}
}
}
}
}
for (auto AC : ACAlloc) {
delete AC;
}
}
}
}
}

if (Changed && EnzymeAttributor) {
Expand Down
33 changes: 22 additions & 11 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,8 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,
BlasInfo blas, llvm::ArrayRef<llvm::Value *> args,
llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
std::string copy_name = (blas.floatType + "lacpy" + blas.suffix).str();
std::string copy_name =
(blas.prefix + blas.floatType + "lacpy" + blas.suffix).str();

SmallVector<Type *, 1> tys;
for (auto arg : args)
Expand Down Expand Up @@ -2554,14 +2555,16 @@ llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V) {
// } else {
// ld_A = arg_lda;
// }
llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *trans,
llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B,
llvm::ArrayRef<llvm::Value *> trans,
llvm::Value *arg_ld, llvm::Value *dim1,
llvm::Value *dim2, bool cacheMat,
bool byRef) {
if (!cacheMat)
return arg_ld;

Value *width = B.CreateSelect(is_normal(B, trans, byRef), dim1, dim2);
assert(trans.size() == 1);
Value *width = CreateSelect(B, is_normal(B, trans[0], byRef), dim1, dim2);

return width;
}
Expand Down Expand Up @@ -2593,19 +2596,27 @@ llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::IntegerType *intType,
return B.CreateLoad(intType, VP);
}

llvm::Value *get_blas_row(llvm::IRBuilder<> &B, llvm::Value *trans,
llvm::Value *row, llvm::Value *col, bool byRef) {

SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
ArrayRef<llvm::Value *> transA,
ArrayRef<llvm::Value *> row,
ArrayRef<llvm::Value *> col,
bool byRef) {
assert(transA.size() == 1);
auto trans = transA[0];
if (byRef) {
auto charType = IntegerType::get(trans->getContext(), 8);
trans = B.CreateLoad(charType, trans, "ld.row.trans");
}

return B.CreateSelect(
B.CreateOr(
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')),
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'))),
row, col);
auto cond = B.CreateOr(
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')),
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')));
assert(row.size() == col.size());
SmallVector<Value *, 1> toreturn;
for (size_t i = 0; i < row.size(); i++) {
toreturn.push_back(B.CreateSelect(cond, row[i], col[i]));
}
return toreturn;
}

// return how many Special pointers are in T (count > 0),
Expand Down
Loading

0 comments on commit edd0331

Please sign in to comment.