Skip to content

Commit

Permalink
Add fallback search for extra-underscore-suffixed symbols
Browse files Browse the repository at this point in the history
MKL v2024 has ILP64-suffixed symbols, but the suffix applied to FORTRAN
symbols (`64`, mapping `dgemm_` to `dgemm_64`) is different from the
suffix applied to non-FORTRAN symbols (`_64`, mapping `cblas_dgemm` to
`cblas_dgemm_64`).  This wreaks havoc with LBT, which assumes that all
symbols undergo the same name transformation.  To work around this, we
add a fallback path to our symbol forwarding routine; if a symbol does
not exist in a library, we look for the same symbol but with an
underscore in front of the symbol name. Because this all happens only
after we have already found `isamax_` and `dpotrf_` in
`autodetect_symbol_suffix()`, we have some measure of assurance that we
will not be blindly guessing ridiculous symbol names.
  • Loading branch information
staticfloat committed Jul 31, 2024
1 parent eaf1c3a commit 53735ae
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 35 deletions.
15 changes: 14 additions & 1 deletion ext/gensymbol/generate_func_list.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ echo >> "${OUTPUT_FILE}"
echo "#endif" >> "${OUTPUT_FILE}"
NUM_COMPLEX128_SYMBOLS="${NUM_SYMBOLS}"

NUM_SYMBOLS=0
CBLAS_FUNCS="$(grep -e '^cblas_.*$' <<< "${EXPORTED_FUNCS}")"
echo >> "${OUTPUT_FILE}"
echo "#ifndef CBLAS_FUNCS" >> "${OUTPUT_FILE}"
echo "#define CBLAS_FUNCS(XX) \\" >> "${OUTPUT_FILE}"
for func_name in ${CBLAS_FUNCS}; do
output_func "${func_name}"
done
echo >> "${OUTPUT_FILE}"
echo "#endif" >> "${OUTPUT_FILE}"
NUM_CBLAS_SYMBOLS="${NUM_SYMBOLS}"

NUM_SYMBOLS=0
# We manually curate a list of cblas functions that we have defined adapters for
# in `src/cblas_adapters.c`. This is our compromise between the crushing workload
Expand All @@ -123,9 +135,10 @@ echo >> "${OUTPUT_FILE}"
echo "#endif" >> "${OUTPUT_FILE}"
NUM_CBLAS_WORKAROUND_SYMBOLS="${NUM_SYMBOLS}"


# Report to the user and cleanup
echo
NUM_F2C_SYMBOLS="$((NUM_FLOAT32_SYMBOLS + NUM_COMPLEX64_SYMBOLS + NUM_COMPLEX128_SYMBOLS))"
NUM_CMPLX_SYMBOLS="$((NUM_COMPLEX64_SYMBOLS + NUM_COMPLEX128_SYMBOLS))"
echo "Done, with ${NUM_EXPORTED} symbols generated (${NUM_F2C_SYMBOLS} f2c, ${NUM_CMPLX_SYMBOLS} complex-returning, ${NUM_CBLAS_WORKAROUND_SYMBOLS} cblas-workaround functions)."
echo "Done, with ${NUM_EXPORTED} symbols generated (${NUM_F2C_SYMBOLS} f2c, ${NUM_CMPLX_SYMBOLS} complex-returning, ${NUM_CBLAS_SYMBOLS} cblas, ${NUM_CBLAS_WORKAROUND_SYMBOLS} cblas-workaround functions)."
rm -f tempsymbols.def
9 changes: 9 additions & 0 deletions src/autodetection.c
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ int32_t autodetect_cblas_divergence(void * handle, const char * suffix) {

#ifdef CBLAS_DIVERGENCE_AUTODETECTION
char symbol_name[MAX_SYMBOL_LEN];
char extra_underscore_suffix[MAX_SYMBOL_LEN];
snprintf(extra_underscore_suffix, MAX_SYMBOL_LEN, "_%s", suffix);

build_symbol_name(symbol_name, "zdotc_", suffix);
if (lookup_symbol(handle, symbol_name) != NULL ) {
Expand All @@ -371,6 +373,13 @@ int32_t autodetect_cblas_divergence(void * handle, const char * suffix) {
return LBT_CBLAS_CONFORMANT;
}

// Do the fallback extra-underscore suffix search here, so we don't mistakenly
// mark MKL v2024 as CBLAS-divergent
build_symbol_name(symbol_name, "cblas_zdotc_sub", extra_underscore_suffix);
if (lookup_symbol(handle, symbol_name) != NULL ) {
return LBT_CBLAS_CONFORMANT;
}

const char * lp64_suffixes[] = {
"", "_", "__",
};
Expand Down
208 changes: 208 additions & 0 deletions src/exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5164,6 +5164,214 @@

#endif

#ifndef CBLAS_FUNCS
#define CBLAS_FUNCS(XX) \
XX(cblas_caxpby, 2609) \
XX(cblas_caxpy, 2610) \
XX(cblas_caxpyc, 2611) \
XX(cblas_ccopy, 2612) \
XX(cblas_cdotc, 2613) \
XX(cblas_cdotc_sub, 2614) \
XX(cblas_cdotu, 2615) \
XX(cblas_cdotu_sub, 2616) \
XX(cblas_cgbmv, 2617) \
XX(cblas_cgeadd, 2618) \
XX(cblas_cgemm, 2619) \
XX(cblas_cgemm3m, 2620) \
XX(cblas_cgemmt, 2621) \
XX(cblas_cgemv, 2622) \
XX(cblas_cgerc, 2623) \
XX(cblas_cgeru, 2624) \
XX(cblas_chbmv, 2625) \
XX(cblas_chemm, 2626) \
XX(cblas_chemv, 2627) \
XX(cblas_cher, 2628) \
XX(cblas_cher2, 2629) \
XX(cblas_cher2k, 2630) \
XX(cblas_cherk, 2631) \
XX(cblas_chpmv, 2632) \
XX(cblas_chpr, 2633) \
XX(cblas_chpr2, 2634) \
XX(cblas_cimatcopy, 2635) \
XX(cblas_comatcopy, 2636) \
XX(cblas_crotg, 2637) \
XX(cblas_cscal, 2638) \
XX(cblas_csrot, 2639) \
XX(cblas_csscal, 2640) \
XX(cblas_cswap, 2641) \
XX(cblas_csymm, 2642) \
XX(cblas_csyr2k, 2643) \
XX(cblas_csyrk, 2644) \
XX(cblas_ctbmv, 2645) \
XX(cblas_ctbsv, 2646) \
XX(cblas_ctpmv, 2647) \
XX(cblas_ctpsv, 2648) \
XX(cblas_ctrmm, 2649) \
XX(cblas_ctrmv, 2650) \
XX(cblas_ctrsm, 2651) \
XX(cblas_ctrsv, 2652) \
XX(cblas_damax, 2653) \
XX(cblas_damin, 2654) \
XX(cblas_dasum, 2655) \
XX(cblas_daxpby, 2656) \
XX(cblas_daxpy, 2657) \
XX(cblas_dbf16tod, 2658) \
XX(cblas_dcopy, 2659) \
XX(cblas_ddot, 2660) \
XX(cblas_dgbmv, 2661) \
XX(cblas_dgeadd, 2662) \
XX(cblas_dgemm, 2663) \
XX(cblas_dgemmt, 2664) \
XX(cblas_dgemv, 2665) \
XX(cblas_dger, 2666) \
XX(cblas_dimatcopy, 2667) \
XX(cblas_dnrm2, 2668) \
XX(cblas_domatcopy, 2669) \
XX(cblas_drot, 2670) \
XX(cblas_drotg, 2671) \
XX(cblas_drotm, 2672) \
XX(cblas_drotmg, 2673) \
XX(cblas_dsbmv, 2674) \
XX(cblas_dscal, 2675) \
XX(cblas_dsdot, 2676) \
XX(cblas_dspmv, 2677) \
XX(cblas_dspr, 2678) \
XX(cblas_dspr2, 2679) \
XX(cblas_dsum, 2680) \
XX(cblas_dswap, 2681) \
XX(cblas_dsymm, 2682) \
XX(cblas_dsymv, 2683) \
XX(cblas_dsyr, 2684) \
XX(cblas_dsyr2, 2685) \
XX(cblas_dsyr2k, 2686) \
XX(cblas_dsyrk, 2687) \
XX(cblas_dtbmv, 2688) \
XX(cblas_dtbsv, 2689) \
XX(cblas_dtpmv, 2690) \
XX(cblas_dtpsv, 2691) \
XX(cblas_dtrmm, 2692) \
XX(cblas_dtrmv, 2693) \
XX(cblas_dtrsm, 2694) \
XX(cblas_dtrsv, 2695) \
XX(cblas_dzamax, 2696) \
XX(cblas_dzamin, 2697) \
XX(cblas_dzasum, 2698) \
XX(cblas_dznrm2, 2699) \
XX(cblas_dzsum, 2700) \
XX(cblas_icamax, 2701) \
XX(cblas_icamin, 2702) \
XX(cblas_icmax, 2703) \
XX(cblas_icmin, 2704) \
XX(cblas_idamax, 2705) \
XX(cblas_idamin, 2706) \
XX(cblas_idmax, 2707) \
XX(cblas_idmin, 2708) \
XX(cblas_isamax, 2709) \
XX(cblas_isamin, 2710) \
XX(cblas_ismax, 2711) \
XX(cblas_ismin, 2712) \
XX(cblas_izamax, 2713) \
XX(cblas_izamin, 2714) \
XX(cblas_izmax, 2715) \
XX(cblas_izmin, 2716) \
XX(cblas_samax, 2717) \
XX(cblas_samin, 2718) \
XX(cblas_sasum, 2719) \
XX(cblas_saxpby, 2720) \
XX(cblas_saxpy, 2721) \
XX(cblas_sbdot, 2722) \
XX(cblas_sbdtobf16, 2723) \
XX(cblas_sbf16tos, 2724) \
XX(cblas_sbgemm, 2725) \
XX(cblas_sbgemv, 2726) \
XX(cblas_sbstobf16, 2727) \
XX(cblas_scamax, 2728) \
XX(cblas_scamin, 2729) \
XX(cblas_scasum, 2730) \
XX(cblas_scnrm2, 2731) \
XX(cblas_scopy, 2732) \
XX(cblas_scsum, 2733) \
XX(cblas_sdot, 2734) \
XX(cblas_sdsdot, 2735) \
XX(cblas_sgbmv, 2736) \
XX(cblas_sgeadd, 2737) \
XX(cblas_sgemm, 2738) \
XX(cblas_sgemmt, 2739) \
XX(cblas_sgemv, 2740) \
XX(cblas_sger, 2741) \
XX(cblas_simatcopy, 2742) \
XX(cblas_snrm2, 2743) \
XX(cblas_somatcopy, 2744) \
XX(cblas_srot, 2745) \
XX(cblas_srotg, 2746) \
XX(cblas_srotm, 2747) \
XX(cblas_srotmg, 2748) \
XX(cblas_ssbmv, 2749) \
XX(cblas_sscal, 2750) \
XX(cblas_sspmv, 2751) \
XX(cblas_sspr, 2752) \
XX(cblas_sspr2, 2753) \
XX(cblas_ssum, 2754) \
XX(cblas_sswap, 2755) \
XX(cblas_ssymm, 2756) \
XX(cblas_ssymv, 2757) \
XX(cblas_ssyr, 2758) \
XX(cblas_ssyr2, 2759) \
XX(cblas_ssyr2k, 2760) \
XX(cblas_ssyrk, 2761) \
XX(cblas_stbmv, 2762) \
XX(cblas_stbsv, 2763) \
XX(cblas_stpmv, 2764) \
XX(cblas_stpsv, 2765) \
XX(cblas_strmm, 2766) \
XX(cblas_strmv, 2767) \
XX(cblas_strsm, 2768) \
XX(cblas_strsv, 2769) \
XX(cblas_xerbla, 2770) \
XX(cblas_zaxpby, 2771) \
XX(cblas_zaxpy, 2772) \
XX(cblas_zcopy, 2773) \
XX(cblas_zdotc, 2774) \
XX(cblas_zdotc_sub, 2775) \
XX(cblas_zdotu, 2776) \
XX(cblas_zdotu_sub, 2777) \
XX(cblas_zdscal, 2778) \
XX(cblas_zgbmv, 2779) \
XX(cblas_zgeadd, 2780) \
XX(cblas_zgemm, 2781) \
XX(cblas_zgemm3m, 2782) \
XX(cblas_zgemmt, 2783) \
XX(cblas_zgemv, 2784) \
XX(cblas_zgerc, 2785) \
XX(cblas_zgeru, 2786) \
XX(cblas_zhbmv, 2787) \
XX(cblas_zhemm, 2788) \
XX(cblas_zhemv, 2789) \
XX(cblas_zher, 2790) \
XX(cblas_zher2, 2791) \
XX(cblas_zher2k, 2792) \
XX(cblas_zherk, 2793) \
XX(cblas_zhpmv, 2794) \
XX(cblas_zhpr, 2795) \
XX(cblas_zhpr2, 2796) \
XX(cblas_zimatcopy, 2797) \
XX(cblas_zomatcopy, 2798) \
XX(cblas_zscal, 2799) \
XX(cblas_zswap, 2800) \
XX(cblas_zsymm, 2801) \
XX(cblas_zsyr2k, 2802) \
XX(cblas_zsyrk, 2803) \
XX(cblas_ztbmv, 2804) \
XX(cblas_ztbsv, 2805) \
XX(cblas_ztpmv, 2806) \
XX(cblas_ztpsv, 2807) \
XX(cblas_ztrmm, 2808) \
XX(cblas_ztrmv, 2809) \
XX(cblas_ztrsm, 2810) \
XX(cblas_ztrsv, 2811) \

#endif

#ifndef CBLAS_WORKAROUND_FUNCS
#define CBLAS_WORKAROUND_FUNCS(XX) \
XX(cblas_cdotc_sub, 2614) \
Expand Down
19 changes: 15 additions & 4 deletions src/libblastrampoline.c
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
if (verbose) {
printf(" -> Autodetected symbol suffix \"%s\"\n", lib_suffix);
}
char extra_underscore_suffix[MAX_SYMBOL_LEN];
snprintf(extra_underscore_suffix, MAX_SYMBOL_LEN, "_%s", lib_suffix);

// Next, we need to figure out if it's a 32-bit or 64-bit BLAS library;
// we'll do that by calling `autodetect_interface()`:
Expand Down Expand Up @@ -392,6 +394,15 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
void *addr = lookup_symbol(handle, symbol_name);
void *self_symbol_addr = interface == LBT_INTERFACE_ILP64 ? exported_func64[symbol_idx] \
: exported_func32[symbol_idx];
if (addr == NULL ) {
// MKL (and other libraries too in the fullness of time, I have no doubt) doesn't like
// to slap `64` directly onto the end of their symbol names; they insert an extra `_`
// if the symbol is not a FORTRAN symbol (which would already have a `_` at the end)
// We catch this case here, as a fallback check:
build_symbol_name(symbol_name, exported_func_names[symbol_idx], extra_underscore_suffix);
addr = lookup_symbol(handle, symbol_name);
}

if (addr != NULL && addr != self_symbol_addr) {
lbt_set_forward_by_index(symbol_idx, addr, interface, complex_retstyle, f2c, verbose);
BITFIELD_SET(forwards, symbol_idx);
Expand All @@ -404,8 +415,8 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
// the FORTRAN equivalents.
if (cblas == LBT_CBLAS_DIVERGENT) {
int32_t cblas_symbol_idx = 0;
for (cblas_symbol_idx = 0; cblas_func_idxs[cblas_symbol_idx] != -1; cblas_symbol_idx += 1) {
int32_t symbol_idx = cblas_func_idxs[cblas_symbol_idx];
for (cblas_symbol_idx = 0; cblas_workaround_func_idxs[cblas_symbol_idx] != -1; cblas_symbol_idx += 1) {
int32_t symbol_idx = cblas_workaround_func_idxs[cblas_symbol_idx];

// Report to the user that we're cblas-wrapping this one
if (verbose) {
Expand All @@ -415,9 +426,9 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
}

if (interface == LBT_INTERFACE_LP64) {
(*exported_func32_addrs[symbol_idx]) = cblas32_func_wrappers[cblas_symbol_idx];
(*exported_func32_addrs[symbol_idx]) = cblas32_workaround_func_wrappers[cblas_symbol_idx];
} else {
(*exported_func64_addrs[symbol_idx]) = cblas64_func_wrappers[cblas_symbol_idx];
(*exported_func64_addrs[symbol_idx]) = cblas64_workaround_func_wrappers[cblas_symbol_idx];
}
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/libblastrampoline_cblasdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,24 @@ CBLAS_WORKAROUND_FUNCS(XX_64)
// locations, allowing a cblas index -> function lookup
#define XX(name, index) &lbt_##name,
#define XX_64(name, index) &lbt_##name##64_,
const void ** cblas32_func_wrappers[] = {
const void ** cblas32_workaround_func_wrappers[] = {
CBLAS_WORKAROUND_FUNCS(XX)
NULL
};
const void ** cblas64_func_wrappers[] = {
const void ** cblas64_workaround_func_wrappers[] = {
CBLAS_WORKAROUND_FUNCS(XX_64)
NULL
};
#undef XX
#undef XX_64

// Finally, an array that maps cblas index -> exported symbol index
// Finally, arrays that map cblas index -> exported symbol index
#define XX(name, index) index,
const int cblas_func_idxs[] = {
CBLAS_FUNCS(XX)
-1
};
const int cblas_workaround_func_idxs[] = {
CBLAS_WORKAROUND_FUNCS(XX)
-1
};
Expand Down
25 changes: 0 additions & 25 deletions src/mkl_v2022_adapters.c

This file was deleted.

4 changes: 2 additions & 2 deletions test/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64
libs = unpack_loaded_libraries(config)
@test length(libs) == 1
@test libs[1].interface == LBT_INTERFACE_ILP64
@test libs[1].cblas == LBT_CBLAS_DIVERGENT
@test libs[1].cblas (LBT_CBLAS_DIVERGENT, LBT_CBLAS_CONFORMANT)
@test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT

# Call cblas_zdotc_sub, asserting that it does not try to call a forwardless-symbol
Expand Down Expand Up @@ -318,7 +318,7 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64
libs = unpack_loaded_libraries(config)
@test length(libs) == 1
@test libs[1].interface == LBT_INTERFACE_ILP64
@test libs[1].cblas == LBT_CBLAS_DIVERGENT
@test libs[1].cblas (LBT_CBLAS_DIVERGENT, LBT_CBLAS_CONFORMANT)
@test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT

# Call cblas_cdotc_sub64_ to test the full CBLAS workaround -> complex return style handling chain
Expand Down

0 comments on commit 53735ae

Please sign in to comment.