Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fallback search for extra-underscore-suffixed symbols #137

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion ext/gensymbol/generate_func_list.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ echo >> "${OUTPUT_FILE}"
echo "#endif" >> "${OUTPUT_FILE}"
NUM_COMPLEX128_SYMBOLS="${NUM_SYMBOLS}"

NUM_SYMBOLS=0
CBLAS_FUNCS="$(grep -e '^cblas_.*$' <<< "${EXPORTED_FUNCS}")"
echo >> "${OUTPUT_FILE}"
echo "#ifndef CBLAS_FUNCS" >> "${OUTPUT_FILE}"
echo "#define CBLAS_FUNCS(XX) \\" >> "${OUTPUT_FILE}"
for func_name in ${CBLAS_FUNCS}; do
output_func "${func_name}"
done
echo >> "${OUTPUT_FILE}"
echo "#endif" >> "${OUTPUT_FILE}"
NUM_CBLAS_SYMBOLS="${NUM_SYMBOLS}"

NUM_SYMBOLS=0
# We manually curate a list of cblas functions that we have defined adapters for
# in `src/cblas_adapters.c`. This is our compromise between the crushing workload
Expand All @@ -123,9 +135,10 @@ echo >> "${OUTPUT_FILE}"
echo "#endif" >> "${OUTPUT_FILE}"
NUM_CBLAS_WORKAROUND_SYMBOLS="${NUM_SYMBOLS}"


# Report to the user and cleanup
echo
NUM_F2C_SYMBOLS="$((NUM_FLOAT32_SYMBOLS + NUM_COMPLEX64_SYMBOLS + NUM_COMPLEX128_SYMBOLS))"
NUM_CMPLX_SYMBOLS="$((NUM_COMPLEX64_SYMBOLS + NUM_COMPLEX128_SYMBOLS))"
echo "Done, with ${NUM_EXPORTED} symbols generated (${NUM_F2C_SYMBOLS} f2c, ${NUM_CMPLX_SYMBOLS} complex-returning, ${NUM_CBLAS_WORKAROUND_SYMBOLS} cblas-workaround functions)."
echo "Done, with ${NUM_EXPORTED} symbols generated (${NUM_F2C_SYMBOLS} f2c, ${NUM_CMPLX_SYMBOLS} complex-returning, ${NUM_CBLAS_SYMBOLS} cblas, ${NUM_CBLAS_WORKAROUND_SYMBOLS} cblas-workaround functions)."
rm -f tempsymbols.def
9 changes: 9 additions & 0 deletions src/autodetection.c
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ int32_t autodetect_cblas_divergence(void * handle, const char * suffix) {

#ifdef CBLAS_DIVERGENCE_AUTODETECTION
char symbol_name[MAX_SYMBOL_LEN];
char extra_underscore_suffix[MAX_SYMBOL_LEN];
snprintf(extra_underscore_suffix, MAX_SYMBOL_LEN, "_%s", suffix);

build_symbol_name(symbol_name, "zdotc_", suffix);
if (lookup_symbol(handle, symbol_name) != NULL ) {
Expand All @@ -371,6 +373,13 @@ int32_t autodetect_cblas_divergence(void * handle, const char * suffix) {
return LBT_CBLAS_CONFORMANT;
}

// Do the fallback extra-underscore suffix search here, so we don't mistakenly
// mark MKL v2024 as CBLAS-divergent
build_symbol_name(symbol_name, "cblas_zdotc_sub", extra_underscore_suffix);
if (lookup_symbol(handle, symbol_name) != NULL ) {
return LBT_CBLAS_CONFORMANT;
}

const char * lp64_suffixes[] = {
"", "_", "__",
};
Expand Down
208 changes: 208 additions & 0 deletions src/exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5164,6 +5164,214 @@

#endif

#ifndef CBLAS_FUNCS
#define CBLAS_FUNCS(XX) \
XX(cblas_caxpby, 2609) \
XX(cblas_caxpy, 2610) \
XX(cblas_caxpyc, 2611) \
XX(cblas_ccopy, 2612) \
XX(cblas_cdotc, 2613) \
XX(cblas_cdotc_sub, 2614) \
XX(cblas_cdotu, 2615) \
XX(cblas_cdotu_sub, 2616) \
XX(cblas_cgbmv, 2617) \
XX(cblas_cgeadd, 2618) \
XX(cblas_cgemm, 2619) \
XX(cblas_cgemm3m, 2620) \
XX(cblas_cgemmt, 2621) \
XX(cblas_cgemv, 2622) \
XX(cblas_cgerc, 2623) \
XX(cblas_cgeru, 2624) \
XX(cblas_chbmv, 2625) \
XX(cblas_chemm, 2626) \
XX(cblas_chemv, 2627) \
XX(cblas_cher, 2628) \
XX(cblas_cher2, 2629) \
XX(cblas_cher2k, 2630) \
XX(cblas_cherk, 2631) \
XX(cblas_chpmv, 2632) \
XX(cblas_chpr, 2633) \
XX(cblas_chpr2, 2634) \
XX(cblas_cimatcopy, 2635) \
XX(cblas_comatcopy, 2636) \
XX(cblas_crotg, 2637) \
XX(cblas_cscal, 2638) \
XX(cblas_csrot, 2639) \
XX(cblas_csscal, 2640) \
XX(cblas_cswap, 2641) \
XX(cblas_csymm, 2642) \
XX(cblas_csyr2k, 2643) \
XX(cblas_csyrk, 2644) \
XX(cblas_ctbmv, 2645) \
XX(cblas_ctbsv, 2646) \
XX(cblas_ctpmv, 2647) \
XX(cblas_ctpsv, 2648) \
XX(cblas_ctrmm, 2649) \
XX(cblas_ctrmv, 2650) \
XX(cblas_ctrsm, 2651) \
XX(cblas_ctrsv, 2652) \
XX(cblas_damax, 2653) \
XX(cblas_damin, 2654) \
XX(cblas_dasum, 2655) \
XX(cblas_daxpby, 2656) \
XX(cblas_daxpy, 2657) \
XX(cblas_dbf16tod, 2658) \
XX(cblas_dcopy, 2659) \
XX(cblas_ddot, 2660) \
XX(cblas_dgbmv, 2661) \
XX(cblas_dgeadd, 2662) \
XX(cblas_dgemm, 2663) \
XX(cblas_dgemmt, 2664) \
XX(cblas_dgemv, 2665) \
XX(cblas_dger, 2666) \
XX(cblas_dimatcopy, 2667) \
XX(cblas_dnrm2, 2668) \
XX(cblas_domatcopy, 2669) \
XX(cblas_drot, 2670) \
XX(cblas_drotg, 2671) \
XX(cblas_drotm, 2672) \
XX(cblas_drotmg, 2673) \
XX(cblas_dsbmv, 2674) \
XX(cblas_dscal, 2675) \
XX(cblas_dsdot, 2676) \
XX(cblas_dspmv, 2677) \
XX(cblas_dspr, 2678) \
XX(cblas_dspr2, 2679) \
XX(cblas_dsum, 2680) \
XX(cblas_dswap, 2681) \
XX(cblas_dsymm, 2682) \
XX(cblas_dsymv, 2683) \
XX(cblas_dsyr, 2684) \
XX(cblas_dsyr2, 2685) \
XX(cblas_dsyr2k, 2686) \
XX(cblas_dsyrk, 2687) \
XX(cblas_dtbmv, 2688) \
XX(cblas_dtbsv, 2689) \
XX(cblas_dtpmv, 2690) \
XX(cblas_dtpsv, 2691) \
XX(cblas_dtrmm, 2692) \
XX(cblas_dtrmv, 2693) \
XX(cblas_dtrsm, 2694) \
XX(cblas_dtrsv, 2695) \
XX(cblas_dzamax, 2696) \
XX(cblas_dzamin, 2697) \
XX(cblas_dzasum, 2698) \
XX(cblas_dznrm2, 2699) \
XX(cblas_dzsum, 2700) \
XX(cblas_icamax, 2701) \
XX(cblas_icamin, 2702) \
XX(cblas_icmax, 2703) \
XX(cblas_icmin, 2704) \
XX(cblas_idamax, 2705) \
XX(cblas_idamin, 2706) \
XX(cblas_idmax, 2707) \
XX(cblas_idmin, 2708) \
XX(cblas_isamax, 2709) \
XX(cblas_isamin, 2710) \
XX(cblas_ismax, 2711) \
XX(cblas_ismin, 2712) \
XX(cblas_izamax, 2713) \
XX(cblas_izamin, 2714) \
XX(cblas_izmax, 2715) \
XX(cblas_izmin, 2716) \
XX(cblas_samax, 2717) \
XX(cblas_samin, 2718) \
XX(cblas_sasum, 2719) \
XX(cblas_saxpby, 2720) \
XX(cblas_saxpy, 2721) \
XX(cblas_sbdot, 2722) \
XX(cblas_sbdtobf16, 2723) \
XX(cblas_sbf16tos, 2724) \
XX(cblas_sbgemm, 2725) \
XX(cblas_sbgemv, 2726) \
XX(cblas_sbstobf16, 2727) \
XX(cblas_scamax, 2728) \
XX(cblas_scamin, 2729) \
XX(cblas_scasum, 2730) \
XX(cblas_scnrm2, 2731) \
XX(cblas_scopy, 2732) \
XX(cblas_scsum, 2733) \
XX(cblas_sdot, 2734) \
XX(cblas_sdsdot, 2735) \
XX(cblas_sgbmv, 2736) \
XX(cblas_sgeadd, 2737) \
XX(cblas_sgemm, 2738) \
XX(cblas_sgemmt, 2739) \
XX(cblas_sgemv, 2740) \
XX(cblas_sger, 2741) \
XX(cblas_simatcopy, 2742) \
XX(cblas_snrm2, 2743) \
XX(cblas_somatcopy, 2744) \
XX(cblas_srot, 2745) \
XX(cblas_srotg, 2746) \
XX(cblas_srotm, 2747) \
XX(cblas_srotmg, 2748) \
XX(cblas_ssbmv, 2749) \
XX(cblas_sscal, 2750) \
XX(cblas_sspmv, 2751) \
XX(cblas_sspr, 2752) \
XX(cblas_sspr2, 2753) \
XX(cblas_ssum, 2754) \
XX(cblas_sswap, 2755) \
XX(cblas_ssymm, 2756) \
XX(cblas_ssymv, 2757) \
XX(cblas_ssyr, 2758) \
XX(cblas_ssyr2, 2759) \
XX(cblas_ssyr2k, 2760) \
XX(cblas_ssyrk, 2761) \
XX(cblas_stbmv, 2762) \
XX(cblas_stbsv, 2763) \
XX(cblas_stpmv, 2764) \
XX(cblas_stpsv, 2765) \
XX(cblas_strmm, 2766) \
XX(cblas_strmv, 2767) \
XX(cblas_strsm, 2768) \
XX(cblas_strsv, 2769) \
XX(cblas_xerbla, 2770) \
XX(cblas_zaxpby, 2771) \
XX(cblas_zaxpy, 2772) \
XX(cblas_zcopy, 2773) \
XX(cblas_zdotc, 2774) \
XX(cblas_zdotc_sub, 2775) \
XX(cblas_zdotu, 2776) \
XX(cblas_zdotu_sub, 2777) \
XX(cblas_zdscal, 2778) \
XX(cblas_zgbmv, 2779) \
XX(cblas_zgeadd, 2780) \
XX(cblas_zgemm, 2781) \
XX(cblas_zgemm3m, 2782) \
XX(cblas_zgemmt, 2783) \
XX(cblas_zgemv, 2784) \
XX(cblas_zgerc, 2785) \
XX(cblas_zgeru, 2786) \
XX(cblas_zhbmv, 2787) \
XX(cblas_zhemm, 2788) \
XX(cblas_zhemv, 2789) \
XX(cblas_zher, 2790) \
XX(cblas_zher2, 2791) \
XX(cblas_zher2k, 2792) \
XX(cblas_zherk, 2793) \
XX(cblas_zhpmv, 2794) \
XX(cblas_zhpr, 2795) \
XX(cblas_zhpr2, 2796) \
XX(cblas_zimatcopy, 2797) \
XX(cblas_zomatcopy, 2798) \
XX(cblas_zscal, 2799) \
XX(cblas_zswap, 2800) \
XX(cblas_zsymm, 2801) \
XX(cblas_zsyr2k, 2802) \
XX(cblas_zsyrk, 2803) \
XX(cblas_ztbmv, 2804) \
XX(cblas_ztbsv, 2805) \
XX(cblas_ztpmv, 2806) \
XX(cblas_ztpsv, 2807) \
XX(cblas_ztrmm, 2808) \
XX(cblas_ztrmv, 2809) \
XX(cblas_ztrsm, 2810) \
XX(cblas_ztrsv, 2811) \

#endif

#ifndef CBLAS_WORKAROUND_FUNCS
#define CBLAS_WORKAROUND_FUNCS(XX) \
XX(cblas_cdotc_sub, 2614) \
Expand Down
19 changes: 15 additions & 4 deletions src/libblastrampoline.c
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
if (verbose) {
printf(" -> Autodetected symbol suffix \"%s\"\n", lib_suffix);
}
char extra_underscore_suffix[MAX_SYMBOL_LEN];
snprintf(extra_underscore_suffix, MAX_SYMBOL_LEN, "_%s", lib_suffix);

// Next, we need to figure out if it's a 32-bit or 64-bit BLAS library;
// we'll do that by calling `autodetect_interface()`:
Expand Down Expand Up @@ -392,6 +394,15 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
void *addr = lookup_symbol(handle, symbol_name);
void *self_symbol_addr = interface == LBT_INTERFACE_ILP64 ? exported_func64[symbol_idx] \
: exported_func32[symbol_idx];
if (addr == NULL ) {
// MKL (and other libraries too in the fullness of time, I have no doubt) doesn't like
// to slap `64` directly onto the end of their symbol names; they insert an extra `_`
// if the symbol is not a FORTRAN symbol (which would already have a `_` at the end)
// We catch this case here, as a fallback check:
build_symbol_name(symbol_name, exported_func_names[symbol_idx], extra_underscore_suffix);
addr = lookup_symbol(handle, symbol_name);
}

if (addr != NULL && addr != self_symbol_addr) {
lbt_set_forward_by_index(symbol_idx, addr, interface, complex_retstyle, f2c, verbose);
BITFIELD_SET(forwards, symbol_idx);
Expand All @@ -404,8 +415,8 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
// the FORTRAN equivalents.
if (cblas == LBT_CBLAS_DIVERGENT) {
int32_t cblas_symbol_idx = 0;
for (cblas_symbol_idx = 0; cblas_func_idxs[cblas_symbol_idx] != -1; cblas_symbol_idx += 1) {
int32_t symbol_idx = cblas_func_idxs[cblas_symbol_idx];
for (cblas_symbol_idx = 0; cblas_workaround_func_idxs[cblas_symbol_idx] != -1; cblas_symbol_idx += 1) {
int32_t symbol_idx = cblas_workaround_func_idxs[cblas_symbol_idx];

// Report to the user that we're cblas-wrapping this one
if (verbose) {
Expand All @@ -415,9 +426,9 @@ LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t v
}

if (interface == LBT_INTERFACE_LP64) {
(*exported_func32_addrs[symbol_idx]) = cblas32_func_wrappers[cblas_symbol_idx];
(*exported_func32_addrs[symbol_idx]) = cblas32_workaround_func_wrappers[cblas_symbol_idx];
} else {
(*exported_func64_addrs[symbol_idx]) = cblas64_func_wrappers[cblas_symbol_idx];
(*exported_func64_addrs[symbol_idx]) = cblas64_workaround_func_wrappers[cblas_symbol_idx];
}
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/libblastrampoline_cblasdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,24 @@ CBLAS_WORKAROUND_FUNCS(XX_64)
// locations, allowing a cblas index -> function lookup
#define XX(name, index) &lbt_##name,
#define XX_64(name, index) &lbt_##name##64_,
const void ** cblas32_func_wrappers[] = {
const void ** cblas32_workaround_func_wrappers[] = {
CBLAS_WORKAROUND_FUNCS(XX)
NULL
};
const void ** cblas64_func_wrappers[] = {
const void ** cblas64_workaround_func_wrappers[] = {
CBLAS_WORKAROUND_FUNCS(XX_64)
NULL
};
#undef XX
#undef XX_64

// Finally, an array that maps cblas index -> exported symbol index
// Finally, arrays that map cblas index -> exported symbol index
#define XX(name, index) index,
const int cblas_func_idxs[] = {
CBLAS_FUNCS(XX)
-1
};
const int cblas_workaround_func_idxs[] = {
CBLAS_WORKAROUND_FUNCS(XX)
-1
};
Expand Down
25 changes: 0 additions & 25 deletions src/mkl_v2022_adapters.c

This file was deleted.

4 changes: 2 additions & 2 deletions test/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64
libs = unpack_loaded_libraries(config)
@test length(libs) == 1
@test libs[1].interface == LBT_INTERFACE_ILP64
@test libs[1].cblas == LBT_CBLAS_DIVERGENT
@test libs[1].cblas (LBT_CBLAS_DIVERGENT, LBT_CBLAS_CONFORMANT)
@test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT

# Call cblas_zdotc_sub, asserting that it does not try to call a forwardless-symbol
Expand Down Expand Up @@ -318,7 +318,7 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64
libs = unpack_loaded_libraries(config)
@test length(libs) == 1
@test libs[1].interface == LBT_INTERFACE_ILP64
@test libs[1].cblas == LBT_CBLAS_DIVERGENT
@test libs[1].cblas (LBT_CBLAS_DIVERGENT, LBT_CBLAS_CONFORMANT)
@test libs[1].complex_retstyle == LBT_COMPLEX_RETSTYLE_ARGUMENT

# Call cblas_cdotc_sub64_ to test the full CBLAS workaround -> complex return style handling chain
Expand Down