Skip to content

Commit

Permalink
Fix Blas diffuse on P^3 computation
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 6, 2023
1 parent a1d95f5 commit 434803a
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 4 deletions.
184 changes: 184 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/blas_diffuse.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -ge 13 ] ; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -jump-threading -adce -S | FileCheck %s; fi
; RUN: if [ %llvmver -ge 13 ]; then %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,early-cse,instsimplify,jump-threading,adce)" -enzyme-preopt=false -S | FileCheck %s ; fi

; ModuleID = '../examples/big/big_inlined_correctness.cpp'
source_filename = "../examples/big/big_inlined_correctness.cpp"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%struct.Prod = type { ptr, double }

declare i32 @dgemm_(ptr nocapture noundef readonly %transa_t, ptr nocapture noundef readonly %transb_t, ptr nocapture noundef readonly %m, ptr nocapture noundef readonly %n, ptr nocapture noundef readonly %k, ptr nocapture noundef readonly %alpha, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %lda, ptr nocapture noundef readonly %b, ptr nocapture noundef readonly %ldb, ptr nocapture noundef readonly %beta, ptr nocapture noundef %c, ptr nocapture noundef readonly %ldc)

; Function Attrs: mustprogress noinline nounwind uwtable
define dso_local void @_Z3mulR4ProdPd(ptr nocapture noundef nonnull align 8 dereferenceable(16) %P, ptr noalias nocapture noundef readonly %rhs) {
entry:
%N = alloca i8, align 1
%ten = alloca i32, align 4
%one = alloca double, align 8
%zero = alloca double, align 8
%calloc = call dereferenceable_or_null(32) ptr @calloc(i64 1, i64 32)
store i8 78, ptr %N, align 1
store i32 2, ptr %ten, align 4
store double 1.000000e+00, ptr %one, align 8
store double 0.000000e+00, ptr %zero, align 8
%call1 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten)
%0 = load ptr, ptr %P, align 8
%call2 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %zero, ptr noundef %0, ptr noundef nonnull %ten)
%alpha = getelementptr inbounds %struct.Prod, ptr %P, i64 0, i32 1
store double 0.000000e+00, ptr %alpha, align 8
ret void
}

declare noalias ptr @malloc(i64)

; Function Attrs: mustprogress nounwind uwtable
define dso_local double @_Z8simulatePd(ptr nocapture noundef readonly %P) {
entry:
%M = alloca %struct.Prod, align 8
%call = tail call noalias dereferenceable_or_null(32) ptr @malloc(i64 noundef 32)
store ptr %call, ptr %M, align 8
%alpha = getelementptr inbounds %struct.Prod, ptr %M, i64 0, i32 1
store double 1.000000e+00, ptr %alpha, align 8
call void @_Z3mulR4ProdPd(ptr noundef nonnull align 8 dereferenceable(16) %M, ptr noundef %P)
%0 = load ptr, ptr %M, align 8
%1 = load double, ptr %0, align 8
ret double %1
}

define void @caller(ptr %A, ptr %Adup) {
entry:
call void (...) @_Z17__enzyme_autodiffz(ptr noundef nonnull @_Z8simulatePd, metadata !"enzyme_dup", ptr noundef nonnull %A, ptr noundef nonnull %Adup)
ret void
}

declare void @_Z17__enzyme_autodiffz(...)

declare noalias noundef ptr @calloc(i64 noundef, i64 noundef)

; we must actually save or set the matmul
; CHECK: define internal void @diffe_Z3mulR4ProdPd(ptr nocapture align 8 dereferenceable(16) %P, ptr nocapture align 8 %"P'", ptr noalias nocapture readonly %rhs, ptr nocapture %"rhs'", { ptr, ptr, ptr, ptr } %tapeArg)
; CHECK-NEXT: invertentry:
; CHECK-NEXT: %byref.transpose.transb = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8
; CHECK-NEXT: %byref.transpose.transa = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.05 = alloca double, align 8
; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1
; CHECK-NEXT: %byref.constant.int.0 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.int.06 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.fp.1.07 = alloca double, align 8
; CHECK-NEXT: %[[i0:.+]] = alloca i32, align 4
; CHECK-NEXT: %byref.transpose.transb11 = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.014 = alloca double, align 8
; CHECK-NEXT: %byref.transpose.transa16 = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.019 = alloca double, align 8
; CHECK-NEXT: %byref.constant.char.G20 = alloca i8, align 1
; CHECK-NEXT: %byref.constant.int.021 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.int.022 = alloca i32, align 4
; CHECK-NEXT: %byref.constant.fp.1.023 = alloca double, align 8
; CHECK-NEXT: %[[i1:.+]] = alloca i32, align 4
; CHECK-NEXT: %malloccall3 = alloca i8, i64 8, align 8
; CHECK-NEXT: %malloccall = alloca i8, i64 1, align 1
; CHECK-NEXT: %malloccall2 = alloca i8, i64 8, align 8
; CHECK-NEXT: %malloccall1 = alloca i8, i64 4, align 4
; CHECK-NEXT: %"calloc'mi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 2
; CHECK-NEXT: %calloc = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 3
; CHECK-NEXT: store i8 78, ptr %malloccall, align 1
; CHECK-NEXT: store i32 2, ptr %malloccall1, align 4
; CHECK-NEXT: store double 1.000000e+00, ptr %malloccall2, align 8
; CHECK-NEXT: store double 0.000000e+00, ptr %malloccall3, align 8
; CHECK-NEXT: %"'il_phi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 1
; CHECK-NEXT: %[[i2:.+]] = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 0
; CHECK-NEXT: %"alpha'ipg" = getelementptr inbounds %struct.Prod, ptr %"P'", i64 0, i32 1
; CHECK-NEXT: store double 0.000000e+00, ptr %"alpha'ipg", align 8
; CHECK-NEXT: %ld.transb = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i3:.+]] = icmp eq i8 %ld.transb, 110
; CHECK-NEXT: %[[i4:.+]] = select i1 %[[i3]], i8 116, i8 0
; CHECK-NEXT: %[[i5:.+]] = icmp eq i8 %ld.transb, 78
; CHECK-NEXT: %[[i6:.+]] = select i1 %[[i5]], i8 84, i8 %[[i4]]
; CHECK-NEXT: %[[i7:.+]] = icmp eq i8 %ld.transb, 116
; CHECK-NEXT: %[[i8:.+]] = select i1 %[[i7]], i8 110, i8 %[[i6]]
; CHECK-NEXT: %[[i9:.+]] = icmp eq i8 %ld.transb, 84
; CHECK-NEXT: %[[i10:.+]] = select i1 %[[i9]], i8 78, i8 %[[i8]]
; CHECK-NEXT: store i8 %[[i10]], ptr %byref.transpose.transb, align 1
; CHECK-NEXT: %ld.row.trans = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i11:.+]] = icmp eq i8 %ld.row.trans, 110
; CHECK-NEXT: %[[i12:.+]] = icmp eq i8 %ld.row.trans, 78
; CHECK-NEXT: %[[i13:.+]] = or i1 %[[i12]], %[[i11]]
; CHECK-NEXT: %[[i14:.+]] = select i1 %[[i13]], ptr %byref.transpose.transb, ptr %malloccall
; CHECK-NEXT: %[[i15:.+]] = select i1 %[[i13]], ptr %"'il_phi", ptr %rhs
; CHECK-NEXT: %[[i16:.+]] = select i1 %[[i13]], ptr %rhs, ptr %"'il_phi"
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.0, align 8
; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i14]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i15]], ptr %malloccall1, ptr %[[i16]], ptr %malloccall1, ptr %byref.constant.fp.1.0, ptr %"calloc'mi", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: %ld.transa = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i17:.+]] = icmp eq i8 %ld.transa, 110
; CHECK-NEXT: %[[i18:.+]] = select i1 %[[i17]], i8 116, i8 0
; CHECK-NEXT: %[[i19:.+]] = icmp eq i8 %ld.transa, 78
; CHECK-NEXT: %[[i20:.+]] = select i1 %[[i19]], i8 84, i8 %[[i18]]
; CHECK-NEXT: %[[i21:.+]] = icmp eq i8 %ld.transa, 116
; CHECK-NEXT: %[[i22:.+]] = select i1 %[[i21]], i8 110, i8 %[[i20]]
; CHECK-NEXT: %[[i23:.+]] = icmp eq i8 %ld.transa, 84
; CHECK-NEXT: %[[i24:.+]] = select i1 %[[i23]], i8 78, i8 %[[i22]]
; CHECK-NEXT: store i8 %[[i24]], ptr %byref.transpose.transa, align 1
; CHECK-NEXT: %ld.row.trans2 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i25:.+]] = icmp eq i8 %ld.row.trans2, 110
; CHECK-NEXT: %[[i26:.+]] = icmp eq i8 %ld.row.trans2, 78
; CHECK-NEXT: %[[i27:.+]] = or i1 %[[i26]], %[[i25]]
; CHECK-NEXT: %[[i28:.+]] = select i1 %[[i27]], ptr %byref.transpose.transa, ptr %malloccall
; CHECK-NEXT: %[[i29:.+]] = select i1 %[[i27]], ptr %[[i2]], ptr %"'il_phi"
; CHECK-NEXT: %[[i30:.+]] = select i1 %[[i27]], ptr %"'il_phi", ptr %[[i2]]
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.05, align 8
; CHECK-NEXT: call void @dgemm_(ptr %[[i28]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i29]], ptr %malloccall1, ptr %[[i30]], ptr %malloccall1, ptr %byref.constant.fp.1.05, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G, align 1
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.0, align 4
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.06, align 4
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.07, align 8
; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G, ptr %byref.constant.int.0, ptr %byref.constant.int.06, ptr %byref.constant.fp.1.07, ptr %malloccall3, ptr %malloccall1, ptr %malloccall1, ptr %"'il_phi", ptr %malloccall1, ptr %[[i0]], i32 1)
; CHECK-NEXT: tail call void @free(ptr nonnull %[[i2]])
; CHECK-NEXT: %ld.transb10 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i31:.+]] = icmp eq i8 %ld.transb10, 110
; CHECK-NEXT: %[[i32:.+]] = select i1 %[[i31]], i8 116, i8 0
; CHECK-NEXT: %[[i33:.+]] = icmp eq i8 %ld.transb10, 78
; CHECK-NEXT: %[[i34:.+]] = select i1 %[[i33]], i8 84, i8 %[[i32]]
; CHECK-NEXT: %[[i35:.+]] = icmp eq i8 %ld.transb10, 116
; CHECK-NEXT: %[[i36:.+]] = select i1 %[[i35]], i8 110, i8 %[[i34]]
; CHECK-NEXT: %[[i37:.+]] = icmp eq i8 %ld.transb10, 84
; CHECK-NEXT: %[[i38:.+]] = select i1 %[[i37]], i8 78, i8 %[[i36]]
; CHECK-NEXT: store i8 %[[i38]], ptr %byref.transpose.transb11, align 1
; CHECK-NEXT: %ld.row.trans12 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i39:.+]] = icmp eq i8 %ld.row.trans12, 110
; CHECK-NEXT: %[[i40:.+]] = icmp eq i8 %ld.row.trans12, 78
; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]]
; CHECK-NEXT: %[[i42:.+]] = select i1 %[[i41]], ptr %byref.transpose.transb11, ptr %malloccall
; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i41]], ptr %"calloc'mi", ptr %rhs
; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i41]], ptr %rhs, ptr %"calloc'mi"
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.014, align 8
; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i42]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i43]], ptr %malloccall1, ptr %[[i44]], ptr %malloccall1, ptr %byref.constant.fp.1.014, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: %ld.transa15 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i45:.+]] = icmp eq i8 %ld.transa15, 110
; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i45]], i8 116, i8 0
; CHECK-NEXT: %[[i47:.+]] = icmp eq i8 %ld.transa15, 78
; CHECK-NEXT: %[[i48:.+]] = select i1 %[[i47]], i8 84, i8 %[[i46]]
; CHECK-NEXT: %[[i49:.+]] = icmp eq i8 %ld.transa15, 116
; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8 110, i8 %[[i48]]
; CHECK-NEXT: %[[i51:.+]] = icmp eq i8 %ld.transa15, 84
; CHECK-NEXT: %[[i52:.+]] = select i1 %[[i51]], i8 78, i8 %[[i50]]
; CHECK-NEXT: store i8 %[[i52]], ptr %byref.transpose.transa16, align 1
; CHECK-NEXT: %ld.row.trans17 = load i8, ptr %malloccall, align 1
; CHECK-NEXT: %[[i53:.+]] = icmp eq i8 %ld.row.trans17, 110
; CHECK-NEXT: %[[i54:.+]] = icmp eq i8 %ld.row.trans17, 78
; CHECK-NEXT: %[[i55:.+]] = or i1 %[[i54]], %[[i53]]
; CHECK-NEXT: %[[i56:.+]] = select i1 %[[i55:.+]], ptr %byref.transpose.transa16, ptr %malloccall
; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i55:.+]], ptr %rhs, ptr %"calloc'mi"
; CHECK-NEXT: %[[i58:.+]] = select i1 %[[i55:.+]], ptr %"calloc'mi", ptr %rhs
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.019, align 8
; CHECK-NEXT: call void @dgemm_(ptr %[[i56]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %57, ptr %malloccall1, ptr %58, ptr %malloccall1, ptr %byref.constant.fp.1.019, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G20, align 1
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.021, align 4
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.022, align 4
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.023, align 8
; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G20, ptr %byref.constant.int.021, ptr %byref.constant.int.022, ptr %byref.constant.fp.1.023, ptr %malloccall2, ptr %malloccall1, ptr %malloccall1, ptr %"calloc'mi", ptr %malloccall1, ptr %[[i1]], i32 1)
; CHECK-NEXT: call void @free(ptr nonnull %"calloc'mi")
; CHECK-NEXT: call void @free(ptr %calloc)
; CHECK-NEXT: ret void
; CHECK-NEXT: }
99 changes: 99 additions & 0 deletions enzyme/test/Integration/ReverseMode/blas_gemm2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include "test_utils.h"
#include "../blas_inline.h"

extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
extern int enzyme_const;

#include <assert.h>

void __enzyme_autodiff(void*, ...);

const size_t n = 20;

#include <string.h>
struct Prod {
double* out;
double alpha;
};

__attribute__((noinline))
void mul(struct Prod* P, double* __restrict__ rhs) {
double* tmp= (double*)malloc(sizeof(double)*n*n);
memset(tmp, 0, n*n*sizeof(double));
char N = 'N';
int ten = n;
double one = 1.0;
double zero = 0.0;

dgemm_(&N, &N, &ten, &ten, &ten, &one, rhs, &ten, rhs, &ten, &one, tmp, &ten);
dgemm_(&N, &N, &ten, &ten, &ten, &one, tmp, &ten, rhs, &ten, &zero, P->out, &ten);
P->alpha = 0;
return;
}

double simulate(double* P) {
struct Prod M;
M.out = (double*)malloc(sizeof(double)*n*n);
M.alpha = 1.0;
mul(&M, P);
return M.out[0];
// double *out = (double*)malloc(sizeof(double)*n*n);
// dgemm_(&N, &N, &ten, &ten, &ten, &one, P1.data(), &ten, P.data(), &ten, &zero, &out[0], &ten);
// return P1(0, 0);
}

int main(int argc, char **argv) {

double A[n * n];
double Adup[n * n];
double Adup_fd[n * n];

for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
A[n*i + j] = j == i ? 0.3 : 0.1;
Adup[n*i + j] = 0.0;
Adup_fd[n*i + j] = 0.0;
}
}

double delta = 0.001;
delta = delta * delta;

double fx = simulate(A);
printf("f(A) = %f\n", fx);

// if (argc == 2) {
__enzyme_autodiff((void *)simulate, enzyme_dup, &A[0], &Adup[0]);
printf("dP(0,0) = %f, dP(0,1) = %f, dP(1,0) = %f\n", Adup[0], Adup[1], Adup[2]);
//}

for (int i = 0; i < n*n; i++) {
A[i] += delta / 2;
double fx2 = simulate(A);
A[i] -= delta;
double fx3 = simulate(A);
A[i] += delta/2;
Adup_fd[i] = (fx2 - fx3) / delta;

printf("dA_fd[%d]=%f\n", i, Adup_fd[i]);

APPROX_EQ(Adup[i], Adup_fd[i], 1e-6);
}

return 0;
}
9 changes: 7 additions & 2 deletions enzyme/test/Integration/blas_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int xerbla_(const char *srname, integer *info, int len)
return 0;
}
__attribute__((noinline))
logical lsame_(char *ca, char *cb, int, int)
logical lsame_(char *ca, char *cb, int cs, int cb)
{
/* System generated locals */
logical ret_val;
Expand Down Expand Up @@ -764,12 +764,17 @@ doublereal ddot_(integer *n, doublereal *dx, integer *incx, doublereal *dy,
} /* ddot_ */

__attribute__((noinline))
/* Subroutine */ int dgemm_(const char *transa, const char *transb, const integer *m, const integer *
/* Subroutine */ int dgemm_(const char *transa_t, const char *transb_t, const integer *m, const integer *
n, const integer *k, const doublereal *alpha, const doublereal *a, const integer *lda,
const doublereal *b, const integer *ldb, const doublereal *beta, doublereal *c, const integer
*ldc)
{

char transa_v = *transa_t;
char* transa = &transa_v;

char transb_v = *transb_t;
char* transb = &transb_v;

/* System generated locals */
integer a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset, i__1, i__2,
Expand Down
5 changes: 3 additions & 2 deletions enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {
}
}

os << " if (!shadow && need_" << argname << " && !cache_" << argname
<< ")\n"
os << " if (!shadow && need_" << argname
<< " && ((cacheMode && overwritten_args_ptr) ? !cache_" << argname
<< " : true ))\n"
<< " return true;\n";
os << " }\n";
}
Expand Down

0 comments on commit 434803a

Please sign in to comment.