From 53735ae086d43985eaa50a58203efc7b7dfbcee4 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Tue, 30 Jul 2024 21:57:35 +0000 Subject: [PATCH] Add fallback search for extra-underscore-suffixed symbols MKL v2024 has ILP64-suffixed symbols, but the suffix applied to FORTRAN symbols (`64`, mapping `dgemm_` to `dgemm_64`) is different from the suffix applied to non-FORTRAN symbols (`_64`, mapping `cblas_dgemm` to `cblas_dgemm_64`). This wreaks havoc with LBT, which assumes that all symbols undergo the same name transformation. To work around this, we add a fallback path to our symbol forwarding routine; if a symbol does not exist in a library, we look for the same symbol but with an underscore in front of the symbol name. Because this all happens only after we have already found `isamax_` and `dpotrf_` in `autodetect_symbol_suffix()`, we have some measure of assurance that we will not be blindly guessing ridiculous symbol names. --- ext/gensymbol/generate_func_list.sh | 15 +- src/autodetection.c | 9 ++ src/exported_funcs.inc | 208 ++++++++++++++++++++++++++++ src/libblastrampoline.c | 19 ++- src/libblastrampoline_cblasdata.h | 10 +- src/mkl_v2022_adapters.c | 25 ---- test/direct.jl | 4 +- 7 files changed, 255 insertions(+), 35 deletions(-) delete mode 100644 src/mkl_v2022_adapters.c diff --git a/ext/gensymbol/generate_func_list.sh b/ext/gensymbol/generate_func_list.sh index 6d47e3c..feebb3d 100755 --- a/ext/gensymbol/generate_func_list.sh +++ b/ext/gensymbol/generate_func_list.sh @@ -106,6 +106,18 @@ echo >> "${OUTPUT_FILE}" echo "#endif" >> "${OUTPUT_FILE}" NUM_COMPLEX128_SYMBOLS="${NUM_SYMBOLS}" +NUM_SYMBOLS=0 +CBLAS_FUNCS="$(grep -e '^cblas_.*$' <<< "${EXPORTED_FUNCS}")" +echo >> "${OUTPUT_FILE}" +echo "#ifndef CBLAS_FUNCS" >> "${OUTPUT_FILE}" +echo "#define CBLAS_FUNCS(XX) \\" >> "${OUTPUT_FILE}" +for func_name in ${CBLAS_FUNCS}; do + output_func "${func_name}" +done +echo >> "${OUTPUT_FILE}" +echo "#endif" >> "${OUTPUT_FILE}" +NUM_CBLAS_SYMBOLS="${NUM_SYMBOLS}" + NUM_SYMBOLS=0 # We manually curate a list of cblas functions that we have defined adapters for # in `src/cblas_adapters.c`. This is our compromise between the crushing workload @@ -123,9 +135,10 @@ echo >> "${OUTPUT_FILE}" echo "#endif" >> "${OUTPUT_FILE}" NUM_CBLAS_WORKAROUND_SYMBOLS="${NUM_SYMBOLS}" + # Report to the user and cleanup echo NUM_F2C_SYMBOLS="$((NUM_FLOAT32_SYMBOLS + NUM_COMPLEX64_SYMBOLS + NUM_COMPLEX128_SYMBOLS))" NUM_CMPLX_SYMBOLS="$((NUM_COMPLEX64_SYMBOLS + NUM_COMPLEX128_SYMBOLS))" -echo "Done, with ${NUM_EXPORTED} symbols generated (${NUM_F2C_SYMBOLS} f2c, ${NUM_CMPLX_SYMBOLS} complex-returning, ${NUM_CBLAS_WORKAROUND_SYMBOLS} cblas-workaround functions)." +echo "Done, with ${NUM_EXPORTED} symbols generated (${NUM_F2C_SYMBOLS} f2c, ${NUM_CMPLX_SYMBOLS} complex-returning, ${NUM_CBLAS_SYMBOLS} cblas, ${NUM_CBLAS_WORKAROUND_SYMBOLS} cblas-workaround functions)." rm -f tempsymbols.def diff --git a/src/autodetection.c b/src/autodetection.c index 1d6f59a..f96bc06 100644 --- a/src/autodetection.c +++ b/src/autodetection.c @@ -362,6 +362,8 @@ int32_t autodetect_cblas_divergence(void * handle, const char * suffix) { #ifdef CBLAS_DIVERGENCE_AUTODETECTION char symbol_name[MAX_SYMBOL_LEN]; + char extra_underscore_suffix[MAX_SYMBOL_LEN]; + snprintf(extra_underscore_suffix, MAX_SYMBOL_LEN, "_%s", suffix); build_symbol_name(symbol_name, "zdotc_", suffix); if (lookup_symbol(handle, symbol_name) != NULL ) { @@ -371,6 +373,13 @@ int32_t autodetect_cblas_divergence(void * handle, const char * suffix) { return LBT_CBLAS_CONFORMANT; } + // Do the fallback extra-underscore suffix search here, so we don't mistakenly + // mark MKL v2024 as CBLAS-divergent + build_symbol_name(symbol_name, "cblas_zdotc_sub", extra_underscore_suffix); + if (lookup_symbol(handle, symbol_name) != NULL ) { + return LBT_CBLAS_CONFORMANT; + } + const char * lp64_suffixes[] = { "", "_", "__", }; diff --git a/src/exported_funcs.inc b/src/exported_funcs.inc index 3ea2530..669c2c3 100644 --- a/src/exported_funcs.inc +++ b/src/exported_funcs.inc @@ -5164,6 +5164,214 @@ #endif +#ifndef CBLAS_FUNCS +#define CBLAS_FUNCS(XX) \ + XX(cblas_caxpby, 2609) \ + XX(cblas_caxpy, 2610) \ + XX(cblas_caxpyc, 2611) \ + XX(cblas_ccopy, 2612) \ + XX(cblas_cdotc, 2613) \ + XX(cblas_cdotc_sub, 2614) \ + XX(cblas_cdotu, 2615) \ + XX(cblas_cdotu_sub, 2616) \ + XX(cblas_cgbmv, 2617) \ + XX(cblas_cgeadd, 2618) \ + XX(cblas_cgemm, 2619) \ + XX(cblas_cgemm3m, 2620) \ + XX(cblas_cgemmt, 2621) \ + XX(cblas_cgemv, 2622) \ + XX(cblas_cgerc, 2623) \ + XX(cblas_cgeru, 2624) \ + XX(cblas_chbmv, 2625) \ + XX(cblas_chemm, 2626) \ + XX(cblas_chemv, 2627) \ + XX(cblas_cher, 2628) \ + XX(cblas_cher2, 2629) \ + XX(cblas_cher2k, 2630) \ + XX(cblas_cherk, 2631) \ + XX(cblas_chpmv, 2632) \ + XX(cblas_chpr, 2633) \ + XX(cblas_chpr2, 2634) \ + XX(cblas_cimatcopy, 2635) \ + XX(cblas_comatcopy, 2636) \ + XX(cblas_crotg, 2637) \ + XX(cblas_cscal, 2638) \ + XX(cblas_csrot, 2639) \ + XX(cblas_csscal, 2640) \ + XX(cblas_cswap, 2641) \ + XX(cblas_csymm, 2642) \ + XX(cblas_csyr2k, 2643) \ + XX(cblas_csyrk, 2644) \ + XX(cblas_ctbmv, 2645) \ + XX(cblas_ctbsv, 2646) \ + XX(cblas_ctpmv, 2647) \ + XX(cblas_ctpsv, 2648) \ + XX(cblas_ctrmm, 2649) \ + XX(cblas_ctrmv, 2650) \ + XX(cblas_ctrsm, 2651) \ + XX(cblas_ctrsv, 2652) \ + XX(cblas_damax, 2653) \ + XX(cblas_damin, 2654) \ + XX(cblas_dasum, 2655) \ + XX(cblas_daxpby, 2656) \ + XX(cblas_daxpy, 2657) \ + XX(cblas_dbf16tod, 2658) \ + XX(cblas_dcopy, 2659) \ + XX(cblas_ddot, 2660) \ + XX(cblas_dgbmv, 2661) \ + XX(cblas_dgeadd, 2662) \ + XX(cblas_dgemm, 2663) \ + XX(cblas_dgemmt, 2664) \ + XX(cblas_dgemv, 2665) \ + XX(cblas_dger, 2666) \ + XX(cblas_dimatcopy, 2667) \ + XX(cblas_dnrm2, 2668) \ + XX(cblas_domatcopy, 2669) \ + XX(cblas_drot, 2670) \ + XX(cblas_drotg, 2671) \ + XX(cblas_drotm, 2672) \ + XX(cblas_drotmg, 2673) \ + XX(cblas_dsbmv, 2674) \ + XX(cblas_dscal, 2675) \ + XX(cblas_dsdot, 2676) \ + XX(cblas_dspmv, 2677) \ + XX(cblas_dspr, 2678) \ + XX(cblas_dspr2, 2679) \ + XX(cblas_dsum, 2680) \ + XX(cblas_dswap, 2681) \ + XX(cblas_dsymm, 2682) \ + XX(cblas_dsymv, 2683) \ + XX(cblas_dsyr, 2684) \ + XX(cblas_dsyr2, 2685) \ + XX(cblas_dsyr2k, 2686) \ + XX(cblas_dsyrk, 2687) \ + XX(cblas_dtbmv, 2688) \ + XX(cblas_dtbsv, 2689) \ + XX(cblas_dtpmv, 2690) \ + XX(cblas_dtpsv, 2691) \ + XX(cblas_dtrmm, 2692) \ + XX(cblas_dtrmv, 2693) \ + XX(cblas_dtrsm, 2694) \ + XX(cblas_dtrsv, 2695) \ + XX(cblas_dzamax, 2696) \ + XX(cblas_dzamin, 2697) \ + XX(cblas_dzasum, 2698) \ + XX(cblas_dznrm2, 2699) \ + XX(cblas_dzsum, 2700) \ + XX(cblas_icamax, 2701) \ + XX(cblas_icamin, 2702) \ + XX(cblas_icmax, 2703) \ + XX(cblas_icmin, 2704) \ + XX(cblas_idamax, 2705) \ + XX(cblas_idamin, 2706) \ + XX(cblas_idmax, 2707) \ + XX(cblas_idmin, 2708) \ + XX(cblas_isamax, 2709) \ + XX(cblas_isamin, 2710) \ + XX(cblas_ismax, 2711) \ + XX(cblas_ismin, 2712) \ + XX(cblas_izamax, 2713) \ + XX(cblas_izamin, 2714) \ + XX(cblas_izmax, 2715) \ + XX(cblas_izmin, 2716) \ + XX(cblas_samax, 2717) \ + XX(cblas_samin, 2718) \ + XX(cblas_sasum, 2719) \ + XX(cblas_saxpby, 2720) \ + XX(cblas_saxpy, 2721) \ + XX(cblas_sbdot, 2722) \ + XX(cblas_sbdtobf16, 2723) \ + XX(cblas_sbf16tos, 2724) \ + XX(cblas_sbgemm, 2725) \ + XX(cblas_sbgemv, 2726) \ + XX(cblas_sbstobf16, 2727) \ + XX(cblas_scamax, 2728) \ + XX(cblas_scamin, 2729) \ + XX(cblas_scasum, 2730) \ + XX(cblas_scnrm2, 2731) \ + XX(cblas_scopy, 2732) \ + XX(cblas_scsum, 2733) \ + XX(cblas_sdot, 2734) \ + XX(cblas_sdsdot, 2735) \ + XX(cblas_sgbmv, 2736) \ + XX(cblas_sgeadd, 2737) \ + XX(cblas_sgemm, 2738) \ + XX(cblas_sgemmt, 2739) \ + XX(cblas_sgemv, 2740) \ + XX(cblas_sger, 2741) \ + XX(cblas_simatcopy, 2742) \ + XX(cblas_snrm2, 2743) \ + XX(cblas_somatcopy, 2744) \ + XX(cblas_srot, 2745) \ + XX(cblas_srotg, 2746) \ + XX(cblas_srotm, 2747) \ + XX(cblas_srotmg, 2748) \ + XX(cblas_ssbmv, 2749) \ + XX(cblas_sscal, 2750) \ + XX(cblas_sspmv, 2751) \ + XX(cblas_sspr, 2752) \ + XX(cblas_sspr2, 2753) \ + XX(cblas_ssum, 2754) \ + XX(cblas_sswap, 2755) \ + XX(cblas_ssymm, 2756) \ + XX(cblas_ssymv, 2757) \ + XX(cblas_ssyr, 2758) \ + XX(cblas_ssyr2, 2759) \ + XX(cblas_ssyr2k, 2760) \ + XX(cblas_ssyrk, 2761) \ + XX(cblas_stbmv, 2762) \ + XX(cblas_stbsv, 2763) \ + XX(cblas_stpmv, 2764) \ + XX(cblas_stpsv, 2765) \ + XX(cblas_strmm, 2766) \ + XX(cblas_strmv, 2767) \ + XX(cblas_strsm, 2768) \ + XX(cblas_strsv, 2769) \ + XX(cblas_xerbla, 2770) \ + XX(cblas_zaxpby, 2771) \ + XX(cblas_zaxpy, 2772) \ + XX(cblas_zcopy, 2773) \ + XX(cblas_zdotc, 2774) \ + XX(cblas_zdotc_sub, 2775) \ + XX(cblas_zdotu, 2776) \ + XX(cblas_zdotu_sub, 2777) \ + XX(cblas_zdscal, 2778) \ + XX(cblas_zgbmv, 2779) \ + XX(cblas_zgeadd, 2780) \ + XX(cblas_zgemm, 2781) \ + XX(cblas_zgemm3m, 2782) \ + XX(cblas_zgemmt, 2783) \ + XX(cblas_zgemv, 2784) \ + XX(cblas_zgerc, 2785) \ + XX(cblas_zgeru, 2786) \ + XX(cblas_zhbmv, 2787) \ + XX(cblas_zhemm, 2788) \ + XX(cblas_zhemv, 2789) \ + XX(cblas_zher, 2790) \ + XX(cblas_zher2, 2791) \ + XX(cblas_zher2k, 2792) \ + XX(cblas_zherk, 2793) \ + XX(cblas_zhpmv, 2794) \ + XX(cblas_zhpr, 2795) \ + XX(cblas_zhpr2, 2796) \ + XX(cblas_zimatcopy, 2797) \ + XX(cblas_zomatcopy, 2798) \ + XX(cblas_zscal, 2799) \ + XX(cblas_zswap, 2800) \ + XX(cblas_zsymm, 2801) \ + XX(cblas_zsyr2k, 2802) \ + XX(cblas_zsyrk, 2803) \ + XX(cblas_ztbmv, 2804) \ + XX(cblas_ztbsv, 2805) \ + XX(cblas_ztpmv, 2806) \ + XX(cblas_ztpsv, 2807) \ + XX(cblas_ztrmm, 2808) \ + XX(cblas_ztrmv, 2809) \ + XX(cblas_ztrsm, 2810) \ + XX(cblas_ztrsv, 2811) \ + +#endif + #ifndef CBLAS_WORKAROUND_FUNCS #define CBLAS_WORKAROUND_FUNCS(XX) \ XX(cblas_cdotc_sub, 2614) \ diff --git a/src/libblastrampoline.c b/src/libblastrampoline.c index cc6787b..fb0682a 100644 --- a/src/libblastrampoline.c +++ b/src/libblastrampoline.c @@ -237,6 +237,8 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v if (verbose) { printf(" -> Autodetected symbol suffix \"%s\"\n", lib_suffix); } + char extra_underscore_suffix[MAX_SYMBOL_LEN]; + snprintf(extra_underscore_suffix, MAX_SYMBOL_LEN, "_%s", lib_suffix); // Next, we need to figure out if it's a 32-bit or 64-bit BLAS library; // we'll do that by calling `autodetect_interface()`: @@ -392,6 +394,15 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v void *addr = lookup_symbol(handle, symbol_name); void *self_symbol_addr = interface == LBT_INTERFACE_ILP64 ? exported_func64[symbol_idx] \ : exported_func32[symbol_idx]; + if (addr == NULL ) { + // MKL (and other libraries too in the fullness of time, I have no doubt) doesn't like + // to slap `64` directly onto the end of their symbol names; they insert an extra `_` + // if the symbol is not a FORTRAN symbol (which would already have a `_` at the end) + // We catch this case here, as a fallback check: + build_symbol_name(symbol_name, exported_func_names[symbol_idx], extra_underscore_suffix); + addr = lookup_symbol(handle, symbol_name); + } + if (addr != NULL && addr != self_symbol_addr) { lbt_set_forward_by_index(symbol_idx, addr, interface, complex_retstyle, f2c, verbose); BITFIELD_SET(forwards, symbol_idx); @@ -404,8 +415,8 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v // the FORTRAN equivalents. if (cblas == LBT_CBLAS_DIVERGENT) { int32_t cblas_symbol_idx = 0; - for (cblas_symbol_idx = 0; cblas_func_idxs[cblas_symbol_idx] != -1; cblas_symbol_idx += 1) { - int32_t symbol_idx = cblas_func_idxs[cblas_symbol_idx]; + for (cblas_symbol_idx = 0; cblas_workaround_func_idxs[cblas_symbol_idx] != -1; cblas_symbol_idx += 1) { + int32_t symbol_idx = cblas_workaround_func_idxs[cblas_symbol_idx]; // Report to the user that we're cblas-wrapping this one if (verbose) { @@ -415,9 +426,9 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v } if (interface == LBT_INTERFACE_LP64) { - (*exported_func32_addrs[symbol_idx]) = cblas32_func_wrappers[cblas_symbol_idx]; + (*exported_func32_addrs[symbol_idx]) = cblas32_workaround_func_wrappers[cblas_symbol_idx]; } else { - (*exported_func64_addrs[symbol_idx]) = cblas64_func_wrappers[cblas_symbol_idx]; + (*exported_func64_addrs[symbol_idx]) = cblas64_workaround_func_wrappers[cblas_symbol_idx]; } } } diff --git a/src/libblastrampoline_cblasdata.h b/src/libblastrampoline_cblasdata.h index a9b6d83..5b2ce72 100644 --- a/src/libblastrampoline_cblasdata.h +++ b/src/libblastrampoline_cblasdata.h @@ -10,20 +10,24 @@ CBLAS_WORKAROUND_FUNCS(XX_64) // locations, allowing a cblas index -> function lookup #define XX(name, index) &lbt_##name, #define XX_64(name, index) &lbt_##name##64_, -const void ** cblas32_func_wrappers[] = { +const void ** cblas32_workaround_func_wrappers[] = { CBLAS_WORKAROUND_FUNCS(XX) NULL }; -const void ** cblas64_func_wrappers[] = { +const void ** cblas64_workaround_func_wrappers[] = { CBLAS_WORKAROUND_FUNCS(XX_64) NULL }; #undef XX #undef XX_64 -// Finally, an array that maps cblas index -> exported symbol index +// Finally, arrays that map cblas index -> exported symbol index #define XX(name, index) index, const int cblas_func_idxs[] = { + CBLAS_FUNCS(XX) + -1 +}; +const int cblas_workaround_func_idxs[] = { CBLAS_WORKAROUND_FUNCS(XX) -1 }; diff --git a/src/mkl_v2022_adapters.c b/src/mkl_v2022_adapters.c deleted file mode 100644 index ed03747..0000000 --- a/src/mkl_v2022_adapters.c +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include -#include "libblastrampoline_internal.h" - -/* - * In MKL v2022.0, a new ILP64 interface was released, but it sadly lacked a few symbol names. - * While waiting for a new release, the following intermediate - */ - -/* -double cblas_ddot_64(const int64_t N, const double *X, const int64_t incX, const double *Y, const int64_t incY) -{ - return ddot_64(&N, X, &incX, Y, &incY); -}*/ - -extern double complex (*mkl_cblas_zdotc_sub_64__addr)(const int64_t N, - const double complex *X, const int64_t incX, - const double complex *Y, const int64_t incY); -void mkl_cblas_zdotc_sub_64_(const int64_t N, - const double complex *X, const int64_t incX, - const double complex *Y, const int64_t incY, - double complex * z) -{ - *z = mkl_cblas_zdotc_sub_64__addr(N, X, incX, Y, incY); -} \ No newline at end of file diff --git a/test/direct.jl b/test/direct.jl index c46c85d..2eb535e 100644 --- a/test/direct.jl +++ b/test/direct.jl @@ -288,7 +288,7 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64 libs = unpack_loaded_libraries(config) @test length(libs) == 1 @test libs[1].interface == LBT_INTERFACE_ILP64 - @test libs[1].cblas == LBT_CBLAS_DIVERGENT + @test libs[1].cblas ∈ (LBT_CBLAS_DIVERGENT, LBT_CBLAS_CONFORMANT) @test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT # Call cblas_zdotc_sub, asserting that it does not try to call a forwardless-symbol @@ -318,7 +318,7 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64 libs = unpack_loaded_libraries(config) @test length(libs) == 1 @test libs[1].interface == LBT_INTERFACE_ILP64 - @test libs[1].cblas == LBT_CBLAS_DIVERGENT + @test libs[1].cblas ∈ (LBT_CBLAS_DIVERGENT, LBT_CBLAS_CONFORMANT) @test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT # Call cblas_cdotc_sub64_ to test the full CBLAS workaround -> complex return style handling chain