Skip to content

Commit

Permalink
Special-case MKL threading setters/getters
Browse files Browse the repository at this point in the history
This uses the `MKL_Domain_*` functions to get/set the thread count for
the BLAS domain only, so FFT, PARDISO, etc... domains in MKL are not
affected bt `lbt_set_num_threads()`.  It also adds a test to show that
this behavior is reasonable when MKL is loaded.
  • Loading branch information
ViralBShah authored and staticfloat committed Jul 31, 2024
1 parent 5a08240 commit ead47b5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

#define MAX_THREADING_NAMES 32

#ifndef max
#define max(x, y) ((x) > (y) ? (x) : (y))
#endif

// We need to ask MKL to get/set threads for only the BLAS domain, we therefore pass in
// this constant to the relevant threading functions to limit our thread setting.
#define MKL_DOMAIN_BLAS 1 /* From mkl_types.h */

/*
* We provide a flexible thread getter/setter interface here; by calling `lbt_set_num_threads()`
* libblastrampoline will propagate the call through to its loaded libraries as long as the
Expand All @@ -10,15 +18,17 @@
*/
static char * getter_names[MAX_THREADING_NAMES] = {
"openblas_get_num_threads",
"MKL_Get_Max_Threads",
"bli_thread_get_num_threads",
// We special-case MKL in the lookup loop below
//"MKL_Domain_Get_Max_Threads",
NULL
};

static char * setter_names[MAX_THREADING_NAMES] = {
"openblas_set_num_threads",
"MKL_Set_Num_Threads",
"bli_thread_set_num_threads",
// We special-case MKL in the lookup loop below
//"MKL_Domain_Set_Num_Threads",
NULL
};

Expand Down Expand Up @@ -62,9 +72,16 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
int (*fptr)() = lookup_symbol(lib->handle, symbol_name);
if (fptr != NULL) {
int new_threads = fptr();
max_threads = max_threads > new_threads ? max_threads : new_threads;
max_threads = max(max_threads, new_threads);
}
}

// Special-case MKL, as we need to specifically ask for the "BLAS" domain
int (*fptr)(int) = lookup_symbol(lib->handle, "MKL_Domain_Get_Max_Threads");
if (fptr != NULL) {
int new_threads = fptr(MKL_DOMAIN_BLAS);
max_threads = max(max_threads, new_threads);
}
}
return max_threads;
}
Expand All @@ -76,15 +93,21 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
*/
LBT_DLLEXPORT void lbt_set_num_threads(int32_t nthreads) {
const lbt_config_t * config = lbt_get_config();
char symbol_name[MAX_SYMBOL_LEN];
for (int lib_idx=0; config->loaded_libs[lib_idx] != NULL; ++lib_idx) {
lbt_library_info_t * lib = config->loaded_libs[lib_idx];
for (int symbol_idx=0; setter_names[symbol_idx] != NULL; ++symbol_idx) {
char symbol_name[MAX_SYMBOL_LEN];
build_symbol_name(symbol_name, setter_names[symbol_idx], lib->suffix);
void (*fptr)(int) = lookup_symbol(lib->handle, symbol_name);
if (fptr != NULL) {
fptr(nthreads);
}
}

// Special-case MKL, as we need to specifically ask for the "BLAS" domain
int (*fptr)(int, int) = lookup_symbol(lib->handle, "MKL_Domain_Set_Num_Threads");
if (fptr != NULL) {
fptr(nthreads, MKL_DOMAIN_BLAS);
}
}
}
12 changes: 12 additions & 0 deletions test/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,16 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64
@test result[1] ComplexF32(1.47 + 3.83im)
@test isempty(stacktraces)
end

@testset "MKL threading domains" begin
nthreads = lbt_get_num_threads(lbt_handle)
if nthreads <= 1
nthreads = 2
else
nthreads = div(nthreads, 2)
end
lbt_set_num_threads(lbt_handle, nthreads)
@test ccall((:MKL_Domain_Get_Max_Threads, libmkl_rt), Cint, (Cint,), 1) == nthreads
@test ccall((:MKL_Domain_Get_Max_Threads, libmkl_rt), Cint, (Cint,), 2) != nthreads
end
end

0 comments on commit ead47b5

Please sign in to comment.