Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Blas diffuse on P^3 computation #1526

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,10 @@ Function *getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT,
#if LLVM_VERSION_MAJOR >= 15
if (Mod.getContext().supportsTypedPointers()) {
#endif
assert(PT->getPointerElementType() == elementType);
#if LLVM_VERSION_MAJOR >= 13
if (!PT->isOpaquePointerTy())
#endif
assert(PT->getPointerElementType() == elementType);
#if LLVM_VERSION_MAJOR >= 15
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,10 @@ static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode,
if (!arg->getContext().supportsTypedPointers()) {
return DIFFE_TYPE::DUP_ARG;
}
#elif LLVM_VERSION_MAJOR >= 13
if (arg->isOpaquePointerTy()) {
return DIFFE_TYPE::DUP_ARG;
}
#endif
switch (whatType(arg->getPointerElementType(), mode, integersAreConstant,
seen)) {
Expand Down
184 changes: 184 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/blas_diffuse.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -ge 14 ] ; then %opt < %s %loadEnzyme -opaque-pointers -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -jump-threading -adce -S | FileCheck %s; fi
; RUN: if [ %llvmver -ge 14 ]; then %opt < %s %newLoadEnzyme -opaque-pointers -passes="enzyme,function(mem2reg,early-cse,instsimplify,jump-threading,adce)" -enzyme-preopt=false -S | FileCheck %s ; fi

; ModuleID = '../examples/big/big_inlined_correctness.cpp'
source_filename = "../examples/big/big_inlined_correctness.cpp"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%struct.Prod = type { ptr, double }

declare i32 @dgemm_(ptr nocapture noundef readonly %transa_t, ptr nocapture noundef readonly %transb_t, ptr nocapture noundef readonly %m, ptr nocapture noundef readonly %n, ptr nocapture noundef readonly %k, ptr nocapture noundef readonly %alpha, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %lda, ptr nocapture noundef readonly %b, ptr nocapture noundef readonly %ldb, ptr nocapture noundef readonly %beta, ptr nocapture noundef %c, ptr nocapture noundef readonly %ldc)

; Function Attrs: mustprogress noinline nounwind uwtable
define dso_local void @_Z3mulR4ProdPd(ptr nocapture noundef nonnull align 8 dereferenceable(16) %P, ptr noalias nocapture noundef readonly %rhs) {
entry:
%N = alloca i8, align 1
%ten = alloca i32, align 4
%one = alloca double, align 8
%zero = alloca double, align 8
%calloc = call dereferenceable_or_null(32) ptr @calloc(i64 1, i64 32)
store i8 78, ptr %N, align 1
store i32 2, ptr %ten, align 4
store double 1.000000e+00, ptr %one, align 8
store double 0.000000e+00, ptr %zero, align 8
%call1 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten)
%0 = load ptr, ptr %P, align 8
%call2 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %zero, ptr noundef %0, ptr noundef nonnull %ten)
%alpha = getelementptr inbounds %struct.Prod, ptr %P, i64 0, i32 1
store double 0.000000e+00, ptr %alpha, align 8
ret void
}

declare noalias ptr @malloc(i64)

; Function Attrs: mustprogress nounwind uwtable
define dso_local double @_Z8simulatePd(ptr nocapture noundef readonly %P) {
entry:
%M = alloca %struct.Prod, align 8
%call = tail call noalias dereferenceable_or_null(32) ptr @malloc(i64 noundef 32)
store ptr %call, ptr %M, align 8
%alpha = getelementptr inbounds %struct.Prod, ptr %M, i64 0, i32 1
store double 1.000000e+00, ptr %alpha, align 8
call void @_Z3mulR4ProdPd(ptr noundef nonnull align 8 dereferenceable(16) %M, ptr noundef %P)
%0 = load ptr, ptr %M, align 8
%1 = load double, ptr %0, align 8
ret double %1
}

define void @caller(ptr %A, ptr %Adup) {
entry:
call void (...) @_Z17__enzyme_autodiffz(ptr noundef nonnull @_Z8simulatePd, metadata !"enzyme_dup", ptr noundef nonnull %A, ptr noundef nonnull %Adup)
ret void
}

declare void @_Z17__enzyme_autodiffz(...)

declare noalias noundef ptr @calloc(i64 noundef, i64 noundef)

; we must actually save or set the matmul
; CHECK: define internal void @diffe_Z3mulR4ProdPd(ptr nocapture align 8 dereferenceable(16) %P, ptr nocapture align 8 %"P'", ptr noalias nocapture readonly %rhs, ptr nocapture %"rhs'", { ptr, ptr, ptr, ptr } %tapeArg)
; CHECK-NEXT: invertentry:
; CHECK-NEXT: %byref.transpose.transb = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8
; CHECK-NEXT: %byref.transpose.transa = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.05 = alloca double, align 8
; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1
; CHECK-NEXT: %byref.constant.int.0 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.int.06 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.fp.1.07 = alloca double, align 8
; CHECK-NEXT: %[[i0:.+]] = alloca i32, align 4
; CHECK-NEXT: %byref.transpose.transb11 = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.014 = alloca double, align 8
; CHECK-NEXT: %byref.transpose.transa16 = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.019 = alloca double, align 8
; CHECK-NEXT: %byref.constant.char.G20 = alloca i8, align 1
; CHECK-NEXT: %byref.constant.int.021 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.int.022 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.fp.1.023 = alloca double, align 8
; CHECK-NEXT: %[[i1:.+]] = alloca i32, align 4
; CHECK-NEXT: %malloccall3 = alloca i8, i64 8, align 8
; CHECK-NEXT: %malloccall = alloca i8, i64 1, align 1
; CHECK-NEXT: %malloccall2 = alloca i8, i64 8, align 8
; CHECK-NEXT: %malloccall1 = alloca i8, i64 4, align 4
; CHECK-NEXT: %"calloc'mi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 2
; CHECK-NEXT: %calloc = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 3
; CHECK-NEXT: store i8 78, ptr %malloccall, align 1
; CHECK-NEXT: store i32 2, ptr %malloccall1, align 4
; CHECK-NEXT: store double 1.000000e+00, ptr %malloccall2, align 8
; CHECK-NEXT: store double 0.000000e+00, ptr %malloccall3, align 8
; CHECK-NEXT: %"'il_phi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 1
; CHECK-NEXT: %[[i2:.+]] = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 0
; CHECK-NEXT: %"alpha'ipg" = getelementptr inbounds %struct.Prod, ptr %"P'", i64 0, i32 1
; CHECK-NEXT: store double 0.000000e+00, ptr %"alpha'ipg", align 8
; CHECK-NEXT: %ld.transb = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i3:.+]] = icmp eq i8 %ld.transb, 110
; CHECK-NEXT: %[[i4:.+]] = select i1 %[[i3]], i8 116, i8 0
; CHECK-NEXT: %[[i5:.+]] = icmp eq i8 %ld.transb, 78
; CHECK-NEXT: %[[i6:.+]] = select i1 %[[i5]], i8 84, i8 %[[i4]]
; CHECK-NEXT: %[[i7:.+]] = icmp eq i8 %ld.transb, 116
; CHECK-NEXT: %[[i8:.+]] = select i1 %[[i7]], i8 110, i8 %[[i6]]
; CHECK-NEXT: %[[i9:.+]] = icmp eq i8 %ld.transb, 84
; CHECK-NEXT: %[[i10:.+]] = select i1 %[[i9]], i8 78, i8 %[[i8]]
; CHECK-NEXT: store i8 %[[i10]], ptr %byref.transpose.transb, align 1
; CHECK-NEXT: %ld.row.trans = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i11:.+]] = icmp eq i8 %ld.row.trans, 110
; CHECK-NEXT: %[[i12:.+]] = icmp eq i8 %ld.row.trans, 78
; CHECK-NEXT: %[[i13:.+]] = or i1 %[[i12]], %[[i11]]
; CHECK-NEXT: %[[i14:.+]] = select i1 %[[i13]], ptr %byref.transpose.transb, ptr %malloccall
; CHECK-NEXT: %[[i15:.+]] = select i1 %[[i13]], ptr %"'il_phi", ptr %rhs
; CHECK-NEXT: %[[i16:.+]] = select i1 %[[i13]], ptr %rhs, ptr %"'il_phi"
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.0, align 8
; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i14]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i15]], ptr %malloccall1, ptr %[[i16]], ptr %malloccall1, ptr %byref.constant.fp.1.0, ptr %"calloc'mi", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: %ld.transa = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i17:.+]] = icmp eq i8 %ld.transa, 110
; CHECK-NEXT: %[[i18:.+]] = select i1 %[[i17]], i8 116, i8 0
; CHECK-NEXT: %[[i19:.+]] = icmp eq i8 %ld.transa, 78
; CHECK-NEXT: %[[i20:.+]] = select i1 %[[i19]], i8 84, i8 %[[i18]]
; CHECK-NEXT: %[[i21:.+]] = icmp eq i8 %ld.transa, 116
; CHECK-NEXT: %[[i22:.+]] = select i1 %[[i21]], i8 110, i8 %[[i20]]
; CHECK-NEXT: %[[i23:.+]] = icmp eq i8 %ld.transa, 84
; CHECK-NEXT: %[[i24:.+]] = select i1 %[[i23]], i8 78, i8 %[[i22]]
; CHECK-NEXT: store i8 %[[i24]], ptr %byref.transpose.transa, align 1
; CHECK-NEXT: %ld.row.trans2 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i25:.+]] = icmp eq i8 %ld.row.trans2, 110
; CHECK-NEXT: %[[i26:.+]] = icmp eq i8 %ld.row.trans2, 78
; CHECK-NEXT: %[[i27:.+]] = or i1 %[[i26]], %[[i25]]
; CHECK-NEXT: %[[i28:.+]] = select i1 %[[i27]], ptr %byref.transpose.transa, ptr %malloccall
; CHECK-NEXT: %[[i29:.+]] = select i1 %[[i27]], ptr %[[i2]], ptr %"'il_phi"
; CHECK-NEXT: %[[i30:.+]] = select i1 %[[i27]], ptr %"'il_phi", ptr %[[i2]]
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.05, align 8
; CHECK-NEXT: call void @dgemm_(ptr %[[i28]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i29]], ptr %malloccall1, ptr %[[i30]], ptr %malloccall1, ptr %byref.constant.fp.1.05, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G, align 1
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.0, align 4
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.06, align 4
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.07, align 8
; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G, ptr %byref.constant.int.0, ptr %byref.constant.int.06, ptr %byref.constant.fp.1.07, ptr %malloccall3, ptr %malloccall1, ptr %malloccall1, ptr %"'il_phi", ptr %malloccall1, ptr %[[i0]], i32 1)
; CHECK-NEXT: tail call void @free(ptr nonnull %[[i2]])
; CHECK-NEXT: %ld.transb10 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i31:.+]] = icmp eq i8 %ld.transb10, 110
; CHECK-NEXT: %[[i32:.+]] = select i1 %[[i31]], i8 116, i8 0
; CHECK-NEXT: %[[i33:.+]] = icmp eq i8 %ld.transb10, 78
; CHECK-NEXT: %[[i34:.+]] = select i1 %[[i33]], i8 84, i8 %[[i32]]
; CHECK-NEXT: %[[i35:.+]] = icmp eq i8 %ld.transb10, 116
; CHECK-NEXT: %[[i36:.+]] = select i1 %[[i35]], i8 110, i8 %[[i34]]
; CHECK-NEXT: %[[i37:.+]] = icmp eq i8 %ld.transb10, 84
; CHECK-NEXT: %[[i38:.+]] = select i1 %[[i37]], i8 78, i8 %[[i36]]
; CHECK-NEXT: store i8 %[[i38]], ptr %byref.transpose.transb11, align 1
; CHECK-NEXT: %ld.row.trans12 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i39:.+]] = icmp eq i8 %ld.row.trans12, 110
; CHECK-NEXT: %[[i40:.+]] = icmp eq i8 %ld.row.trans12, 78
; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]]
; CHECK-NEXT: %[[i42:.+]] = select i1 %[[i41]], ptr %byref.transpose.transb11, ptr %malloccall
; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i41]], ptr %"calloc'mi", ptr %rhs
; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i41]], ptr %rhs, ptr %"calloc'mi"
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.014, align 8
; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i42]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i43]], ptr %malloccall1, ptr %[[i44]], ptr %malloccall1, ptr %byref.constant.fp.1.014, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: %ld.transa15 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i45:.+]] = icmp eq i8 %ld.transa15, 110
; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i45]], i8 116, i8 0
; CHECK-NEXT: %[[i47:.+]] = icmp eq i8 %ld.transa15, 78
; CHECK-NEXT: %[[i48:.+]] = select i1 %[[i47]], i8 84, i8 %[[i46]]
; CHECK-NEXT: %[[i49:.+]] = icmp eq i8 %ld.transa15, 116
; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8 110, i8 %[[i48]]
; CHECK-NEXT: %[[i51:.+]] = icmp eq i8 %ld.transa15, 84
; CHECK-NEXT: %[[i52:.+]] = select i1 %[[i51]], i8 78, i8 %[[i50]]
; CHECK-NEXT: store i8 %[[i52]], ptr %byref.transpose.transa16, align 1
; CHECK-NEXT: %ld.row.trans17 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i53:.+]] = icmp eq i8 %ld.row.trans17, 110
; CHECK-NEXT: %[[i54:.+]] = icmp eq i8 %ld.row.trans17, 78
; CHECK-NEXT: %[[i55:.+]] = or i1 %[[i54]], %[[i53]]
; CHECK-NEXT: %[[i56:.+]] = select i1 %[[i55:.+]], ptr %byref.transpose.transa16, ptr %malloccall
; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i55:.+]], ptr %rhs, ptr %"calloc'mi"
; CHECK-NEXT: %[[i58:.+]] = select i1 %[[i55:.+]], ptr %"calloc'mi", ptr %rhs
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.019, align 8
; CHECK-NEXT: call void @dgemm_(ptr %[[i56]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %57, ptr %malloccall1, ptr %58, ptr %malloccall1, ptr %byref.constant.fp.1.019, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G20, align 1
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.021, align 4
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.022, align 4
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.023, align 8
; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G20, ptr %byref.constant.int.021, ptr %byref.constant.int.022, ptr %byref.constant.fp.1.023, ptr %malloccall2, ptr %malloccall1, ptr %malloccall1, ptr %"calloc'mi", ptr %malloccall1, ptr %[[i1]], i32 1)
; CHECK-NEXT: call void @free(ptr nonnull %"calloc'mi")
; CHECK-NEXT: call void @free(ptr %calloc)
; CHECK-NEXT: ret void
; CHECK-NEXT: }
99 changes: 99 additions & 0 deletions enzyme/test/Integration/ReverseMode/blas_gemm2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include "test_utils.h"
#include "../blas_inline.h"

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
extern int enzyme_const;

#include <assert.h>

void __enzyme_autodiff(void*, ...);

const size_t n = 20;

#include <string.h>
struct Prod {
double* out;
double alpha;
};

__attribute__((noinline))
void mul(struct Prod* P, double* __restrict__ rhs) {
double* tmp= (double*)malloc(sizeof(double)*n*n);
memset(tmp, 0, n*n*sizeof(double));
char N = 'N';
int ten = n;
double one = 1.0;
double zero = 0.0;

dgemm_(&N, &N, &ten, &ten, &ten, &one, rhs, &ten, rhs, &ten, &one, tmp, &ten);
dgemm_(&N, &N, &ten, &ten, &ten, &one, tmp, &ten, rhs, &ten, &zero, P->out, &ten);
P->alpha = 0;
return;
}

double simulate(double* P) {
struct Prod M;
M.out = (double*)malloc(sizeof(double)*n*n);
M.alpha = 1.0;
mul(&M, P);
return M.out[0];
// double *out = (double*)malloc(sizeof(double)*n*n);
// dgemm_(&N, &N, &ten, &ten, &ten, &one, P1.data(), &ten, P.data(), &ten, &zero, &out[0], &ten);
// return P1(0, 0);
}

int main(int argc, char **argv) {

double A[n * n];
double Adup[n * n];
double Adup_fd[n * n];

for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
A[n*i + j] = j == i ? 0.3 : 0.1;
Adup[n*i + j] = 0.0;
Adup_fd[n*i + j] = 0.0;
}
}

double delta = 0.001;
delta = delta * delta;

double fx = simulate(A);
printf("f(A) = %f\n", fx);

// if (argc == 2) {
__enzyme_autodiff((void *)simulate, enzyme_dup, &A[0], &Adup[0]);
printf("dP(0,0) = %f, dP(0,1) = %f, dP(1,0) = %f\n", Adup[0], Adup[1], Adup[2]);
//}

for (int i = 0; i < n*n; i++) {
A[i] += delta / 2;
double fx2 = simulate(A);
A[i] -= delta;
double fx3 = simulate(A);
A[i] += delta/2;
Adup_fd[i] = (fx2 - fx3) / delta;

printf("dA_fd[%d]=%f\n", i, Adup_fd[i]);

APPROX_EQ(Adup[i], Adup_fd[i], 1e-6);
}

return 0;
}
9 changes: 7 additions & 2 deletions enzyme/test/Integration/blas_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int xerbla_(const char *srname, integer *info, int len)
return 0;
}
__attribute__((noinline))
logical lsame_(char *ca, char *cb, int, int)
logical lsame_(char *ca, char *cb, int ca_size, int cb_size)
{
/* System generated locals */
logical ret_val;
Expand Down Expand Up @@ -764,12 +764,17 @@ doublereal ddot_(integer *n, doublereal *dx, integer *incx, doublereal *dy,
} /* ddot_ */

__attribute__((noinline))
/* Subroutine */ int dgemm_(const char *transa, const char *transb, const integer *m, const integer *
/* Subroutine */ int dgemm_(const char *transa_t, const char *transb_t, const integer *m, const integer *
n, const integer *k, const doublereal *alpha, const doublereal *a, const integer *lda,
const doublereal *b, const integer *ldb, const doublereal *beta, doublereal *c, const integer
*ldc)
{

char transa_v = *transa_t;
char* transa = &transa_v;

char transb_v = *transb_t;
char* transb = &transb_v;

/* System generated locals */
integer a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset, i__1, i__2,
Expand Down
5 changes: 3 additions & 2 deletions enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {
}
}

os << " if (!shadow && need_" << argname << " && !cache_" << argname
<< ")\n"
os << " if (!shadow && need_" << argname
<< " && ((cacheMode && overwritten_args_ptr) ? !cache_" << argname
<< " : true ))\n"
<< " return true;\n";
os << " }\n";
}
Expand Down
Loading