Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 20, 2023
1 parent 10a37c8 commit c2924c5
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 21 deletions.
6 changes: 3 additions & 3 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def lascl : CallBlasPattern<(Op $layout, $type, $kl, $ku, $cfrom, $cto, $m, $n,
def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy),
["y"],[len, fp, vinc<["n"]>, vinc<["n"]>],
[
(b<"dot"> $n, adj<"y">, $incy, $x, $incx),
(b<"axpy"> $n, $alpha, adj<"y">, $incy, adj<"x">, $incx),
(b<"dot"> $n, adj<"y">, $x),
(b<"axpy"> $n, $alpha, adj<"y">, adj<"x">),
(noop) // y = alpha*x + y, so nothing to do here
]
>;
Expand All @@ -158,7 +158,7 @@ def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
["y"],[len, vinc<["n"]>, vinc<["n"]>],
[
(noop),// copy moves x into y, so x is never modified.
(b<"axpy"> $n, Constant<"1.0">, adj<"y">, $incy, adj<"x">, $incx)
(b<"axpy"> $n, Constant<"1.0">, adj<"y">, adj<"x">)
]
>;

Expand Down
3 changes: 1 addition & 2 deletions enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ entry:
; CHECK-DAG: %[[r20:.+]] = select i1 false, double* %"v0'", double* %cache.x_unwrap
; CHECK-DAG: %[[r21:.+]] = select i1 false, double* %cache.x_unwrap, double* %"v0'"
; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r20]], i32 1, double* %[[r21]], i32 1, double* %"K'", i32 %N)
; CHECK-NEXT: %[[i22:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A_unwrap, i32 %[[i22]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A_unwrap, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: %[[i23:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dscal(i32 %[[i23]], double 1.000000e+00, double* %"v0'", i32 1)
; CHECK-NEXT: %[[i24:.+]] = bitcast double* %cache.A_unwrap to i8*
Expand Down
12 changes: 4 additions & 8 deletions enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ entry:
; CHECK-DAG: %[[r42:.+]] = select i1 false, double* %"v0'", double* %cache.x
; CHECK-DAG: %[[r43:.+]] = select i1 false, double* %cache.x, double* %"v0'"
; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r42]], i32 1, double* %[[r43]], i32 1, double* %"K'", i32 %N)
; CHECK-NEXT: %[[i48:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A, i32 %[[i48]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: %[[i49:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dscal(i32 %[[i49]], double 1.000000e+00, double* %"v0'", i32 1)
; CHECK-NEXT: %[[i50:.+]] = bitcast double* %cache.A to i8*
Expand All @@ -212,8 +211,7 @@ entry:
; CHECK-DAG: %[[r48:.+]] = select i1 false, double* %"v0'", double* %cache.x8
; CHECK-DAG: %[[r49:.+]] = select i1 false, double* %cache.x8, double* %"v0'"
; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r48]], i32 1, double* %[[r49]], i32 1, double* %"K'", i32 %N)
; CHECK-NEXT: %[[i52:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A5, i32 %[[i52]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A5, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: %[[i53:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dscal(i32 %[[i53]], double 1.000000e+00, double* %"v0'", i32 1)
; CHECK-NEXT: %[[i54:.+]] = bitcast double* %cache.A5 to i8*
Expand All @@ -223,8 +221,7 @@ entry:
; CHECK-DAG: %[[r54:.+]] = select i1 false, double* %"v0'", double* %cache.x16
; CHECK-DAG: %[[r55:.+]] = select i1 false, double* %cache.x16, double* %"v0'"
; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r54]], i32 1, double* %[[r55]], i32 1, double* %"K'", i32 %N)
; CHECK-NEXT: %[[i56:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A13, i32 %[[i56]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A13, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: %[[i57:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dscal(i32 %[[i57]], double 1.000000e+00, double* %"v0'", i32 1)
; CHECK-NEXT: %[[i58:.+]] = bitcast double* %cache.A13 to i8*
Expand All @@ -234,8 +231,7 @@ entry:
; CHECK-DAG: %[[r60:.+]] = select i1 false, double* %"v0'", double* %cache.x24
; CHECK-DAG: %[[r61:.+]] = select i1 false, double* %cache.x24, double* %"v0'"
; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r60]], i32 1, double* %[[r61]], i32 1, double* %"K'", i32 %N)
; CHECK-NEXT: %[[i60:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A21, i32 %[[i60]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A21, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1)
; CHECK-NEXT: %[[i61:.+]] = select i1 false, i32 %N, i32 %N
; CHECK-NEXT: call void @cblas_dscal(i32 %[[i61]], double 1.000000e+00, double* %"v0'", i32 1)
; CHECK-NEXT: %[[i62:.+]] = bitcast double* %cache.A21 to i8*
Expand Down
20 changes: 12 additions & 8 deletions enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ entry:
; CHECK-NEXT: %ret = alloca double
; CHECK-NEXT: %byref.transpose.transa = alloca i8
; CHECK-NEXT: %byref.int.one = alloca i64
; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double
; CHECK-NEXT: %byref.constant.fp.0.0 = alloca double
; CHECK-DAG: %byref.constant.fp.1.0 = alloca double
; CHECK-DAG: %byref.constant.char.N = alloca i8, align 1
; CHECK-DAG: %byref.constant.fp.0.0 = alloca double
; CHECK-NEXT: %byref.constant.int.1 = alloca i64
; CHECK-NEXT: %byref.constant.int.17 = alloca i64
; CHECK-NEXT: %byref.constant.char.N11 = alloca i8, align 1
; CHECK-NEXT: %[[byrefconstantfp1:.+]] = alloca double
; CHECK-NEXT: %incy = alloca i64, i64 1, align 16
; CHECK-NEXT: %1 = bitcast i64* %incy to i8*
Expand Down Expand Up @@ -182,6 +184,7 @@ entry:
; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8*
; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0
; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8*
; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1
; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0
; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8*
; CHECK-NEXT: store i64 1, i64* %byref.constant.int.1
Expand Down Expand Up @@ -212,22 +215,23 @@ entry:
; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8* %20, i8* %"y'"
; CHECK-NEXT: %[[i54:.+]] = select i1 %[[i49]], i8* %intcast.int.one, i8* %incy_p
; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[i42]], i8* %[[i46]], i8* %[[i50]], i8* %[[i54]], i8* %"A'", i8* %lda_p)
; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N11, align 1
; CHECK-NEXT: store double 1.000000e+00, double* %[[byrefconstantfp1]]
; CHECK-NEXT: %[[fpcast14:.+]] = bitcast double* %[[byrefconstantfp1]] to i8*
; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %[[fpcast14]], i8* %"x'", i8* %incx_p)
; CHECK-NEXT: %ld.row.trans13 = load i8, i8* %malloccall
; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans13, 110
; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans13, 78
; CHECK-NEXT: %ld.row.trans14 = load i8, i8* %malloccall
; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans14, 110
; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans14, 78
; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]]
; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p
; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %21, i8* %intcast.int.one)
; CHECK-NEXT: %[[r44:.+]] = bitcast i8* %"beta'" to double*
; CHECK-NEXT: %[[r45:.+]] = load double, double* %[[r44]]
; CHECK-NEXT: %[[r46:.+]] = fadd fast double %[[r45]], %[[r43]]
; CHECK-NEXT: store double %[[r46]], double* %[[r44]]
; CHECK-NEXT: %ld.row.trans14 = load i8, i8* %malloccall
; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans14, 110
; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans14, 78
; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %malloccall
; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans15, 110
; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans15, 78
; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]]
; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p
; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p)
Expand Down
2 changes: 2 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ entry:
; CHECK-NEXT: %ret = alloca double
; CHECK-NEXT: %byref.transpose.transa = alloca i8
; CHECK-NEXT: %byref.int.one = alloca i64
; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double
; CHECK-NEXT: %incy = alloca i64, i64 1, align 16
; CHECK-NEXT: %1 = bitcast i64* %incy to i8*
Expand Down Expand Up @@ -203,6 +204,7 @@ entry:
; CHECK-NEXT: %[[r35:.+]] = select i1 %[[r34]], i8* %15, i8* %"y'"
; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r34]], i8* %intcast.int.one, i8* %incy_p
; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %[[r27]], i8* %[[r31]], i8* %[[r35]], i8* %[[r39]], i8* %"A'", i8* %lda_p)
; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1
; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0
; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8*
; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p)
Expand Down

0 comments on commit c2924c5

Please sign in to comment.