From c808878cb95c6c4bba270c821993b61c3b7cbfa4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 3 Nov 2023 00:56:50 -0500 Subject: [PATCH 1/6] Fix dlascl api (#1522) --- enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll | 4 ++-- .../Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll | 4 ++-- .../ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll | 4 ++-- .../test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll | 4 ++-- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 5 ++--- 12 files changed, 24 insertions(+), 25 deletions(-) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll index ecb99d346ebd..af8a12114419 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll @@ -57,7 +57,7 @@ entry: ; CHECK-DAG: %byref.constant.int.0 = alloca i64 ; CHECK-DAG: %[[byrefconstantint1:.+]] = alloca i64 ; CHECK-DAG: %[[byref_fp_1_00:.+]] = alloca double -; CHECK-DAG: %[[tmp:.+]] = alloca i8 +; CHECK-DAG: %[[tmp:.+]] = alloca i64 ; CHECK-DAG: %transa = alloca i8, align 1 ; CHECK-DAG: %transb = alloca i8, align 1 ; CHECK-DAG: %m = alloca i64, align 16 @@ -157,6 +157,6 @@ entry: ; CHECK-NEXT: %[[int02:.+]] = bitcast i64* %[[byrefconstantint1]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %[[byref_fp_1_00]] ; CHECK-NEXT: %[[fp11:.+]] = bitcast double* %[[byref_fp_1_00]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int02]], i8* %[[fp11]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int02]], i8* %[[fp11]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll index b36324ec40fb..e84c12e14701 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll @@ -59,7 +59,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 ; CHECK-NEXT: %[[byrefconstantint4:.+]] = alloca i64, align 8 ; CHECK-NEXT: %[[byref_fp_1_017:.+]] = alloca double, align 8 -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -275,7 +275,7 @@ entry: ; CHECK-NEXT: %[[intcast07:.+]] = bitcast i64* %[[byrefconstantint4]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %[[byref_fp_1_017]] ; CHECK-NEXT: %[[fpcast_1_018:.+]] = bitcast double* %[[byref_fp_1_017]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast07]], i8* %[[fpcast_1_018]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast07]], i8* %[[fpcast_1_018]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: %[[ret1:.+]] = bitcast double* %cache.A to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[ret1]]) ; CHECK-NEXT: %[[ret2:.+]] = bitcast double* %cache.B to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index c47db6a5656e..db74b9fd4484 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -61,7 +61,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 ; CHECK-NEXT: %[[byref_int_019:.+]] = alloca i64, align 8 ; CHECK-NEXT: %[[byref_fp_021:.+]] = alloca double, align 8 -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -348,7 +348,7 @@ entry: ; CHECK-NEXT: %[[intcast_020:.+]] = bitcast i64* %[[byref_int_019]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %[[byref_fp_021]] ; CHECK-NEXT: %[[fpcast_1_022:.+]] = bitcast double* %[[byref_fp_021]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_020]], i8* %[[fpcast_1_022]], i8* %beta, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_020]], i8* %[[fpcast_1_022]], i8* %beta, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: br label %invertentry.C.done ; CHECK: invertentry.C.done: ; preds = %invertentry.C.active, %invertentry.beta.done diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll index 4b00e5d5e51f..c66d4ff88d63 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll @@ -81,7 +81,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 ; CHECK-NEXT: %byref.constant.int.033 = alloca i64, align 8 ; CHECK-NEXT: %byref.constant.fp.1.035 = alloca double, align 8 -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %n = alloca i64, align 16 @@ -295,7 +295,7 @@ entry: ; CHECK-NEXT: %intcast.constant.int.034 = bitcast i64* %byref.constant.int.033 to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.035, align 8 ; CHECK-NEXT: %fpcast.constant.fp.1.036 = bitcast double* %byref.constant.fp.1.035 to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.034, i8* %fpcast.constant.fp.1.036, i8* %cast.beta, i8* %[[r37]], i8* %n_p_unwrap, i8* %"C'", i8* %cast.ldc, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.034, i8* %fpcast.constant.fp.1.036, i8* %cast.beta, i8* %[[r37]], i8* %n_p_unwrap, i8* %"C'", i8* %cast.ldc, i64* %[[tmp]], i64 1) ; CHECK-NEXT: %[[r77:.+]] = bitcast double* %cache.A_unwrap to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[r77]]) ; CHECK-NEXT: call void @free(i8* %[[r37]]) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll index 1f7f69824d79..0e30a9fdc7da 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll @@ -163,7 +163,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 ; CHECK-NEXT: %[[byref_int_0:.+]] = alloca i64 ; CHECK-NEXT: %[[byref_fp_1_011:.+]] = alloca double -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %ldc = alloca i64, i64 1, align 16 ; CHECK-NEXT: %[[i1:.+]] = bitcast i64* %ldc to i8* ; CHECK-NEXT: %beta = alloca double, i64 1, align 16 @@ -289,7 +289,7 @@ entry: ; CHECK-NEXT: %[[intcast_010:.+]] = bitcast i64* %[[byref_int_0]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %[[byref_fp_1_011]], align 8 ; CHECK-NEXT: %[[fpcast_1_0:.+]] = bitcast double* %[[byref_fp_1_011]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_010]], i8* %[[fpcast_1_0]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_010]], i8* %[[fpcast_1_0]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: %[[r70:.+]] = bitcast double* %0 to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[r70]]) ; CHECK-NEXT: ret void diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index 821e44755ac3..f58a4b2e209a 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -135,7 +135,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 ; CHECK-NEXT: %[[int09:.+]] = alloca i64 ; CHECK-NEXT: %[[fp11:.+]] = alloca double -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %ldc = alloca i64, i64 1, align 16 ; CHECK-NEXT: %[[i1:.+]] = bitcast i64* %ldc to i8* ; CHECK-NEXT: %beta = alloca double, i64 1, align 16 @@ -263,7 +263,7 @@ entry: ; CHECK-NEXT: %[[int010:.+]] = bitcast i64* %[[int09:.+]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %[[fp11]], align 8 ; CHECK-NEXT: %[[fp12:.+]] = bitcast double* %[[fp11]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int010]], i8* %[[fp12]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int010]], i8* %[[fp12]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: %[[r70:.+]] = bitcast double* %0 to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[r70]]) ; CHECK-NEXT: ret void diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index 39af9fd38448..2a7bbed40c24 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -135,7 +135,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 ; CHECK-NEXT: %[[byref_int_0:.+]] = alloca i64, align 8 ; CHECK-NEXT: %[[byref_fp_1_011:.+]] = alloca double, align 8 -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %ldc = alloca i64, i64 1, align 16 ; CHECK-NEXT: %[[i1:.+]] = bitcast i64* %ldc to i8* ; CHECK-NEXT: %beta = alloca double, i64 1, align 16 @@ -261,7 +261,7 @@ entry: ; CHECK-NEXT: %[[intcast_010:.+]] = bitcast i64* %[[byref_int_0]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %[[byref_fp_1_011]], align 8 ; CHECK-NEXT: %[[fpcast_1_012:.+]] = bitcast double* %[[byref_fp_1_011]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_010]], i8* %[[fpcast_1_012]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_010]], i8* %[[fpcast_1_012]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: %[[r70:.+]] = bitcast double* %0 to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[r70]]) ; CHECK-NEXT: ret void diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index e4315fc7ce91..964b53b1b925 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -61,7 +61,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 ; CHECK-NEXT: %[[int04:.+]] = alloca i64 ; CHECK-NEXT: %[[byref_fp_1_018:.+]] = alloca double -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -219,7 +219,7 @@ entry: ; CHECK-NEXT: %[[intcast08:.+]] = bitcast i64* %[[int04]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %[[fp19:.+]] = bitcast double* %[[byref_fp_1_018]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %[[intcast0]], i8* %[[intcast08]], i8* %[[fp19]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %[[intcast0]], i8* %[[intcast08]], i8* %[[fp19]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: %[[free1:.+]] = bitcast double* %cache.A to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[free1]]) ; CHECK-NEXT: %[[free2:.+]] = bitcast double* %cache.B to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index ee3455d83c32..a4c00d61bbc4 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -59,7 +59,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 ; CHECK-NEXT: %[[byrefint03:.+]] = alloca i64 ; CHECK-NEXT: %byref.constant.fp.1.06 = alloca double -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -159,7 +159,7 @@ entry: ; CHECK-NEXT: %intcast.constant.int.05 = bitcast i64* %byref.constant.int.04 to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.07 = bitcast double* %byref.constant.fp.1.06 to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.05, i8* %fpcast.constant.fp.1.07, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.05, i8* %fpcast.constant.fp.1.07, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: %[[ret:.+]] = bitcast double* %cache.B to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[ret]]) ; CHECK-NEXT: ret void diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll index 872e1849b777..f913be5728d1 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll @@ -59,7 +59,7 @@ entry: ; CHECK-DAG: %byref.constant.int.0 = alloca i64 ; CHECK-DAG: %[[byref_int_08:.+]] = alloca i64, align 8 ; CHECK-DAG: %[[byref_fp_1_010:.+]] = alloca double -; CHECK-DAG: %[[tmp:.+]] = alloca i8 +; CHECK-DAG: %[[tmp:.+]] = alloca i64 ; CHECK-DAG: %transa = alloca i8, align 1 ; CHECK-DAG: %transb = alloca i8, align 1 ; CHECK-DAG: %m = alloca i64, align 16 @@ -164,6 +164,6 @@ entry: ; CHECK-NEXT: %[[int02:.+]] = bitcast i64* %[[byref_int_08]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %[[fp11:.+]] = bitcast double* %[[byref_fp_1_010]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int02]], i8* %[[fp11]], i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int02]], i8* %[[fp11]], i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll index 460b9b9a48a1..a8967283f49c 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll @@ -59,7 +59,7 @@ entry: ; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 ; CHECK-NEXT: %[[byref_int_0:.+]] = alloca i64, align 8 ; CHECK-NEXT: %[[byref_fp_1_0:.+]] = alloca double, align 8 -; CHECK-NEXT: %[[tmp:.+]] = alloca i8 +; CHECK-NEXT: %[[tmp:.+]] = alloca i64 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -164,7 +164,7 @@ entry: ; CHECK-NEXT: %[[intcast_0:.+]] = bitcast i64* %[[byref_int_0]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %[[byref_fp_1_0]], align 8 ; CHECK-NEXT: %[[fpcast_1:.+]] = bitcast double* %[[byref_fp_1_0]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_0]], i8* %[[fpcast_1]], i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[tmp]], i64 1) +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast_0]], i8* %[[fpcast_1]], i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %[[tmp]], i64 1) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 843a09053ff9..6fced28311f1 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -1195,9 +1195,8 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, } } else if (Def->isSubClassOf("Alloca")) { auto val = Def->getValueAsInt("value"); - os << "{allocationBuilder.CreateAlloca(Type::getIntNTy(allocationBuilder." - "getContext(), " - << (8 * val) << "))}"; + assert(val == 1); + os << "{allocationBuilder.CreateAlloca(intType)}"; } else if (Def->isSubClassOf("ConstantInt")) { auto val = Def->getValueAsInt("value"); os << "{to_blas_callconv(Builder2, ConstantInt::get(intType, " << val From dfb8b55b8c3bc069a6fe3e4a86190b749d0d2d8b Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 3 Nov 2023 13:21:59 -0500 Subject: [PATCH 2/6] Attempt std optional fix for older macos (#1523) --- enzyme/Enzyme/FunctionUtils.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 3ba92f96c5c2..5abc0a3729f3 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -4624,8 +4624,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (fneg->getOperand(1) == PN) legal = false; if (cmpPredicate) { - if (cmpPredicate.value() != - cast(fneg)->getPredicate()) + if (*cmpPredicate != cast(fneg)->getPredicate()) legal = false; } else { cmpPredicate = cast(fneg)->getPredicate(); @@ -4727,7 +4726,7 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, break; case Instruction::FCmp: case Instruction::ICmp: - fneg = B.CreateCmp(cmpPredicate.value(), lhsPN, rhsPN); + fneg = B.CreateCmp(*cmpPredicate, lhsPN, rhsPN); break; case Instruction::UIToFP: fneg = B.CreateUIToFP(lhsPN, PN->getType()); From 7508ad4fec3336047b4dfb34bf8c74d3f017325c Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 3 Nov 2023 22:35:09 -0500 Subject: [PATCH 3/6] Fix SCEV memory error (#1524) * Fix SCEV memory error * Add smalltypeof to inactive --- enzyme/Enzyme/ActivityAnalysis.cpp | 1 + enzyme/Enzyme/GradientUtils.cpp | 114 ++++++++++++------ enzyme/Enzyme/Utils.cpp | 15 ++- enzyme/test/Enzyme/ReverseMode/addrbug.ll | 4 +- enzyme/test/Enzyme/ReverseMode/makememcpy1.ll | 4 +- .../test/Enzyme/ReverseMode/metacachelicm.ll | 4 +- .../test/Enzyme/ReverseMode/metacachelicm2.ll | 4 +- enzyme/test/Enzyme/ReverseMode/rwrloop.ll | 2 +- enzyme/test/Enzyme/ReverseMode/sploop2.ll | 4 +- 9 files changed, 96 insertions(+), 56 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 4db0f8342f69..6cf2a22ad294 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -109,6 +109,7 @@ const char *KnownInactiveFunctionsContains[] = { "__enzyme_pointer"}; const StringSet<> InactiveGlobals = { + "small_typeof", "ompi_request_null", "ompi_mpi_double", "ompi_mpi_comm_world", diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index e5929a1fd763..f08fea3f44ba 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -6094,6 +6094,11 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, mode == DerivativeMode::ReverseModeCombined); assert(val->getName() != ""); + { + auto found = incoming_available.find(val); + if (found != incoming_available.end()) + return found->second; + } if (isa(val)) { return val; } @@ -6121,7 +6126,6 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } auto inst = cast(val); - assert(inst->getName() != ""); if (inversionAllocs && inst->getParent() == inversionAllocs) { return val; } @@ -6418,7 +6422,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto li2obj = getBaseObject(li2->getPointerOperand()); if (liobj == li2obj && DT.dominates(li2, li)) { - auto orig2 = isOriginal(li2); + auto orig2 = dyn_cast_or_null(isOriginal(li2)); if (!orig2) continue; @@ -6427,8 +6431,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // llvm::errs() << "found potential candidate loads: oli:" // << *origInst << " oli2: " << *orig2 << "\n"; - auto scev1 = SE.getSCEV(li->getPointerOperand()); - auto scev2 = SE.getSCEV(li2->getPointerOperand()); + auto scev1 = SE.getSCEV(origInst->getPointerOperand()); + auto scev2 = SE.getSCEV(orig2->getPointerOperand()); // llvm::errs() << " scev1: " << *scev1 << " scev2: " << *scev2 // << "\n"; @@ -6449,11 +6453,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (auto ar1 = dyn_cast(scev1)) { if (auto ar2 = dyn_cast(scev2)) { - if (ar1->getStart() != SE.getCouldNotCompute() && + if (ar1->getStart() != OrigSE.getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar1->getStepRecurrence(SE) != SE.getCouldNotCompute() && - ar1->getStepRecurrence(SE) == - ar2->getStepRecurrence(SE)) { + ar1->getStepRecurrence(OrigSE) != + OrigSE.getCouldNotCompute() && + ar1->getStepRecurrence(OrigSE) == + ar2->getStepRecurrence(OrigSE)) { LoopContext l1; getContext(ar1->getLoop()->getHeader(), l1); @@ -6848,7 +6853,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = SE.getSCEV(li->getPointerOperand()); + auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); // Store in memcpy opt Value *lim = nullptr; BasicBlock *ctx = nullptr; @@ -6856,12 +6861,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Value *offset = nullptr; if (auto ar1 = dyn_cast(scev1)) { if (auto step = - dyn_cast(ar1->getStepRecurrence(SE))) { + dyn_cast(ar1->getStepRecurrence(OrigSE))) { if (step->getAPInt() != loadSize) goto noSpeedCache; LoopContext l1; - getContext(ar1->getLoop()->getHeader(), l1); + getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), l1); if (l1.dynamic) goto noSpeedCache; @@ -6886,40 +6891,69 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, lim = v.CreateAdd(lim, ConstantInt::get(lim->getType(), 1), "", true, true); - SmallVector toErase; { #if LLVM_VERSION_MAJOR >= 12 - SCEVExpander Exp(SE, - ctx->getParent()->getParent()->getDataLayout(), - "enzyme"); -#else - fake::SCEVExpander Exp( - SE, ctx->getParent()->getParent()->getDataLayout(), - "enzyme"); -#endif - Exp.setInsertPoint(l1.header->getTerminator()); - Value *start0 = Exp.expandCodeFor( - ar1->getStart(), li->getPointerOperand()->getType()); - start = unwrapM(start0, v, - /*available*/ ValueToValueMapTy(), - UnwrapMode::AttemptFullUnwrapWithLookup); - std::set todo = {start0}; - while (todo.size()) { - Value *now = *todo.begin(); - todo.erase(now); - if (Instruction *inst = dyn_cast(now)) { - if (inst != start && inst->getNumUses() == 0 && - Exp.isInsertedInstruction(inst)) { - for (auto &op : inst->operands()) { - todo.insert(op); - } - toErase.push_back(inst); - } + Value *start0; + SmallVector InsertedInstructions; + { + SCEVExpander OrigExp( + OrigSE, ctx->getParent()->getParent()->getDataLayout(), + "enzyme"); + + OrigExp.setInsertPoint( + isOriginal(l1.header)->getTerminator()); + + start0 = OrigExp.expandCodeFor( + ar1->getStart(), li->getPointerOperand()->getType()); + InsertedInstructions = OrigExp.getAllInsertedInstructions(); + } + + ValueToValueMapTy available; + for (const auto &pair : originalToNewFn) { + if (pair.first->getType() == pair.second->getType()) + available[pair.first] = pair.second; + } + + // Sort so that later instructions do not dominate earlier + // instructions. + llvm::stable_sort(InsertedInstructions, + [this](Instruction *A, Instruction *B) { + return OrigDT.dominates(A, B); + }); + for (auto a : InsertedInstructions) { + assert(!isa(a)); + auto uw = cast( + unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap, + /*scope*/ nullptr, /*cache*/ false)); + assert(uw->getType() == a->getType()); + for (size_t i = 0; i < uw->getNumOperands(); i++) { + auto op = uw->getOperand(i); + if (auto arg = dyn_cast(op)) + assert(arg->getParent() == newFunc); + else if (auto inst = dyn_cast(op)) + assert(inst->getParent()->getParent() == newFunc); } + available[a] = uw; + unwrappedLoads.erase(cast(uw)); } + + start = + isa(start0) ? start0 : (Value *)available[start0]; + if (!start) { + llvm::errs() << "old: " << *oldFunc << "\n"; + llvm::errs() << "new: " << *newFunc << "\n"; + llvm::errs() << "start0: " << *start0 << "\n"; + } + assert(start); + + available.clear(); + for (auto I : llvm::reverse(InsertedInstructions)) { + assert(I->getNumUses() == 0); + OrigSE.forgetValue(I); + I->eraseFromParent(); + } +#endif } - for (auto a : toErase) - erase(a); if (!start) goto noSpeedCache; diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 1a4dc9afd87e..3d5ee679064c 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -1914,7 +1914,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, if (auto LI = dyn_cast(maybeReader)) { LoadBegin = SE.getSCEV(LI->getPointerOperand()); - if (LoadBegin != SE.getCouldNotCompute()) { + if (LoadBegin != SE.getCouldNotCompute() && + !LoadBegin->getType()->isIntegerTy()) { auto &DL = maybeWriter->getModule()->getDataLayout(); auto width = cast(DL.getIndexType(LoadBegin->getType())) ->getBitWidth(); @@ -1930,7 +1931,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, } if (auto SI = dyn_cast(maybeWriter)) { StoreBegin = SE.getSCEV(SI->getPointerOperand()); - if (StoreBegin != SE.getCouldNotCompute()) { + if (StoreBegin != SE.getCouldNotCompute() && + !StoreBegin->getType()->isIntegerTy()) { auto &DL = maybeWriter->getModule()->getDataLayout(); auto width = cast(DL.getIndexType(StoreBegin->getType())) ->getBitWidth(); @@ -1948,7 +1950,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, } if (auto MS = dyn_cast(maybeWriter)) { StoreBegin = SE.getSCEV(MS->getArgOperand(0)); - if (StoreBegin != SE.getCouldNotCompute()) { + if (StoreBegin != SE.getCouldNotCompute() && + !StoreBegin->getType()->isIntegerTy()) { if (auto Len = dyn_cast(MS->getArgOperand(2))) { auto &DL = MS->getModule()->getDataLayout(); auto width = cast(DL.getIndexType(StoreBegin->getType())) @@ -1961,7 +1964,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, } if (auto MS = dyn_cast(maybeWriter)) { StoreBegin = SE.getSCEV(MS->getArgOperand(0)); - if (StoreBegin != SE.getCouldNotCompute()) { + if (StoreBegin != SE.getCouldNotCompute() && + !StoreBegin->getType()->isIntegerTy()) { if (auto Len = dyn_cast(MS->getArgOperand(2))) { auto &DL = MS->getModule()->getDataLayout(); auto width = cast(DL.getIndexType(StoreBegin->getType())) @@ -1974,7 +1978,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, } if (auto MS = dyn_cast(maybeReader)) { LoadBegin = SE.getSCEV(MS->getArgOperand(1)); - if (LoadBegin != SE.getCouldNotCompute()) { + if (LoadBegin != SE.getCouldNotCompute() && + !LoadBegin->getType()->isIntegerTy()) { if (auto Len = dyn_cast(MS->getArgOperand(2))) { auto &DL = MS->getModule()->getDataLayout(); auto width = cast(DL.getIndexType(LoadBegin->getType())) diff --git a/enzyme/test/Enzyme/ReverseMode/addrbug.ll b/enzyme/test/Enzyme/ReverseMode/addrbug.ll index d69793906e3a..0ac9b1313fb5 100644 --- a/enzyme/test/Enzyme/ReverseMode/addrbug.ll +++ b/enzyme/test/Enzyme/ReverseMode/addrbug.ll @@ -1,5 +1,5 @@ -; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -loop-deletion -simplifycfg -correlated-propagation -adce -instsimplify -S | FileCheck %s; fi -; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),%simplifycfg,correlated-propagation,adce,instsimplify)" -enzyme-preopt=false -S | FileCheck %s +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -loop-deletion -simplifycfg -correlated-propagation -adce -instsimplify -S | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),%simplifycfg,correlated-propagation,adce,instsimplify)" -enzyme-preopt=false -S | FileCheck %s; fi ; Function Attrs: nounwind declare void @__enzyme_autodiff(i8*, ...) diff --git a/enzyme/test/Enzyme/ReverseMode/makememcpy1.ll b/enzyme/test/Enzyme/ReverseMode/makememcpy1.ll index ee87604a9bfd..78b7d7cd4bfd 100644 --- a/enzyme/test/Enzyme/ReverseMode/makememcpy1.ll +++ b/enzyme/test/Enzyme/ReverseMode/makememcpy1.ll @@ -1,5 +1,5 @@ -; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -loop-deletion -correlated-propagation -adce -simplifycfg -S | FileCheck %s; fi -; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),correlated-propagation,adce,%simplifycfg)" -S | FileCheck %s +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -loop-deletion -correlated-propagation -adce -simplifycfg -S | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),correlated-propagation,adce,%simplifycfg)" -S | FileCheck %s; fi ; This requires the additional optimization to create memcpy's diff --git a/enzyme/test/Enzyme/ReverseMode/metacachelicm.ll b/enzyme/test/Enzyme/ReverseMode/metacachelicm.ll index 38b79cc5527c..5b65639dbfe0 100644 --- a/enzyme/test/Enzyme/ReverseMode/metacachelicm.ll +++ b/enzyme/test/Enzyme/ReverseMode/metacachelicm.ll @@ -1,5 +1,5 @@ -; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi -; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi ; Function Attrs: nounwind uwtable define dso_local void @compute(double* noalias nocapture %data, i64* noalias nocapture readnone %array, double* noalias nocapture %out) #0 { diff --git a/enzyme/test/Enzyme/ReverseMode/metacachelicm2.ll b/enzyme/test/Enzyme/ReverseMode/metacachelicm2.ll index e7256baa9f78..f72b15d3c995 100644 --- a/enzyme/test/Enzyme/ReverseMode/metacachelicm2.ll +++ b/enzyme/test/Enzyme/ReverseMode/metacachelicm2.ll @@ -1,5 +1,5 @@ -; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi -; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi ; Function Attrs: nounwind uwtable define dso_local void @compute(double* noalias nocapture %data, i64* noalias nocapture readonly %array, double* noalias nocapture %out) #0 { diff --git a/enzyme/test/Enzyme/ReverseMode/rwrloop.ll b/enzyme/test/Enzyme/ReverseMode/rwrloop.ll index 35fc44070550..a902119b1b56 100644 --- a/enzyme/test/Enzyme/ReverseMode/rwrloop.ll +++ b/enzyme/test/Enzyme/ReverseMode/rwrloop.ll @@ -133,11 +133,11 @@ attributes #9 = { noreturn nounwind } ; CHECK: for.cond1.preheader: ; preds = %for.cond.cleanup3, %entry ; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup3 ], [ 0, %entry ] -; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10 ; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 ; CHECK-NEXT: br i1 %cmp233, label %for.body4.lr.ph, label %for.cond.cleanup3 ; CHECK: for.body4.lr.ph: ; preds = %for.cond1.preheader +; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10 ; CHECK-NEXT: %[[a3:.+]] = load i32, i32* %N, align 4, !tbaa !2, !alias.scope !8, !noalias !11, !invariant.group ![[INVG:[0-9]]] ; CHECK-NEXT: %[[a4:.+]] = getelementptr inbounds i32, i32* %[[malloccache12]], i64 %iv ; CHECK-NEXT: store i32 %[[a3]], i32* %[[a4]], align 4, !tbaa !2, !invariant.group ![[INVG]] diff --git a/enzyme/test/Enzyme/ReverseMode/sploop2.ll b/enzyme/test/Enzyme/ReverseMode/sploop2.ll index 1ef2c1c3211b..0a32b12719e8 100644 --- a/enzyme/test/Enzyme/ReverseMode/sploop2.ll +++ b/enzyme/test/Enzyme/ReverseMode/sploop2.ll @@ -1,5 +1,5 @@ -; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi -; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s; fi ; This requires the memcpy optimization to run From a1d95f5ec0eef83a75975dcd76aece632609b18f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 5 Nov 2023 23:57:47 -0600 Subject: [PATCH 4/6] [Blas] fix temporary allocation (and deallocation) (#1525) * [Blas] fix temporary allocation (and deallocation) * fixups * fix * fix * fix test --- enzyme/Enzyme/AdjointGenerator.h | 7 +- enzyme/Enzyme/DiffeGradientUtils.cpp | 15 +- enzyme/Enzyme/Utils.cpp | 23 +- .../blas/gemm_f_c_lacpy_runtime_act.ll | 23 +- .../blas/gemv_f_c_split_blascpy.ll | 44 +- .../Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll | 22 +- .../test/Integration/ReverseMode/blas_gemm.c | 75 ++ enzyme/test/Integration/blas_inline.h | 1159 +++++++++++++++++ enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 57 +- 9 files changed, 1349 insertions(+), 76 deletions(-) create mode 100644 enzyme/test/Integration/ReverseMode/blas_gemm.c create mode 100644 enzyme/test/Integration/blas_inline.h diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index fbe42322944a..c1be8be8ab6c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1673,11 +1673,10 @@ class AdjointGenerator .getTypeSizeInBits(SVI.getOperand(opnum)->getType()) + 7) / 8; + Value *toadd = Builder2.CreateExtractElement(loaded, instidx); ((DiffeGradientUtils *)gutils) - ->addToDiffe(SVI.getOperand(opnum), - Builder2.CreateExtractElement(loaded, instidx), - Builder2, TR.addingType(size, SVI.getOperand(opnum)), - sv); + ->addToDiffe(SVI.getOperand(opnum), toadd, Builder2, + TR.addingType(size, SVI.getOperand(opnum)), sv); } ++instidx; } diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index ce40131377eb..7e240129abb6 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -321,6 +321,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, Value *ptr = getDifferential(val); + Value *old; if (idxs.size() != 0) { SmallVector sv = { ConstantInt::get(Type::getInt32Ty(val->getContext()), 0)}; @@ -328,8 +329,12 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, sv.push_back(i); ptr = BuilderM.CreateGEP(getShadowType(val->getType()), ptr, sv); cast(ptr)->setIsInBounds(true); + old = BuilderM.CreateLoad( + GetElementPtrInst::getIndexedType(getShadowType(val->getType()), sv), + ptr); + } else { + old = BuilderM.CreateLoad(getShadowType(val->getType()), ptr); } - Value *old = BuilderM.CreateLoad(dif->getType(), ptr); assert(dif->getType() == old->getType()); Value *res = nullptr; @@ -352,7 +357,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, &TR.analyzer, nullptr, wrap(&BuilderM)); return addedSelects; } else { - TR.dump(); + TR.dump(ss); DebugLoc loc; if (auto inst = dyn_cast(val)) EmitFailure("CannotDeduceType", inst->getDebugLoc(), inst, ss.str()); @@ -381,13 +386,15 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, llvm::raw_string_ostream ss(s); ss << "oldFunc: " << *oldFunc << "\n"; ss << "Illegal intermediate when adding to: " << *val - << " with addingType: " << *addingType << "\n"; + << " with addingType: " << *addingType << "\n" + << " old: " << *old << " dif: " << *dif << "\n" + << " oldBitSize: " << oldBitSize << " newBitSize: " << newBitSize + << "\n"; if (CustomErrorHandler) { CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType, &TR.analyzer, nullptr, wrap(&BuilderM)); return addedSelects; } else { - TR.dump(); DebugLoc loc; if (auto inst = dyn_cast(val)) EmitFailure("CannotDeduceType", inst->getDebugLoc(), inst, ss.str()); diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 3d5ee679064c..f028ec906f53 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -667,8 +667,8 @@ void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false); auto fn = M.getOrInsertFunction(copy_name, FT); - Function *F = cast(fn.getCallee()); - attributeKnownFunctions(*F); + if (auto F = GetFunctionFromValue(fn.getCallee())) + attributeKnownFunctions(*F); B.CreateCall(fn, args, bundles); } @@ -884,9 +884,9 @@ getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, std::string dot_name = blas.prefix + blas.floatType + "dot" + blas.suffix; auto FDotT = FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT}, false); - Function *FDot = - cast(M.getOrInsertFunction(dot_name, FDotT).getCallee()); - attributeKnownFunctions(*F); + auto FDot = M.getOrInsertFunction(dot_name, FDotT); + if (auto F = GetFunctionFromValue(FDot.getCallee())) + attributeKnownFunctions(*F); // now add the implementation for the inner_prod call F->setLinkage(Function::LinkageTypes::InternalLinkage); @@ -933,12 +933,19 @@ getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, { IRBuilder<> B1(entry); Value *blasOne = to_blas_callconv(B1, ConstantInt::get(IT, 1), byRef, - cublas, IT, B1, "constant.one"); + cublas, nullptr, B1, "constant.one"); + + if (blasOne->getType() != BlasIT) + blasOne = B1.CreatePointerCast(blasOne, BlasIT, "intcast.constant.one"); + Value *m = load_if_ref(B1, IT, blasm, byRef); Value *n = load_if_ref(B1, IT, blasn, byRef); Value *size = B1.CreateNUWMul(m, n, "mat.size"); - Value *blasSize = - to_blas_callconv(B1, size, byRef, cublas, IT, B1, "mat.size"); + Value *blasSize = to_blas_callconv( + B1, size, byRef, cublas, julia_decl ? IT : nullptr, B1, "mat.size"); + + if (blasSize->getType() != BlasIT) + blasSize = B1.CreatePointerCast(blasSize, BlasIT, "intcast.mat.size"); B1.CreateCondBr(B1.CreateICmpEQ(size, ConstantInt::get(IT, 0)), end, init); IRBuilder<> B2(init); diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index db74b9fd4484..0918ca43f83f 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -120,15 +120,6 @@ entry: ; CHECK-NEXT: %cache.C = bitcast i8* %malloccall2 to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage3 ; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %n_p) -; CHECK-NEXT: %[[i17:.+]] = bitcast i8* %m_p to i64* -; CHECK-NEXT: %[[i18:.+]] = load i64, i64* %[[i17]] -; CHECK-NEXT: %[[i19:.+]] = bitcast i8* %n_p to i64* -; CHECK-NEXT: %[[i20:.+]] = load i64, i64* %[[i19]] -; CHECK-NEXT: %size_AB = mul nuw i64 %[[i18]], %[[i20]] -; CHECK-NEXT: %mallocsize5 = mul nuw nsw i64 %size_AB, 8 -; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize5) -; CHECK-NEXT: %mat_AB = bitcast i8* %malloccall6 to double* -; CHECK-NEXT: %[[i21:.+]] = bitcast double* %mat_AB to i8* ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* @@ -145,6 +136,16 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.alpha, label %invertentry.alpha.done, label %invertentry.alpha.active ; CHECK: invertentry.alpha.active: ; preds = %invertentry +; CHECK-NEXT: %[[i17:.+]] = bitcast i8* %m_p to i64* +; CHECK-NEXT: %[[i18:.+]] = load i64, i64* %[[i17]] +; CHECK-NEXT: %[[i19:.+]] = bitcast i8* %n_p to i64* +; CHECK-NEXT: %[[i20:.+]] = load i64, i64* %[[i19]] +; CHECK-NEXT: %size_AB = mul nuw i64 %[[i18]], %[[i20]] +; CHECK-NEXT: %mallocsize5 = mul nuw nsw i64 %size_AB, 8 +; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize5) +; CHECK-NEXT: %[[matAB:.+]] = bitcast i8* %malloccall6 to double* +; CHECK-NEXT: %[[i21:.+]] = bitcast double* %[[matAB:.+]] to i8* + ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* ; CHECK-NEXT: %loaded.trans7 = load i8, i8* %transa @@ -183,7 +184,7 @@ entry: ; CHECK-NEXT: %iteration.i = phi i64 [ 0, %init.idx.i ], [ %iter.next.i, %for.body.i ] ; CHECK-NEXT: %sum.i = phi{{( fast)?}} double [ 0.000000e+00, %init.idx.i ], [ %[[i57:.+]], %for.body.i ] ; CHECK-NEXT: %A.i.i = getelementptr inbounds double, double* %[[i51]], i64 %Aidx.i -; CHECK-NEXT: %B.i.i = getelementptr inbounds double, double* %mat_AB, i64 %Bidx.i +; CHECK-NEXT: %B.i.i = getelementptr inbounds double, double* %[[matAB]], i64 %Bidx.i ; CHECK-NEXT: %[[i54:.+]] = bitcast double* %A.i.i to i8* ; CHECK-NEXT: %[[i55:.+]] = bitcast double* %B.i.i to i8* ; CHECK-NEXT: %[[i56:.+]] = call fast double @ddot_64_(i8* %m_p, i8* %[[i54]], i8* %intcast.constant.one.i, i8* %[[i55]], i8* %intcast.constant.one.i) @@ -202,6 +203,8 @@ entry: ; CHECK-NEXT: %[[i62:.+]] = load double, double* %[[i61]] ; CHECK-NEXT: %[[i63:.+]] = fadd fast double %[[i62]], %res.i ; CHECK-NEXT: store double %[[i63]], double* %[[i61]] +; CHECK-NEXT: %[[forfree:.+]] = bitcast double* %[[matAB]] to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[forfree:.+]]) ; CHECK-NEXT: br label %invertentry.alpha.done ; CHECK: invertentry.alpha.done: ; preds = %__enzyme_inner_prodd_64_.exit, %invertentry diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll index 466a2d16be5f..4f543b580bdc 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll @@ -100,7 +100,7 @@ entry: ; CHECK-NEXT: %malloccall9 = tail call noalias nonnull i8* @malloc(i64 %mallocsize8) ; CHECK-NEXT: %cache.y = bitcast i8* %malloccall9 to double* ; CHECK-NEXT: store i64 1, i64* %byref.10 -; CHECK-NEXT: call void @dcopy_64_(i8* %20, i8* %y, i8* %incy_p, double* %cache.y, i64* %byref.10) +; CHECK-NEXT: call void @dcopy_64_(i8* %[[r20]], i8* %y, i8* %incy_p, double* %cache.y, i64* %byref.10) ; CHECK-NEXT: %23 = insertvalue { double*, double* } undef, double* %cache.x, 0 ; CHECK-NEXT: %24 = insertvalue { double*, double* } %23, double* %cache.y, 1 ; CHECK-NEXT: store { double*, double* } %24, { double*, double* }* %0 @@ -148,30 +148,32 @@ entry: ; CHECK-NEXT: store i64 4, i64* %8, align 16 ; CHECK-NEXT: store i64 2, i64* %9, align 16 ; CHECK-NEXT: store i64 1, i64* %10, align 16 -; CHECK-NEXT: %11 = bitcast i8* %m_p to i64* -; CHECK-NEXT: %12 = load i64, i64* %11 -; CHECK-NEXT: %13 = bitcast i8* %n_p to i64* -; CHECK-NEXT: %14 = load i64, i64* %13 -; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[r15:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[r16:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-NEXT: %[[r17:.+]] = or i1 %[[r16]], %[[r15]] -; CHECK-NEXT: %18 = select i1 %[[r17]], i64 %12, i64 %14 -; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %18, 8 -; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize) -; CHECK-NEXT: %mat_Ax = bitcast i8* %malloccall6 to double* -; CHECK-NEXT: %19 = bitcast double* %mat_Ax to i8* ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry ; CHECK-NEXT: %tape.ext.x = extractvalue { double*, double* } %0, 0 -; CHECK-NEXT: %20 = bitcast double* %tape.ext.x to i8* +; CHECK-NEXT: %[[r20:.+]] = bitcast double* %tape.ext.x to i8* ; CHECK-NEXT: %tape.ext.y = extractvalue { double*, double* } %0, 1 -; CHECK-NEXT: %21 = bitcast double* %tape.ext.y to i8* +; CHECK-NEXT: %[[r21:.+]] = bitcast double* %tape.ext.y to i8* ; CHECK-NEXT: %tape.ext.y1 = extractvalue { double*, double* } %0, 1 -; CHECK-NEXT: %22 = bitcast double* %tape.ext.y1 to i8* +; CHECK-NEXT: %[[r22:.+]] = bitcast double* %tape.ext.y1 to i8* ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* + +; CHECK-NEXT: %[[r11:.+]] = bitcast i8* %m_p to i64* +; CHECK-NEXT: %[[r12:.+]] = load i64, i64* %[[r11]] +; CHECK-NEXT: %[[r13:.+]] = bitcast i8* %n_p to i64* +; CHECK-NEXT: %[[r14:.+]] = load i64, i64* %[[r13]] +; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall +; CHECK-DAG: %[[r15:.+]] = icmp eq i8 %loaded.trans, 78 +; CHECK-DAG: %[[r16:.+]] = icmp eq i8 %loaded.trans, 110 +; CHECK-NEXT: %[[r17:.+]] = or i1 %[[r16]], %[[r15]] +; CHECK-NEXT: %[[r18:.+]] = select i1 %[[r17]], i64 %[[r12]], i64 %[[r14]] +; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %[[r18]], 8 +; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize) +; CHECK-NEXT: %[[matAx:.+]] = bitcast i8* %malloccall6 to double* +; CHECK-NEXT: %[[r19:.+]] = bitcast double* %[[matAx]] to i8* + ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* ; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 @@ -179,7 +181,7 @@ entry: ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.1 ; CHECK-NEXT: %intcast.constant.int.1 = bitcast i64* %byref.constant.int.1 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %fpcast.constant.fp.1.0, i8* %A, i8* %lda_p, i8* %20, i8* %intcast.int.one, i8* %fpcast.constant.fp.0.0, i8* %19, i8* %intcast.constant.int.1, i64 1) +; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %fpcast.constant.fp.1.0, i8* %A, i8* %lda_p, i8* %[[r20]], i8* %intcast.int.one, i8* %fpcast.constant.fp.0.0, i8* %[[r19]], i8* %intcast.constant.int.1, i64 1) ; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[c1:.+]] = icmp eq i8 %ld.row.trans, 110 ; CHECK-DAG: %[[c2:.+]] = icmp eq i8 %ld.row.trans, 78 @@ -187,11 +189,15 @@ entry: ; CHECK-NEXT: %[[r34:.+]] = select i1 %[[c3]], i8* %m_p, i8* %n_p ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.17 ; CHECK-NEXT: %intcast.constant.int.18 = bitcast i64* %byref.constant.int.17 to i8* -; CHECK-NEXT: %[[r35:.+]] = call fast double @ddot_64_(i8* %[[r34]], i8* %"y'", i8* %incy_p, i8* %19, i8* %intcast.constant.int.18) +; CHECK-NEXT: %[[r35:.+]] = call fast double @ddot_64_(i8* %[[r34]], i8* %"y'", i8* %incy_p, i8* %[[r19]], i8* %intcast.constant.int.18) ; CHECK-NEXT: %[[r36:.+]] = bitcast i8* %"alpha'" to double* ; CHECK-NEXT: %[[r37:.+]] = load double, double* %[[r36]] ; CHECK-NEXT: %[[r38:.+]] = fadd fast double %[[r37]], %[[r35]] ; CHECK-NEXT: store double %[[r38]], double* %[[r36]] + +; CHECK-NEXT: %[[forfree:.+]] = bitcast double* %22 to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[forfree]]) + ; CHECK-NEXT: %ld.row.trans9 = load i8, i8* %malloccall, align 1 ; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans9, 110 ; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans9, 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll index c61b1fbe8d6b..ea99f479c8d8 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll @@ -76,15 +76,6 @@ entry: ; CHECK-NEXT: %cache.y = bitcast i8* %malloccall2 to double* ; CHECK-NEXT: store i64 1, i64* %byref. ; CHECK-NEXT: call void @dcopy_64_(i8* %n_p, i8* %Y, i8* %incy_p, double* %cache.y, i64* %byref.) -; CHECK-NEXT: %[[i6:.+]] = bitcast i8* %n_p to i64* -; CHECK-NEXT: %[[i7:.+]] = load i64, i64* %[[i6]] -; CHECK-NEXT: %[[i8:.+]] = add i64 %[[i7]], 1 -; CHECK-NEXT: %square_mat_size_y0 = mul i64 %[[i7]], %[[i8]] -; CHECK-NEXT: %size_y0 = udiv i64 %square_mat_size_y0, 2 -; CHECK-NEXT: %mallocsize4 = mul nuw nsw i64 %size_y0, 8 -; CHECK-NEXT: %malloccall5 = tail call noalias nonnull i8* @malloc(i64 %mallocsize4) -; CHECK-NEXT: %mat_y0 = bitcast i8* %malloccall5 to double* -; CHECK-NEXT: %[[i9:.+]] = bitcast double* %mat_y0 to i8* ; CHECK-NEXT: call void @dspmv_64_(i8* %uplo, i8* %n_p, i8* %alpha, i8* %AP, i8* %X, i8* %incx_p, i8* %beta, i8* %Y, i8* %incy_p) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"AP'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %AP to double* @@ -98,6 +89,17 @@ entry: ; CHECK-NEXT: %[[i12:.+]] = bitcast double* %cache.y to i8* ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* + +; CHECK-NEXT: %[[i6:.+]] = bitcast i8* %n_p to i64* +; CHECK-NEXT: %[[i7:.+]] = load i64, i64* %[[i6]] +; CHECK-NEXT: %[[i8:.+]] = add i64 %[[i7]], 1 +; CHECK-NEXT: %square_mat_size_y0 = mul i64 %[[i7]], %[[i8]] +; CHECK-NEXT: %size_y0 = udiv i64 %square_mat_size_y0, 2 +; CHECK-NEXT: %mallocsize4 = mul nuw nsw i64 %size_y0, 8 +; CHECK-NEXT: %malloccall5 = tail call noalias nonnull i8* @malloc(i64 %mallocsize4) +; CHECK-NEXT: %[[mat_y0:.+]] = bitcast i8* %malloccall5 to double* +; CHECK-NEXT: %[[i9:.+]] = bitcast double* %[[mat_y0]] to i8* + ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 @@ -112,6 +114,8 @@ entry: ; CHECK-NEXT: %[[i15:.+]] = load double, double* %[[i14]] ; CHECK-NEXT: %[[i16:.+]] = fadd fast double %[[i15]], %[[i13]] ; CHECK-NEXT: store double %[[i16]], double* %[[i14]] +; CHECK-NEXT: %[[forfree:.+]] = bitcast double* %[[mat_y0]] to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[forfree]]) ; CHECK-NEXT: call void @dspr2_64_(i8* %uplo, i8* %n_p, i8* %alpha, i8* %X, i8* %incx_p, i8* %"Y'", i8* %incy_p, i8* %"AP'") ; CHECK: %[[i17:.+]] = load i64, i64* %n ; CHECK-NEXT: %[[i18:.+]] = load i64, i64* %incx diff --git a/enzyme/test/Integration/ReverseMode/blas_gemm.c b/enzyme/test/Integration/ReverseMode/blas_gemm.c new file mode 100644 index 000000000000..195b7efea775 --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/blas_gemm.c @@ -0,0 +1,75 @@ +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi + +#include "test_utils.h" +#include "../blas_inline.h" + +#include + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +void __enzyme_autodiff(void*, ...); + +const size_t n = 10; + +static char N = 'N'; +static int ten = 10; +static double one = 1.0; +static double zero = 0.0; +double simulate(double* A) { + double *out = (double*)malloc(sizeof(double)*n*n); + dgemm_(&N, &N, &ten, &ten, &ten, &one, A, &ten, A, &ten, &zero, &out[0], &ten); + return out[0];//P1(0, 0); +} + +int main(int argc, char **argv) { + + double A[n * n]; + double Adup[n * n]; + double Adup_fd[n * n]; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + A[n*i + j] = j == i ? 0.3 : 0.1; + Adup[n*i + j] = 0.0; + Adup_fd[n*i + j] = 0.0; + } + } + + double delta = 0.001; + delta = delta * delta; + + double fx = simulate(A); + printf("f(A) = %f\n", fx); + + __enzyme_autodiff((void *)simulate, enzyme_dup, &A[0], &Adup[0]); + + for (int i = 0; i < n*n; i++) { + printf("dA[%d]=%f\n", i, Adup[i]); + } + for (int i = 0; i < n*n; i++) { + A[i] += delta / 2; + double fx2 = simulate(A); + A[i] -= delta; + double fx3 = simulate(A); + A[i] += delta/2; + + Adup_fd[i] = (fx2 - fx3) / delta; + + printf("dA_fd[%d]=%f\n", i, Adup_fd[i]); + + APPROX_EQ(Adup[i], Adup_fd[i], 1e-6); + } + + return 0; +} diff --git a/enzyme/test/Integration/blas_inline.h b/enzyme/test/Integration/blas_inline.h new file mode 100644 index 000000000000..46507a65c20f --- /dev/null +++ b/enzyme/test/Integration/blas_inline.h @@ -0,0 +1,1159 @@ +#include +#include + +typedef int32_t integer; +typedef double doublereal; +typedef bool logical; + +#define abs(x) ((x) >= 0 ? (x) : -(x)) +#define min(a,b) ((a) <= (b) ? (a) : (b)) +#define max(a,b) ((a) >= (b) ? (a) : (b)) +#define TRUE_ true +#define FALSE_ false + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((noinline)) +int xerbla_(const char *srname, integer *info, int len) +{ + static char fmt_9999[] = "(\002 ** On entry to \002,a,\002 parameter num" + "ber \002,i2,\002 had \002,\002an illegal value\002)"; + + printf("** On entry to %6s, parameter number %2i had an illegal value\n", + srname, *info); + assert(0 &&" error"); + exit(1); + return 0; +} +__attribute__((noinline)) +logical lsame_(char *ca, char *cb, int, int) +{ + /* System generated locals */ + logical ret_val; + + /* Local variables */ + integer inta, intb, zcode; + + +/* -- LAPACK auxiliary routine (version 3.1) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* LSAME returns .TRUE. if CA is the same letter as CB regardless of */ +/* case. */ + +/* Arguments */ +/* ========= */ + +/* CA (input) CHARACTER*1 */ + +/* CB (input) CHARACTER*1 */ +/* CA and CB specify the single characters to be compared. */ + +/* ===================================================================== */ + +/* .. Intrinsic Functions .. */ +/* .. */ +/* .. Local Scalars .. */ +/* .. */ + +/* Test if the characters are equal */ + + ret_val = *(unsigned char *)ca == *(unsigned char *)cb; + if (ret_val) { + return ret_val; + } + +/* Now test for equivalence if both characters are alphabetic. */ + + zcode = 'Z'; + +/* Use 'Z' rather than 'A' so that ASCII can be detected on Prime */ +/* machines, on which ICHAR returns a value with bit 8 set. */ +/* ICHAR('A') on Prime machines returns 193 which is the same as */ +/* ICHAR('A') on an EBCDIC machine. */ + + inta = *(unsigned char *)ca; + intb = *(unsigned char *)cb; + + if (zcode == 90 || zcode == 122) { + +/* ASCII is assumed - ZCODE is the ASCII code of either lower or */ +/* upper case 'Z'. */ + + if (inta >= 97 && inta <= 122) { + inta += -32; + } + if (intb >= 97 && intb <= 122) { + intb += -32; + } + + } else if (zcode == 233 || zcode == 169) { + +/* EBCDIC is assumed - ZCODE is the EBCDIC code of either lower or */ +/* upper case 'Z'. */ + + if (inta >= 129 && inta <= 137 || inta >= 145 && inta <= 153 || inta + >= 162 && inta <= 169) { + inta += 64; + } + if (intb >= 129 && intb <= 137 || intb >= 145 && intb <= 153 || intb + >= 162 && intb <= 169) { + intb += 64; + } + + } else if (zcode == 218 || zcode == 250) { + +/* ASCII is assumed, on Prime machines - ZCODE is the ASCII code */ +/* plus 128 of either lower or upper case 'Z'. */ + + if (inta >= 225 && inta <= 250) { + inta += -32; + } + if (intb >= 225 && intb <= 250) { + intb += -32; + } + } + ret_val = inta == intb; + +/* RETURN */ + +/* End of LSAME */ + + return ret_val; +} /* lsame_ */ + +__attribute__((noinline)) + +logical dlaisnan_(doublereal *din1, doublereal *din2) +{ + /* System generated locals */ + logical ret_val; + + +/* -- LAPACK auxiliary routine (version 3.2) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* This routine is not for general use. It exists solely to avoid */ +/* over-optimization in DISNAN. */ + +/* DLAISNAN checks for NaNs by comparing its two arguments for */ +/* inequality. NaN is the only floating-point value where NaN != NaN */ +/* returns .TRUE. To check for NaNs, pass the same variable as both */ +/* arguments. */ + +/* A compiler must assume that the two arguments are */ +/* not the same variable, and the test will not be optimized away. */ +/* Interprocedural or whole-program optimization may delete this */ +/* test. The ISNAN functions will be replaced by the correct */ +/* Fortran 03 intrinsic once the intrinsic is widely available. */ + +/* Arguments */ +/* ========= */ + +/* DIN1 (input) DOUBLE PRECISION */ +/* DIN2 (input) DOUBLE PRECISION */ +/* Two numbers to compare for inequality. */ + +/* ===================================================================== */ + +/* .. Executable Statements .. */ + ret_val = *din1 != *din2; + return ret_val; +} /* dlaisnan_ */ + +__attribute__((noinline)) +logical disnan_(doublereal *din) +{ + /* System generated locals */ + logical ret_val; + + /* Local variables */ + +/* -- LAPACK auxiliary routine (version 3.2) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DISNAN returns .TRUE. if its argument is NaN, and .FALSE. */ +/* otherwise. To be replaced by the Fortran 2003 intrinsic in the */ +/* future. */ + +/* Arguments */ +/* ========= */ + +/* DIN (input) DOUBLE PRECISION */ +/* Input to test for NaN. */ + +/* ===================================================================== */ + +/* .. External Functions .. */ +/* .. */ +/* .. Executable Statements .. */ + ret_val = dlaisnan_(din, din); + return ret_val; +} /* disnan_ */ + +__attribute__((noinline)) + +/* Subroutine */ void dlacpy_(char *uplo, integer *m, integer *n, const doublereal * + a, integer *lda, doublereal *b, integer *ldb) +{ + /* System generated locals */ + integer a_dim1, a_offset, b_dim1, b_offset, i__1, i__2; + + /* Local variables */ + integer i__, j; + + +/* -- LAPACK auxiliary routine (version 3.2) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DLACPY copies all or part of a two-dimensional matrix A to another */ +/* matrix B. */ + +/* Arguments */ +/* ========= */ + +/* UPLO (input) CHARACTER*1 */ +/* Specifies the part of the matrix A to be copied to B. */ +/* = 'U': Upper triangular part */ +/* = 'L': Lower triangular part */ +/* Otherwise: All of the matrix A */ + +/* M (input) INTEGER */ +/* The number of rows of the matrix A. M >= 0. */ + +/* N (input) INTEGER */ +/* The number of columns of the matrix A. N >= 0. */ + +/* A (input) DOUBLE PRECISION array, dimension (LDA,N) */ +/* The m by n matrix A. If UPLO = 'U', only the upper triangle */ +/* or trapezoid is accessed; if UPLO = 'L', only the lower */ +/* triangle or trapezoid is accessed. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,M). */ + +/* B (output) DOUBLE PRECISION array, dimension (LDB,N) */ +/* On exit, B = A in the locations specified by UPLO. */ + +/* LDB (input) INTEGER */ +/* The leading dimension of the array B. LDB >= max(1,M). */ + +/* ===================================================================== */ + +/* .. Local Scalars .. */ +/* .. */ +/* .. External Functions .. */ +/* .. */ +/* .. Intrinsic Functions .. */ +/* .. */ +/* .. Executable Statements .. */ + + /* Parameter adjustments */ + a_dim1 = *lda; + a_offset = 1 + a_dim1; + a -= a_offset; + b_dim1 = *ldb; + b_offset = 1 + b_dim1; + b -= b_offset; + + /* Function Body */ + if (lsame_(uplo, (char*)"U", 1, 1)) { + i__1 = *n; + for (j = 1; j <= i__1; ++j) { + i__2 = min(j,*m); + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = a[i__ + j * a_dim1]; +/* L10: */ + } +/* L20: */ + } + } else if (lsame_(uplo, (char*)"L", 1, 1)) { + i__1 = *n; + for (j = 1; j <= i__1; ++j) { + i__2 = *m; + for (i__ = j; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = a[i__ + j * a_dim1]; +/* L30: */ + } +/* L40: */ + } + } else { + i__1 = *n; + for (j = 1; j <= i__1; ++j) { + i__2 = *m; + for (i__ = 1; i__ <= i__2; ++i__) { + b[i__ + j * b_dim1] = a[i__ + j * a_dim1]; +/* L50: */ + } +/* L60: */ + } + } + return; + +/* End of DLACPY */ + +} /* dlacpy_ */ + +__attribute__((noinline)) + +/* Subroutine */ int dlascl_(char *type__, integer *kl, integer *ku, + doublereal *cfrom, doublereal *cto, integer *m, integer *n, + doublereal *a, integer *lda, integer *info) +{ + /* System generated locals */ + integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; + + /* Local variables */ + integer i__, j, k1, k2, k3, k4; + doublereal mul, cto1; + logical done; + doublereal ctoc; + integer itype; + doublereal cfrom1; + // extern doublereal dlamch_(char *); + doublereal cfromc; + doublereal bignum, smlnum; + + +/* -- LAPACK auxiliary routine (version 3.2) -- */ +/* Univ. of Tennessee, Univ. of California Berkeley and NAG Ltd.. */ +/* November 2006 */ + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* DLASCL multiplies the M by N real matrix A by the real scalar */ +/* CTO/CFROM. This is done without over/underflow as long as the final */ +/* result CTO*A(I,J)/CFROM does not over/underflow. TYPE specifies that */ +/* A may be full, upper triangular, lower triangular, upper Hessenberg, */ +/* or banded. */ + +/* Arguments */ +/* ========= */ + +/* TYPE (input) CHARACTER*1 */ +/* TYPE indices the storage type of the input matrix. */ +/* = 'G': A is a full matrix. */ +/* = 'L': A is a lower triangular matrix. */ +/* = 'U': A is an upper triangular matrix. */ +/* = 'H': A is an upper Hessenberg matrix. */ +/* = 'B': A is a symmetric band matrix with lower bandwidth KL */ +/* and upper bandwidth KU and with the only the lower */ +/* half stored. */ +/* = 'Q': A is a symmetric band matrix with lower bandwidth KL */ +/* and upper bandwidth KU and with the only the upper */ +/* half stored. */ +/* = 'Z': A is a band matrix with lower bandwidth KL and upper */ +/* bandwidth KU. */ + +/* KL (input) INTEGER */ +/* The lower bandwidth of A. Referenced only if TYPE = 'B', */ +/* 'Q' or 'Z'. */ + +/* KU (input) INTEGER */ +/* The upper bandwidth of A. Referenced only if TYPE = 'B', */ +/* 'Q' or 'Z'. */ + +/* CFROM (input) DOUBLE PRECISION */ +/* CTO (input) DOUBLE PRECISION */ +/* The matrix A is multiplied by CTO/CFROM. A(I,J) is computed */ +/* without over/underflow if the final result CTO*A(I,J)/CFROM */ +/* can be represented without over/underflow. CFROM must be */ +/* nonzero. */ + +/* M (input) INTEGER */ +/* The number of rows of the matrix A. M >= 0. */ + +/* N (input) INTEGER */ +/* The number of columns of the matrix A. N >= 0. */ + +/* A (input/output) DOUBLE PRECISION array, dimension (LDA,N) */ +/* The matrix to be multiplied by CTO/CFROM. See TYPE for the */ +/* storage type. */ + +/* LDA (input) INTEGER */ +/* The leading dimension of the array A. LDA >= max(1,M). */ + +/* INFO (output) INTEGER */ +/* 0 - successful exit */ +/* <0 - if INFO = -i, the i-th argument had an illegal value. */ + +/* ===================================================================== */ + +/* .. Parameters .. */ +/* .. */ +/* .. Local Scalars .. */ +/* .. */ +/* .. External Functions .. */ +/* .. */ +/* .. Intrinsic Functions .. */ +/* .. */ +/* .. External Subroutines .. */ +/* .. */ +/* .. Executable Statements .. */ + +/* Test the input arguments */ + + /* Parameter adjustments */ + a_dim1 = *lda; + a_offset = 1 + a_dim1; + a -= a_offset; + + /* Function Body */ + *info = 0; + + if (lsame_(type__, (char*)"G", 1, 1)) { + itype = 0; + } else if (lsame_(type__, (char*)"L", 1, 1)) { + itype = 1; + } else if (lsame_(type__, (char*)"U", 1, 1)) { + itype = 2; + } else if (lsame_(type__, (char*)"H", 1, 1)) { + itype = 3; + } else if (lsame_(type__, (char*)"B", 1, 1)) { + itype = 4; + } else if (lsame_(type__, (char*)"Q", 1, 1)) { + itype = 5; + } else if (lsame_(type__, (char*)"Z", 1, 1)) { + itype = 6; + } else { + itype = -1; + } + + if (itype == -1) { + *info = -1; + } else if (*cfrom == 0. || disnan_(cfrom)) { + *info = -4; + } else if (disnan_(cto)) { + *info = -5; + } else if (*m < 0) { + *info = -6; + } else if (*n < 0 || itype == 4 && *n != *m || itype == 5 && *n != *m) { + *info = -7; + } else if (itype <= 3 && *lda < max(1,*m)) { + *info = -9; + } else if (itype >= 4) { +/* Computing MAX */ + i__1 = *m - 1; + if (*kl < 0 || *kl > max(i__1,0)) { + *info = -2; + } else /* if(complicated condition) */ { +/* Computing MAX */ + i__1 = *n - 1; + if (*ku < 0 || *ku > max(i__1,0) || (itype == 4 || itype == 5) && + *kl != *ku) { + *info = -3; + } else if (itype == 4 && *lda < *kl + 1 || itype == 5 && *lda < * + ku + 1 || itype == 6 && *lda < (*kl << 1) + *ku + 1) { + *info = -9; + } + } + } + + if (*info != 0) { + i__1 = -(*info); + xerbla_("DLASCL", &i__1, 0); + return 0; + } + +/* Quick return if possible */ + + if (*n == 0 || *m == 0) { + return 0; + } + +/* Get machine parameters */ + + smlnum = 0.0001; //dlamch_("S"); + bignum = 1. / smlnum; + + cfromc = *cfrom; + ctoc = *cto; + +L10: + cfrom1 = cfromc * smlnum; + if (cfrom1 == cfromc) { +/* CFROMC is an inf. Multiply by a correctly signed zero for */ +/* finite CTOC, or a NaN if CTOC is infinite. */ + mul = ctoc / cfromc; + done = TRUE_; + cto1 = ctoc; + } else { + cto1 = ctoc / bignum; + if (cto1 == ctoc) { +/* CTOC is either 0 or an inf. In both cases, CTOC itself */ +/* serves as the correct multiplication factor. */ + mul = ctoc; + done = TRUE_; + cfromc = 1.; + } else if (abs(cfrom1) > abs(ctoc) && ctoc != 0.) { + mul = smlnum; + done = FALSE_; + cfromc = cfrom1; + } else if (abs(cto1) > abs(cfromc)) { + mul = bignum; + done = FALSE_; + ctoc = cto1; + } else { + mul = ctoc / cfromc; + done = TRUE_; + } + } + + if (itype == 0) { + +/* Full matrix */ + + i__1 = *n; + for (j = 1; j <= i__1; ++j) { + i__2 = *m; + for (i__ = 1; i__ <= i__2; ++i__) { + a[i__ + j * a_dim1] *= mul; +/* L20: */ + } +/* L30: */ + } + + } else if (itype == 1) { + +/* Lower triangular matrix */ + + i__1 = *n; + for (j = 1; j <= i__1; ++j) { + i__2 = *m; + for (i__ = j; i__ <= i__2; ++i__) { + a[i__ + j * a_dim1] *= mul; +/* L40: */ + } +/* L50: */ + } + + } else if (itype == 2) { + +/* Upper triangular matrix */ + + i__1 = *n; + for (j = 1; j <= i__1; ++j) { + i__2 = min(j,*m); + for (i__ = 1; i__ <= i__2; ++i__) { + a[i__ + j * a_dim1] *= mul; +/* L60: */ + } +/* L70: */ + } + + } else if (itype == 3) { + +/* Upper Hessenberg matrix */ + + i__1 = *n; + for (j = 1; j <= i__1; ++j) { +/* Computing MIN */ + i__3 = j + 1; + i__2 = min(i__3,*m); + for (i__ = 1; i__ <= i__2; ++i__) { + a[i__ + j * a_dim1] *= mul; +/* L80: */ + } +/* L90: */ + } + + } else if (itype == 4) { + +/* Lower half of a symmetric band matrix */ + + k3 = *kl + 1; + k4 = *n + 1; + i__1 = *n; + for (j = 1; j <= i__1; ++j) { +/* Computing MIN */ + i__3 = k3, i__4 = k4 - j; + i__2 = min(i__3,i__4); + for (i__ = 1; i__ <= i__2; ++i__) { + a[i__ + j * a_dim1] *= mul; +/* L100: */ + } +/* L110: */ + } + + } else if (itype == 5) { + +/* Upper half of a symmetric band matrix */ + + k1 = *ku + 2; + k3 = *ku + 1; + i__1 = *n; + for (j = 1; j <= i__1; ++j) { +/* Computing MAX */ + i__2 = k1 - j; + i__3 = k3; + for (i__ = max(i__2,1); i__ <= i__3; ++i__) { + a[i__ + j * a_dim1] *= mul; +/* L120: */ + } +/* L130: */ + } + + } else if (itype == 6) { + +/* Band matrix */ + + k1 = *kl + *ku + 2; + k2 = *kl + 1; + k3 = (*kl << 1) + *ku + 1; + k4 = *kl + *ku + 1 + *m; + i__1 = *n; + for (j = 1; j <= i__1; ++j) { +/* Computing MAX */ + i__3 = k1 - j; +/* Computing MIN */ + i__4 = k3, i__5 = k4 - j; + i__2 = min(i__4,i__5); + for (i__ = max(i__3,k2); i__ <= i__2; ++i__) { + a[i__ + j * a_dim1] *= mul; +/* L140: */ + } +/* L150: */ + } + + } + + if (! done) { + goto L10; + } + + return 0; + +/* End of DLASCL */ + +} /* dlascl_ */ + +__attribute__((noinline)) + +doublereal ddot_(integer *n, doublereal *dx, integer *incx, doublereal *dy, + integer *incy) +{ + /* System generated locals */ + integer i__1; + doublereal ret_val; + + /* Local variables */ + integer i__, m, ix, iy, mp1; + doublereal dtemp; + +/* .. Scalar Arguments .. */ +/* .. */ +/* .. Array Arguments .. */ +/* .. */ + +/* Purpose */ +/* ======= */ + +/* forms the dot product of two vectors. */ +/* uses unrolled loops for increments equal to one. */ +/* jack dongarra, linpack, 3/11/78. */ +/* modified 12/3/93, array(1) declarations changed to array(*) */ + + +/* .. Local Scalars .. */ +/* .. */ +/* .. Intrinsic Functions .. */ +/* .. */ + /* Parameter adjustments */ + --dy; + --dx; + + /* Function Body */ + ret_val = 0.; + dtemp = 0.; + if (*n <= 0) { + return ret_val; + } + if (*incx == 1 && *incy == 1) { + goto L20; + } + +/* code for unequal increments or equal increments */ +/* not equal to 1 */ + + ix = 1; + iy = 1; + if (*incx < 0) { + ix = (-(*n) + 1) * *incx + 1; + } + if (*incy < 0) { + iy = (-(*n) + 1) * *incy + 1; + } + i__1 = *n; + for (i__ = 1; i__ <= i__1; ++i__) { + dtemp += dx[ix] * dy[iy]; + ix += *incx; + iy += *incy; +/* L10: */ + } + ret_val = dtemp; + return ret_val; + +/* code for both increments equal to 1 */ + + +/* clean-up loop */ + +L20: + m = *n % 5; + if (m == 0) { + goto L40; + } + i__1 = m; + for (i__ = 1; i__ <= i__1; ++i__) { + dtemp += dx[i__] * dy[i__]; +/* L30: */ + } + if (*n < 5) { + goto L60; + } +L40: + mp1 = m + 1; + i__1 = *n; + for (i__ = mp1; i__ <= i__1; i__ += 5) { + dtemp = dtemp + dx[i__] * dy[i__] + dx[i__ + 1] * dy[i__ + 1] + dx[ + i__ + 2] * dy[i__ + 2] + dx[i__ + 3] * dy[i__ + 3] + dx[i__ + + 4] * dy[i__ + 4]; +/* L50: */ + } +L60: + ret_val = dtemp; + return ret_val; +} /* ddot_ */ + +__attribute__((noinline)) +/* Subroutine */ int dgemm_(const char *transa, const char *transb, const integer *m, const integer * + n, const integer *k, const doublereal *alpha, const doublereal *a, const integer *lda, + const doublereal *b, const integer *ldb, const doublereal *beta, doublereal *c, const integer + *ldc) +{ + + + /* System generated locals */ + integer a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset, i__1, i__2, + i__3; + + /* Local variables */ + integer info; + logical nota, notb; + doublereal temp; + integer i, j, l, ncola; + integer nrowa, nrowb; + + +/* Purpose + ======= + + DGEMM performs one of the matrix-matrix operations + + C := alpha*op( A )*op( B ) + beta*C, + + where op( X ) is one of + + op( X ) = X or op( X ) = X', + + alpha and beta are scalars, and A, B and C are matrices, with op( A ) + + an m by k matrix, op( B ) a k by n matrix and C an m by n matrix. + + + Parameters + ========== + + TRANSA - CHARACTER*1. + On entry, TRANSA specifies the form of op( A ) to be used in + + the matrix multiplication as follows: + + TRANSA = 'N' or 'n', op( A ) = A. + + TRANSA = 'T' or 't', op( A ) = A'. + + TRANSA = 'C' or 'c', op( A ) = A'. + + Unchanged on exit. + + TRANSB - CHARACTER*1. + On entry, TRANSB specifies the form of op( B ) to be used in + + the matrix multiplication as follows: + + TRANSB = 'N' or 'n', op( B ) = B. + + TRANSB = 'T' or 't', op( B ) = B'. + + TRANSB = 'C' or 'c', op( B ) = B'. + + Unchanged on exit. + + M - INTEGER. + On entry, M specifies the number of rows of the matrix + + op( A ) and of the matrix C. M must be at least zero. + + Unchanged on exit. + + N - INTEGER. + On entry, N specifies the number of columns of the matrix + + op( B ) and the number of columns of the matrix C. N must be + + at least zero. + Unchanged on exit. + + K - INTEGER. + On entry, K specifies the number of columns of the matrix + + op( A ) and the number of rows of the matrix op( B ). K must + + be at least zero. + Unchanged on exit. + + ALPHA - DOUBLE PRECISION. + On entry, ALPHA specifies the scalar alpha. + Unchanged on exit. + + A - DOUBLE PRECISION array of DIMENSION ( LDA, ka ), where ka is + + k when TRANSA = 'N' or 'n', and is m otherwise. + Before entry with TRANSA = 'N' or 'n', the leading m by k + + part of the array A must contain the matrix A, otherwise + + the leading k by m part of the array A must contain the + + matrix A. + Unchanged on exit. + + LDA - INTEGER. + On entry, LDA specifies the first dimension of A as declared + + in the calling (sub) program. When TRANSA = 'N' or 'n' then + + LDA must be at least max( 1, m ), otherwise LDA must be at + + least max( 1, k ). + Unchanged on exit. + + B - DOUBLE PRECISION array of DIMENSION ( LDB, kb ), where kb is + + n when TRANSB = 'N' or 'n', and is k otherwise. + Before entry with TRANSB = 'N' or 'n', the leading k by n + + part of the array B must contain the matrix B, otherwise + + the leading n by k part of the array B must contain the + + matrix B. + Unchanged on exit. + + LDB - INTEGER. + On entry, LDB specifies the first dimension of B as declared + + in the calling (sub) program. When TRANSB = 'N' or 'n' then + + LDB must be at least max( 1, k ), otherwise LDB must be at + + least max( 1, n ). + Unchanged on exit. + + BETA - DOUBLE PRECISION. + On entry, BETA specifies the scalar beta. When BETA is + + supplied as zero then C need not be set on input. + Unchanged on exit. + + C - DOUBLE PRECISION array of DIMENSION ( LDC, n ). + Before entry, the leading m by n part of the array C must + + contain the matrix C, except when beta is zero, in which + + case C need not be set on entry. + On exit, the array C is overwritten by the m by n matrix + + ( alpha*op( A )*op( B ) + beta*C ). + + LDC - INTEGER. + On entry, LDC specifies the first dimension of C as declared + + in the calling (sub) program. LDC must be at least + + max( 1, m ). + Unchanged on exit. + + + Level 3 Blas routine. + + -- Written on 8-February-1989. + Jack Dongarra, Argonne National Laboratory. + Iain Duff, AERE Harwell. + Jeremy Du Croz, Numerical Algorithms Group Ltd. + Sven Hammarling, Numerical Algorithms Group Ltd. + + + + Set NOTA and NOTB as true if A and B respectively are not + + transposed and set NROWA, NCOLA and NROWB as the number of rows + + and columns of A and the number of rows of B respectively. + + + + Parameter adjustments + Function Body */ + +#define A(I,J) a[(I)-1 + ((J)-1)* ( *lda)] +#define B(I,J) b[(I)-1 + ((J)-1)* ( *ldb)] +#define C(I,J) c[(I)-1 + ((J)-1)* ( *ldc)] + + nota = lsame_((char*)transa, (char*)"N", 1, 1); + notb = lsame_((char*)transb, (char*)"N", 1, 1); + if (nota) { + nrowa = *m; + ncola = *k; + } else { + nrowa = *k; + ncola = *m; + } + if (notb) { + nrowb = *k; + } else { + nrowb = *n; + } + +/* Test the input parameters. */ + + info = 0; + if (! nota && ! lsame_((char*)transa, (char*)"C", 1, 1) && ! lsame_((char*)transa, (char*)"T", 1, 1)) { + info = 1; + } else if (! notb && ! lsame_((char*)transb, (char*)"C", 1, 1) && ! lsame_((char*)transb, + (char*)"T", 1, 1)) { + info = 2; + } else if (*m < 0) { + info = 3; + } else if (*n < 0) { + info = 4; + } else if (*k < 0) { + info = 5; + } else if (*lda < max(1,nrowa)) { + info = 8; + } else if (*ldb < max(1,nrowb)) { + info = 10; + } else if (*ldc < max(1,*m)) { + info = 13; + } + if (info != 0) { + xerbla_("DGEMM ", &info, 0); + return 0; + } + +/* Quick return if possible. */ + + if (*m == 0 || *n == 0 || (*alpha == 0. || *k == 0) && *beta == 1.) { + return 0; + } + +/* And if alpha.eq.zero. */ + + if (*alpha == 0.) { + if (*beta == 0.) { + i__1 = *n; + for (j = 1; j <= *n; ++j) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) = 0.; +/* L10: */ + } +/* L20: */ + } + } else { + i__1 = *n; + for (j = 1; j <= *n; ++j) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) = *beta * C(i,j); +/* L30: */ + } +/* L40: */ + } + } + return 0; + } + +/* Start the operations. */ + + if (notb) { + if (nota) { + +/* Form C := alpha*A*B + beta*C. */ + + i__1 = *n; + for (j = 1; j <= *n; ++j) { + if (*beta == 0.) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) = 0.; +/* L50: */ + } + } else if (*beta != 1.) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) = *beta * C(i,j); +/* L60: */ + } + } + i__2 = *k; + for (l = 1; l <= *k; ++l) { + if (B(l,j) != 0.) { + temp = *alpha * B(l,j); + i__3 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) += temp * A(i,l); +/* L70: */ + } + } +/* L80: */ + } +/* L90: */ + } + } else { + +/* Form C := alpha*A'*B + beta*C */ + + i__1 = *n; + for (j = 1; j <= *n; ++j) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + temp = 0.; + i__3 = *k; + for (l = 1; l <= *k; ++l) { + temp += A(l,i) * B(l,j); +/* L100: */ + } + if (*beta == 0.) { + C(i,j) = *alpha * temp; + } else { + C(i,j) = *alpha * temp + *beta * C(i,j); + } +/* L110: */ + } +/* L120: */ + } + } + } else { + if (nota) { + +/* Form C := alpha*A*B' + beta*C */ + + i__1 = *n; + for (j = 1; j <= *n; ++j) { + if (*beta == 0.) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) = 0.; +/* L130: */ + } + } else if (*beta != 1.) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) = *beta * C(i,j); +/* L140: */ + } + } + i__2 = *k; + for (l = 1; l <= *k; ++l) { + if (B(j,l) != 0.) { + temp = *alpha * B(j,l); + i__3 = *m; + for (i = 1; i <= *m; ++i) { + C(i,j) += temp * A(i,l); +/* L150: */ + } + } +/* L160: */ + } +/* L170: */ + } + } else { + +/* Form C := alpha*A'*B' + beta*C */ + + i__1 = *n; + for (j = 1; j <= *n; ++j) { + i__2 = *m; + for (i = 1; i <= *m; ++i) { + temp = 0.; + i__3 = *k; + for (l = 1; l <= *k; ++l) { + temp += A(l,i) * B(j,l); +/* L180: */ + } + if (*beta == 0.) { + C(i,j) = *alpha * temp; + } else { + C(i,j) = *alpha * temp + *beta * C(i,j); + } +/* L190: */ + } +/* L200: */ + } + } + } + + return 0; +#undef A +#undef B +#undef C +/* End of DGEMM . */ + +} /* dgemm_ */ + +#undef max +#undef min +#undef abs +#ifdef __cplusplus +} +#endif diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 6fced28311f1..1d3fcd39bba6 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -940,7 +940,17 @@ void emit_fwd_rewrite_rules(const TGPattern &pattern, raw_ostream &os) { os << " }\n"; } -void emit_tmp_creation(Record *Def, raw_ostream &os) { +void emit_tmp_free(Record *Def, raw_ostream &os, StringRef builder) { + const auto args = Def->getValueAsListOfStrings("args"); + // allocating tmp variables is optional, return if not required + if (args.size() == 0) + return; + const auto matName = args[0]; + const auto allocName = "mat_" + matName; + os << " CreateDealloc(" << builder << ", true_" << allocName << ");\n"; +} + +void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { const auto args = Def->getValueAsListOfStrings("args"); // allocating tmp variables is optional, return if not required if (args.size() == 0) @@ -955,51 +965,53 @@ void emit_tmp_creation(Record *Def, raw_ostream &os) { const auto matName = args[0]; const auto dim1 = "arg_" + args[2]; const auto dim2 = "arg_" + args[3]; - os << " Value *len1 = load_if_ref(BuilderZ, intType," << dim1 + os << " Value *len1 = load_if_ref(" << builder << ", intType," << dim1 << ", byRef);\n" - << " Value *len2 = load_if_ref(BuilderZ, intType," << dim2 + << " Value *len2 = load_if_ref(" << builder << ", intType," << dim2 << ", byRef);\n" - << " Value *size_" << matName - << " = BuilderZ.CreateNUWMul(len1, len2, \"size_" << matName << "\");\n"; + << " Value *size_" << matName << " = " << builder + << ".CreateNUWMul(len1, len2, \"size_" << matName << "\");\n"; } else if (action == "is_normal") { assert(args.size() == 5); const auto vecName = args[0]; const auto trans = "arg_" + args[2]; const auto dim1 = "arg_" + args[3]; const auto dim2 = "arg_" + args[4]; - os << " Value *len1 = load_if_ref(BuilderZ, intType," << dim1 + os << " Value *len1 = load_if_ref(" << builder << ", intType," << dim1 << ", byRef);\n" - << " Value *len2 = load_if_ref(BuilderZ, intType," << dim2 + << " Value *len2 = load_if_ref(" << builder << ", intType," << dim2 << ", byRef);\n"; - os << " Value *size_" << vecName - << " = BuilderZ.CreateSelect(is_normal(BuilderZ, " << trans + os << " Value *size_" << vecName << " = " << builder + << ".CreateSelect(is_normal(" << builder << ", " << trans << ", byRef, cublas), len1, len2);\n"; } else if (action == "triangular") { assert(args.size() == 3); const auto vecName = args[0]; const auto dim1 = "arg_" + args[2]; - os << " Value *len = load_if_ref(BuilderZ, intType," << dim1 + os << " Value *len = load_if_ref(" << builder << ", intType," << dim1 << ", byRef);\n"; // Size has to be (at least) // ( ( n*( n + 1 ) )/2 ) - os << " Value *size_" << vecName - << " = BuilderZ.CreateMul(len, BuilderZ.CreateAdd(len, " + os << " Value *size_" << vecName << " = " << builder + << ".CreateMul(len, " << builder + << ".CreateAdd(len, " "ConstantInt::get(intType, 1)), \"square_mat_size_" << vecName << "\");\n" - << " size_" << vecName << " = BuilderZ.CreateUDiv(size_" << vecName - << ", ConstantInt::get(intType, 2), \"size_" << vecName << "\");\n"; + << " size_" << vecName << " = " << builder << ".CreateUDiv(size_" + << vecName << ", ConstantInt::get(intType, 2), \"size_" << vecName + << "\");\n"; } const auto matName = args[0]; const auto allocName = "mat_" + matName; - os << " Value *" << allocName - << " = CreateAllocation(BuilderZ, fpType, size_" << matName << ", \"" - << allocName << "\");\n" + os << " Value * true_" << allocName << " = CreateAllocation(" << builder + << ", fpType, size_" << matName << ", \"" << allocName << "\");\n" + << " Value * " << allocName << " = true_" << allocName << ";\n" << " if (type_vec_like->isIntegerTy()) {\n" - << " " << allocName << " = BuilderZ.CreatePtrToInt(" << allocName - << ", type_vec_like);\n" + << " " << allocName << " = " << builder << ".CreatePtrToInt(" + << allocName << ", type_vec_like);\n" << " } else if (" << allocName << "->getType() != type_vec_like){\n" - << " " << allocName << " = BuilderZ.CreatePointerCast(" << allocName - << ", type_vec_like);\n" + << " " << allocName << " = " << builder << ".CreatePointerCast(" + << allocName << ", type_vec_like);\n" << " }\n"; } @@ -1644,7 +1656,7 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os); // We might need to create a tmp vec or matrix - emit_tmp_creation(Def, os); + emit_tmp_creation(Def, os, "Builder2"); os << " const auto Defs = gutils->getInvertedBundles(&call, {" << valueTypes << "}, Builder2, /* lookup */ true);\n"; @@ -1708,6 +1720,7 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, } } } + emit_tmp_free(Def, os, "Builder2"); emit_runtime_continue(ruleDag, name, " ", "Builder2", true, os); os << " }\n"; } else { From 3339d749947fe719e5b708afcdcac83c2e3e6245 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 6 Nov 2023 17:08:03 -0600 Subject: [PATCH 5/6] Fix Blas diffuse on P^3 computation (#1526) * Fix Blas diffuse on P^3 computation * fixup --- enzyme/Enzyme/Utils.cpp | 5 +- enzyme/Enzyme/Utils.h | 4 + .../test/Enzyme/ReverseMode/blas_diffuse.ll | 184 ++++++++++++++++++ .../test/Integration/ReverseMode/blas_gemm2.c | 99 ++++++++++ enzyme/test/Integration/blas_inline.h | 9 +- .../tools/enzyme-tblgen/blasDiffUseUpdater.h | 5 +- 6 files changed, 301 insertions(+), 5 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/blas_diffuse.ll create mode 100644 enzyme/test/Integration/ReverseMode/blas_gemm2.c diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index f028ec906f53..f72d805bc01a 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -1121,7 +1121,10 @@ Function *getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT, #if LLVM_VERSION_MAJOR >= 15 if (Mod.getContext().supportsTypedPointers()) { #endif - assert(PT->getPointerElementType() == elementType); +#if LLVM_VERSION_MAJOR >= 13 + if (!PT->isOpaquePointerTy()) +#endif + assert(PT->getPointerElementType() == elementType); #if LLVM_VERSION_MAJOR >= 15 } #endif diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index ab9d246f907d..c00c2c149df3 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -441,6 +441,10 @@ static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode, if (!arg->getContext().supportsTypedPointers()) { return DIFFE_TYPE::DUP_ARG; } +#elif LLVM_VERSION_MAJOR >= 13 + if (arg->isOpaquePointerTy()) { + return DIFFE_TYPE::DUP_ARG; + } #endif switch (whatType(arg->getPointerElementType(), mode, integersAreConstant, seen)) { diff --git a/enzyme/test/Enzyme/ReverseMode/blas_diffuse.ll b/enzyme/test/Enzyme/ReverseMode/blas_diffuse.ll new file mode 100644 index 000000000000..cc582757de5d --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas_diffuse.ll @@ -0,0 +1,184 @@ +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -ge 14 ] ; then %opt < %s %loadEnzyme -opaque-pointers -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -jump-threading -adce -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 14 ]; then %opt < %s %newLoadEnzyme -opaque-pointers -passes="enzyme,function(mem2reg,early-cse,instsimplify,jump-threading,adce)" -enzyme-preopt=false -S | FileCheck %s ; fi + +; ModuleID = '../examples/big/big_inlined_correctness.cpp' +source_filename = "../examples/big/big_inlined_correctness.cpp" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +%struct.Prod = type { ptr, double } + +declare i32 @dgemm_(ptr nocapture noundef readonly %transa_t, ptr nocapture noundef readonly %transb_t, ptr nocapture noundef readonly %m, ptr nocapture noundef readonly %n, ptr nocapture noundef readonly %k, ptr nocapture noundef readonly %alpha, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %lda, ptr nocapture noundef readonly %b, ptr nocapture noundef readonly %ldb, ptr nocapture noundef readonly %beta, ptr nocapture noundef %c, ptr nocapture noundef readonly %ldc) + +; Function Attrs: mustprogress noinline nounwind uwtable +define dso_local void @_Z3mulR4ProdPd(ptr nocapture noundef nonnull align 8 dereferenceable(16) %P, ptr noalias nocapture noundef readonly %rhs) { +entry: + %N = alloca i8, align 1 + %ten = alloca i32, align 4 + %one = alloca double, align 8 + %zero = alloca double, align 8 + %calloc = call dereferenceable_or_null(32) ptr @calloc(i64 1, i64 32) + store i8 78, ptr %N, align 1 + store i32 2, ptr %ten, align 4 + store double 1.000000e+00, ptr %one, align 8 + store double 0.000000e+00, ptr %zero, align 8 + %call1 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten) + %0 = load ptr, ptr %P, align 8 + %call2 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %zero, ptr noundef %0, ptr noundef nonnull %ten) + %alpha = getelementptr inbounds %struct.Prod, ptr %P, i64 0, i32 1 + store double 0.000000e+00, ptr %alpha, align 8 + ret void +} + +declare noalias ptr @malloc(i64) + +; Function Attrs: mustprogress nounwind uwtable +define dso_local double @_Z8simulatePd(ptr nocapture noundef readonly %P) { +entry: + %M = alloca %struct.Prod, align 8 + %call = tail call noalias dereferenceable_or_null(32) ptr @malloc(i64 noundef 32) + store ptr %call, ptr %M, align 8 + %alpha = getelementptr inbounds %struct.Prod, ptr %M, i64 0, i32 1 + store double 1.000000e+00, ptr %alpha, align 8 + call void @_Z3mulR4ProdPd(ptr noundef nonnull align 8 dereferenceable(16) %M, ptr noundef %P) + %0 = load ptr, ptr %M, align 8 + %1 = load double, ptr %0, align 8 + ret double %1 +} + +define void @caller(ptr %A, ptr %Adup) { +entry: + call void (...) @_Z17__enzyme_autodiffz(ptr noundef nonnull @_Z8simulatePd, metadata !"enzyme_dup", ptr noundef nonnull %A, ptr noundef nonnull %Adup) + ret void +} + +declare void @_Z17__enzyme_autodiffz(...) + +declare noalias noundef ptr @calloc(i64 noundef, i64 noundef) + +; we must actually save or set the matmul +; CHECK: define internal void @diffe_Z3mulR4ProdPd(ptr nocapture align 8 dereferenceable(16) %P, ptr nocapture align 8 %"P'", ptr noalias nocapture readonly %rhs, ptr nocapture %"rhs'", { ptr, ptr, ptr, ptr } %tapeArg) +; CHECK-NEXT: invertentry: +; CHECK-NEXT: %byref.transpose.transb = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8 +; CHECK-NEXT: %byref.transpose.transa = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.05 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.int.0 = alloca i32, align 4 +; CHECK-NEXT: %byref.constant.int.06 = alloca i32, align 4 +; CHECK-NEXT: %byref.constant.fp.1.07 = alloca double, align 8 +; CHECK-NEXT: %[[i0:.+]] = alloca i32, align 4 +; CHECK-NEXT: %byref.transpose.transb11 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.014 = alloca double, align 8 +; CHECK-NEXT: %byref.transpose.transa16 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.019 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.G20 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.int.021 = alloca i32, align 4 +; CHECK-NEXT: %byref.constant.int.022 = alloca i32, align 4 +; CHECK-NEXT: %byref.constant.fp.1.023 = alloca double, align 8 +; CHECK-NEXT: %[[i1:.+]] = alloca i32, align 4 +; CHECK-NEXT: %malloccall3 = alloca i8, i64 8, align 8 +; CHECK-NEXT: %malloccall = alloca i8, i64 1, align 1 +; CHECK-NEXT: %malloccall2 = alloca i8, i64 8, align 8 +; CHECK-NEXT: %malloccall1 = alloca i8, i64 4, align 4 +; CHECK-NEXT: %"calloc'mi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 2 +; CHECK-NEXT: %calloc = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 3 +; CHECK-NEXT: store i8 78, ptr %malloccall, align 1 +; CHECK-NEXT: store i32 2, ptr %malloccall1, align 4 +; CHECK-NEXT: store double 1.000000e+00, ptr %malloccall2, align 8 +; CHECK-NEXT: store double 0.000000e+00, ptr %malloccall3, align 8 +; CHECK-NEXT: %"'il_phi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 1 +; CHECK-NEXT: %[[i2:.+]] = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 0 +; CHECK-NEXT: %"alpha'ipg" = getelementptr inbounds %struct.Prod, ptr %"P'", i64 0, i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr %"alpha'ipg", align 8 +; CHECK-NEXT: %ld.transb = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i3:.+]] = icmp eq i8 %ld.transb, 110 +; CHECK-NEXT: %[[i4:.+]] = select i1 %[[i3]], i8 116, i8 0 +; CHECK-NEXT: %[[i5:.+]] = icmp eq i8 %ld.transb, 78 +; CHECK-NEXT: %[[i6:.+]] = select i1 %[[i5]], i8 84, i8 %[[i4]] +; CHECK-NEXT: %[[i7:.+]] = icmp eq i8 %ld.transb, 116 +; CHECK-NEXT: %[[i8:.+]] = select i1 %[[i7]], i8 110, i8 %[[i6]] +; CHECK-NEXT: %[[i9:.+]] = icmp eq i8 %ld.transb, 84 +; CHECK-NEXT: %[[i10:.+]] = select i1 %[[i9]], i8 78, i8 %[[i8]] +; CHECK-NEXT: store i8 %[[i10]], ptr %byref.transpose.transb, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i11:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[i12:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[i13:.+]] = or i1 %[[i12]], %[[i11]] +; CHECK-NEXT: %[[i14:.+]] = select i1 %[[i13]], ptr %byref.transpose.transb, ptr %malloccall +; CHECK-NEXT: %[[i15:.+]] = select i1 %[[i13]], ptr %"'il_phi", ptr %rhs +; CHECK-NEXT: %[[i16:.+]] = select i1 %[[i13]], ptr %rhs, ptr %"'il_phi" +; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.0, align 8 +; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i14]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i15]], ptr %malloccall1, ptr %[[i16]], ptr %malloccall1, ptr %byref.constant.fp.1.0, ptr %"calloc'mi", ptr %malloccall1, i32 1, i32 1) +; CHECK-NEXT: %ld.transa = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i17:.+]] = icmp eq i8 %ld.transa, 110 +; CHECK-NEXT: %[[i18:.+]] = select i1 %[[i17]], i8 116, i8 0 +; CHECK-NEXT: %[[i19:.+]] = icmp eq i8 %ld.transa, 78 +; CHECK-NEXT: %[[i20:.+]] = select i1 %[[i19]], i8 84, i8 %[[i18]] +; CHECK-NEXT: %[[i21:.+]] = icmp eq i8 %ld.transa, 116 +; CHECK-NEXT: %[[i22:.+]] = select i1 %[[i21]], i8 110, i8 %[[i20]] +; CHECK-NEXT: %[[i23:.+]] = icmp eq i8 %ld.transa, 84 +; CHECK-NEXT: %[[i24:.+]] = select i1 %[[i23]], i8 78, i8 %[[i22]] +; CHECK-NEXT: store i8 %[[i24]], ptr %byref.transpose.transa, align 1 +; CHECK-NEXT: %ld.row.trans2 = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i25:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-NEXT: %[[i26:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-NEXT: %[[i27:.+]] = or i1 %[[i26]], %[[i25]] +; CHECK-NEXT: %[[i28:.+]] = select i1 %[[i27]], ptr %byref.transpose.transa, ptr %malloccall +; CHECK-NEXT: %[[i29:.+]] = select i1 %[[i27]], ptr %[[i2]], ptr %"'il_phi" +; CHECK-NEXT: %[[i30:.+]] = select i1 %[[i27]], ptr %"'il_phi", ptr %[[i2]] +; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.05, align 8 +; CHECK-NEXT: call void @dgemm_(ptr %[[i28]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i29]], ptr %malloccall1, ptr %[[i30]], ptr %malloccall1, ptr %byref.constant.fp.1.05, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1) +; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G, align 1 +; CHECK-NEXT: store i32 0, ptr %byref.constant.int.0, align 4 +; CHECK-NEXT: store i32 0, ptr %byref.constant.int.06, align 4 +; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.07, align 8 +; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G, ptr %byref.constant.int.0, ptr %byref.constant.int.06, ptr %byref.constant.fp.1.07, ptr %malloccall3, ptr %malloccall1, ptr %malloccall1, ptr %"'il_phi", ptr %malloccall1, ptr %[[i0]], i32 1) +; CHECK-NEXT: tail call void @free(ptr nonnull %[[i2]]) +; CHECK-NEXT: %ld.transb10 = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i31:.+]] = icmp eq i8 %ld.transb10, 110 +; CHECK-NEXT: %[[i32:.+]] = select i1 %[[i31]], i8 116, i8 0 +; CHECK-NEXT: %[[i33:.+]] = icmp eq i8 %ld.transb10, 78 +; CHECK-NEXT: %[[i34:.+]] = select i1 %[[i33]], i8 84, i8 %[[i32]] +; CHECK-NEXT: %[[i35:.+]] = icmp eq i8 %ld.transb10, 116 +; CHECK-NEXT: %[[i36:.+]] = select i1 %[[i35]], i8 110, i8 %[[i34]] +; CHECK-NEXT: %[[i37:.+]] = icmp eq i8 %ld.transb10, 84 +; CHECK-NEXT: %[[i38:.+]] = select i1 %[[i37]], i8 78, i8 %[[i36]] +; CHECK-NEXT: store i8 %[[i38]], ptr %byref.transpose.transb11, align 1 +; CHECK-NEXT: %ld.row.trans12 = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i39:.+]] = icmp eq i8 %ld.row.trans12, 110 +; CHECK-NEXT: %[[i40:.+]] = icmp eq i8 %ld.row.trans12, 78 +; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]] +; CHECK-NEXT: %[[i42:.+]] = select i1 %[[i41]], ptr %byref.transpose.transb11, ptr %malloccall +; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i41]], ptr %"calloc'mi", ptr %rhs +; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i41]], ptr %rhs, ptr %"calloc'mi" +; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.014, align 8 +; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i42]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i43]], ptr %malloccall1, ptr %[[i44]], ptr %malloccall1, ptr %byref.constant.fp.1.014, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1) +; CHECK-NEXT: %ld.transa15 = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i45:.+]] = icmp eq i8 %ld.transa15, 110 +; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i45]], i8 116, i8 0 +; CHECK-NEXT: %[[i47:.+]] = icmp eq i8 %ld.transa15, 78 +; CHECK-NEXT: %[[i48:.+]] = select i1 %[[i47]], i8 84, i8 %[[i46]] +; CHECK-NEXT: %[[i49:.+]] = icmp eq i8 %ld.transa15, 116 +; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8 110, i8 %[[i48]] +; CHECK-NEXT: %[[i51:.+]] = icmp eq i8 %ld.transa15, 84 +; CHECK-NEXT: %[[i52:.+]] = select i1 %[[i51]], i8 78, i8 %[[i50]] +; CHECK-NEXT: store i8 %[[i52]], ptr %byref.transpose.transa16, align 1 +; CHECK-NEXT: %ld.row.trans17 = load i8, ptr %malloccall, align 1 +; CHECK-NEXT: %[[i53:.+]] = icmp eq i8 %ld.row.trans17, 110 +; CHECK-NEXT: %[[i54:.+]] = icmp eq i8 %ld.row.trans17, 78 +; CHECK-NEXT: %[[i55:.+]] = or i1 %[[i54]], %[[i53]] +; CHECK-NEXT: %[[i56:.+]] = select i1 %[[i55:.+]], ptr %byref.transpose.transa16, ptr %malloccall +; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i55:.+]], ptr %rhs, ptr %"calloc'mi" +; CHECK-NEXT: %[[i58:.+]] = select i1 %[[i55:.+]], ptr %"calloc'mi", ptr %rhs +; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.019, align 8 +; CHECK-NEXT: call void @dgemm_(ptr %[[i56]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %57, ptr %malloccall1, ptr %58, ptr %malloccall1, ptr %byref.constant.fp.1.019, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1) +; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G20, align 1 +; CHECK-NEXT: store i32 0, ptr %byref.constant.int.021, align 4 +; CHECK-NEXT: store i32 0, ptr %byref.constant.int.022, align 4 +; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.023, align 8 +; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G20, ptr %byref.constant.int.021, ptr %byref.constant.int.022, ptr %byref.constant.fp.1.023, ptr %malloccall2, ptr %malloccall1, ptr %malloccall1, ptr %"calloc'mi", ptr %malloccall1, ptr %[[i1]], i32 1) +; CHECK-NEXT: call void @free(ptr nonnull %"calloc'mi") +; CHECK-NEXT: call void @free(ptr %calloc) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Integration/ReverseMode/blas_gemm2.c b/enzyme/test/Integration/ReverseMode/blas_gemm2.c new file mode 100644 index 000000000000..e1fd7947308e --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/blas_gemm2.c @@ -0,0 +1,99 @@ +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi + +#include +#include +#include +#include "test_utils.h" +#include "../blas_inline.h" + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +#include + +void __enzyme_autodiff(void*, ...); + +const size_t n = 20; + +#include +struct Prod { + double* out; + double alpha; +}; + +__attribute__((noinline)) +void mul(struct Prod* P, double* __restrict__ rhs) { + double* tmp= (double*)malloc(sizeof(double)*n*n); + memset(tmp, 0, n*n*sizeof(double)); + char N = 'N'; + int ten = n; + double one = 1.0; + double zero = 0.0; + + dgemm_(&N, &N, &ten, &ten, &ten, &one, rhs, &ten, rhs, &ten, &one, tmp, &ten); + dgemm_(&N, &N, &ten, &ten, &ten, &one, tmp, &ten, rhs, &ten, &zero, P->out, &ten); + P->alpha = 0; + return; +} + +double simulate(double* P) { + struct Prod M; + M.out = (double*)malloc(sizeof(double)*n*n); + M.alpha = 1.0; + mul(&M, P); + return M.out[0]; + // double *out = (double*)malloc(sizeof(double)*n*n); + // dgemm_(&N, &N, &ten, &ten, &ten, &one, P1.data(), &ten, P.data(), &ten, &zero, &out[0], &ten); + // return P1(0, 0); +} + +int main(int argc, char **argv) { + + double A[n * n]; + double Adup[n * n]; + double Adup_fd[n * n]; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + A[n*i + j] = j == i ? 0.3 : 0.1; + Adup[n*i + j] = 0.0; + Adup_fd[n*i + j] = 0.0; + } + } + + double delta = 0.001; + delta = delta * delta; + + double fx = simulate(A); + printf("f(A) = %f\n", fx); + + // if (argc == 2) { + __enzyme_autodiff((void *)simulate, enzyme_dup, &A[0], &Adup[0]); + printf("dP(0,0) = %f, dP(0,1) = %f, dP(1,0) = %f\n", Adup[0], Adup[1], Adup[2]); + //} + + for (int i = 0; i < n*n; i++) { + A[i] += delta / 2; + double fx2 = simulate(A); + A[i] -= delta; + double fx3 = simulate(A); + A[i] += delta/2; + Adup_fd[i] = (fx2 - fx3) / delta; + + printf("dA_fd[%d]=%f\n", i, Adup_fd[i]); + + APPROX_EQ(Adup[i], Adup_fd[i], 1e-6); + } + + return 0; +} diff --git a/enzyme/test/Integration/blas_inline.h b/enzyme/test/Integration/blas_inline.h index 46507a65c20f..8ac0c16f51b3 100644 --- a/enzyme/test/Integration/blas_inline.h +++ b/enzyme/test/Integration/blas_inline.h @@ -28,7 +28,7 @@ int xerbla_(const char *srname, integer *info, int len) return 0; } __attribute__((noinline)) -logical lsame_(char *ca, char *cb, int, int) +logical lsame_(char *ca, char *cb, int ca_size, int cb_size) { /* System generated locals */ logical ret_val; @@ -764,12 +764,17 @@ doublereal ddot_(integer *n, doublereal *dx, integer *incx, doublereal *dy, } /* ddot_ */ __attribute__((noinline)) -/* Subroutine */ int dgemm_(const char *transa, const char *transb, const integer *m, const integer * +/* Subroutine */ int dgemm_(const char *transa_t, const char *transb_t, const integer *m, const integer * n, const integer *k, const doublereal *alpha, const doublereal *a, const integer *lda, const doublereal *b, const integer *ldb, const doublereal *beta, doublereal *c, const integer *ldc) { + char transa_v = *transa_t; + char* transa = &transa_v; + + char transb_v = *transb_t; + char* transb = &transb_v; /* System generated locals */ integer a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset, i__1, i__2, diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index 89f8a8962b76..c5a35a5089d5 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -114,8 +114,9 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { } } - os << " if (!shadow && need_" << argname << " && !cache_" << argname - << ")\n" + os << " if (!shadow && need_" << argname + << " && ((cacheMode && overwritten_args_ptr) ? !cache_" << argname + << " : true ))\n" << " return true;\n"; os << " }\n"; } From 7d2b6303902c5cda918623f3e7ae1b3714eddf78 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 Nov 2023 08:56:15 -0600 Subject: [PATCH 6/6] Fix stability of minCut (#1527) --- enzyme/Enzyme/AdjointGenerator.h | 11 +++++++++-- enzyme/Enzyme/CallDerivatives.cpp | 11 +++++++++-- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 14 ++++++++------ enzyme/Enzyme/DifferentialUseAnalysis.h | 10 +++++----- enzyme/Enzyme/GradientUtils.cpp | 15 ++++++++------- 5 files changed, 39 insertions(+), 22 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index c1be8be8ab6c..e11522ad1d4c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -915,9 +915,16 @@ class AdjointGenerator bool forceErase = false; if (Mode == DerivativeMode::ReverseModeGradient) { + // Since we won't redo the store in the reverse pass, do not + // force the write barrier. + forceErase = true; for (const auto &pair : gutils->rematerializableAllocations) { - if (pair.second.stores.count(&SI) && pair.second.LI) { - forceErase = true; + // However, if we are rematerailizing the allocationa and not + // inside the loop level rematerialization, we do still need the + // reverse passes ``fake primal'' store and therefore write barrier + if (pair.second.stores.count(&SI) && + (!pair.second.LI || !pair.second.LI->contains(&SI))) { + forceErase = false; } } } diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 68af35acd314..0dff4eac5f6b 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -2667,9 +2667,16 @@ bool AdjointGenerator::handleKnownCallDerivatives( bool forceErase = false; if (Mode == DerivativeMode::ReverseModeGradient) { + // Since we won't redo the store in the reverse pass, do not + // force the write barrier. + forceErase = true; for (const auto &pair : gutils->rematerializableAllocations) { - if (pair.second.stores.count(&call) && pair.second.LI) { - forceErase = true; + // However, if we are rematerailizing the allocationa and not + // inside the loop level rematerialization, we do still need the + // reverse passes ``fake primal'' store and therefore write barrier + if (pair.second.stores.count(&call) && + (!pair.second.LI || !pair.second.LI->contains(&call))) { + forceErase = false; } } } diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 3d3435f98fe0..7edab4fb53ad 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -682,7 +682,7 @@ void DifferentialUseAnalysis::dump(Graph &G) { /* Returns true if there is a path from source 's' to sink 't' in residual graph. Also fills parent[] to store the path */ void DifferentialUseAnalysis::bfs(const Graph &G, - const SmallPtrSetImpl &Recompute, + const SetVector &Recompute, std::map &parent) { std::deque q; for (auto V : Recompute) { @@ -726,9 +726,9 @@ int DifferentialUseAnalysis::cmpLoopNest(Loop *prev, Loop *next) { void DifferentialUseAnalysis::minCut( const DataLayout &DL, LoopInfo &OrigLI, - const SmallPtrSetImpl &Recomputes, - const SmallPtrSetImpl &Intermediates, - SmallPtrSetImpl &Required, SmallPtrSetImpl &MinReq, + const SetVector &Recomputes, + const SetVector &Intermediates, SetVector &Required, + SetVector &MinReq, const ValueMap &rematerializableAllocations, llvm::TargetLibraryInfo &TLI) { @@ -810,6 +810,8 @@ void DifferentialUseAnalysis::minCut( std::map parent; bfs(G, Recomputes, parent); + std::deque todo; + // Print all edges that are from a reachable vertex to // non-reachable vertex in the original graph for (auto &pair : Orig) { @@ -819,13 +821,13 @@ void DifferentialUseAnalysis::minCut( assert(pair.first.outgoing == 0 && N.outgoing == 1); assert(pair.first.V == N.V); MinReq.insert(N.V); + todo.push_back(N.V); } } } // When ambiguous, push to cache the last value in a computation chain // This should be considered in a cost for the max flow - std::deque todo(MinReq.begin(), MinReq.end()); while (todo.size()) { auto V = todo.front(); todo.pop_front(); @@ -889,7 +891,7 @@ void DifferentialUseAnalysis::minCut( (moreOuterLoop == 0 && DL.getTypeSizeInBits(V->getType()) >= DL.getTypeSizeInBits((*found->second.begin()).V->getType()))) { - MinReq.erase(V); + MinReq.remove(V); MinReq.insert((*found->second.begin()).V); todo.push_back((*found->second.begin()).V); } diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 4c9796af5928..cefde2726691 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -397,7 +397,7 @@ void dump(std::map> &G); /* Returns true if there is a path from source 's' to sink 't' in residual graph. Also fills parent[] to store the path */ void bfs(const std::map> &G, - const llvm::SmallPtrSetImpl &Recompute, + const llvm::SetVector &Recompute, std::map &parent); // Return 1 if next is better @@ -406,10 +406,10 @@ void bfs(const std::map> &G, int cmpLoopNest(llvm::Loop *prev, llvm::Loop *next); void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, - const llvm::SmallPtrSetImpl &Recomputes, - const llvm::SmallPtrSetImpl &Intermediates, - llvm::SmallPtrSetImpl &Required, - llvm::SmallPtrSetImpl &MinReq, + const llvm::SetVector &Recomputes, + const llvm::SetVector &Intermediates, + llvm::SetVector &Required, + llvm::SetVector &MinReq, const llvm::ValueMap &rematerializableAllocations, llvm::TargetLibraryInfo &TLI); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index f08fea3f44ba..f7dc982741a9 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -7774,7 +7774,7 @@ nofast:; void GradientUtils::computeMinCache() { if (EnzymeMinCutCache) { - SmallPtrSet Recomputes; + SetVector Recomputes; std::map FullSeen; std::map OneLevelSeen; @@ -7919,11 +7919,8 @@ void GradientUtils::computeMinCache() { } } - SmallPtrSet Intermediates; - SmallPtrSet Required; - - Intermediates.clear(); - Required.clear(); + SetVector Intermediates; + SetVector Required; std::deque todo(Recomputes.begin(), Recomputes.end()); while (todo.size()) { @@ -7970,7 +7967,7 @@ void GradientUtils::computeMinCache() { } } - SmallPtrSet MinReq; + SetVector MinReq; DifferentialUseAnalysis::minCut(oldFunc->getParent()->getDataLayout(), OrigLI, Recomputes, Intermediates, Required, MinReq, rematerializableAllocations, TLI); @@ -9200,6 +9197,10 @@ bool GradientUtils::needsCacheWholeAllocation( // If caching this user, it cannot be a gep/cast of original if (!found->second) { + llvm::errs() << " oldFunc: " << *oldFunc << "\n"; + for (auto &pair : knownRecomputeHeuristic) + llvm::errs() << " krc[" << *pair.first << "] = " << pair.second << "\n"; + llvm::errs() << " cur: " << *cur << "\n"; assert(false && "caching potentially capturing/offset of allocation"); } else { // if not caching this user, it is legal to recompute, consider its users