Skip to content

Commit

Permalink
[oneMKL] Interface lapack routines
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Mar 20, 2024
1 parent e5f6dc7 commit dba535a
Show file tree
Hide file tree
Showing 11 changed files with 806 additions and 221 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ SpecialFunctions = "1.3, 2"
StaticArrays = "1"
julia = "1.8"
oneAPI_Level_Zero_Loader_jll = "1.9"
oneAPI_Support_jll = "~0.3.1"
oneAPI_Support_jll = "~0.3.2"

[extras]
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"
8 changes: 4 additions & 4 deletions deps/generate_interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ function generate_headers(library::String, filename::String, output::String)
occursin("int64_t", header) && (suffix = "_64")
end
header = replace(header, "$(name_routine)(" => "onemkl$(version)$(name_routine)$(suffix)(")
if name_routine ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
if name_routine ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "geqrf")
header = replace(header, "void onemkl" => "int onemkl")
end
if library == "sparse"
Expand Down Expand Up @@ -369,7 +369,7 @@ function generate_cpp(library::String, filename::String, output::String)
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
else
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "geqrf")
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters);\n")
else
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
Expand All @@ -378,8 +378,8 @@ function generate_cpp(library::String, filename::String, output::String)
if occursin("scratchpad_size", name)
write(oneapi_cpp, " return scratchpad_size;\n")
else
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " return 0;\n")
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "geqrf")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "geqrf")) && write(oneapi_cpp, " return 0;\n")
end
write(oneapi_cpp, "}")
write(oneapi_cpp, "\n\n")
Expand Down
24 changes: 8 additions & 16 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2138,28 +2138,20 @@ extern "C" int64_t onemklZgeqrf_scratchpad_size(syclQueue_t device_queue, int64_
return scratchpad_size;
}

extern "C" int onemklCgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda, float _Complex *tau, float _Complex *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<float>*>(a), lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(scratchpad), scratchpad_size);
__FORCE_MKL_FLUSH__(status);
return 0;
extern "C" void onemklCgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda, float _Complex *tau, float _Complex *scratchpad, int64_t scratchpad_size) {
oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<float>*>(a), lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(scratchpad), scratchpad_size);
}

extern "C" int onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double *tau, double *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
return 0;
extern "C" void onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double *tau, double *scratchpad, int64_t scratchpad_size) {
oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size);
}

extern "C" int onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau, float *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size);
__FORCE_MKL_FLUSH__(status);
return 0;
extern "C" void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau, float *scratchpad, int64_t scratchpad_size) {
oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size);
}

extern "C" int onemklZgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda, double _Complex *tau, double _Complex *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<double>*>(a), lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(scratchpad), scratchpad_size);
__FORCE_MKL_FLUSH__(status);
return 0;
extern "C" void onemklZgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda, double _Complex *tau, double _Complex *scratchpad, int64_t scratchpad_size) {
oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<double>*>(a), lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(scratchpad), scratchpad_size);
}

extern "C" int64_t onemklSgesvd_scratchpad_size(syclQueue_t device_queue, onemklJobsvd jobu, onemklJobsvd jobvt, int64_t m, int64_t n, int64_t lda, int64_t ldu, int64_t ldvt) {
Expand Down
16 changes: 8 additions & 8 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1133,17 +1133,17 @@ int64_t onemklCgeqrf_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_

int64_t onemklZgeqrf_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t lda);

int onemklCgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda, float
_Complex *tau, float _Complex *scratchpad, int64_t scratchpad_size);
void onemklCgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda,
float _Complex *tau, float _Complex *scratchpad, int64_t scratchpad_size);

int onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double *tau,
double *scratchpad, int64_t scratchpad_size);
void onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double *tau,
double *scratchpad, int64_t scratchpad_size);

int onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau,
float *scratchpad, int64_t scratchpad_size);
void onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau,
float *scratchpad, int64_t scratchpad_size);

int onemklZgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda,
double _Complex *tau, double _Complex *scratchpad, int64_t scratchpad_size);
void onemklZgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda,
double _Complex *tau, double _Complex *scratchpad, int64_t scratchpad_size);

int64_t onemklSgesvd_scratchpad_size(syclQueue_t device_queue, onemklJobsvd jobu, onemklJobsvd
jobvt, int64_t m, int64_t n, int64_t lda, int64_t ldu, int64_t
Expand Down
9 changes: 8 additions & 1 deletion lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@ using ..SYCL: syclQueue_t

using GPUArrays

using LinearAlgebra
using LinearAlgebra: checksquare
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo

# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16
const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
const onemklComplex = Union{ComplexF32,ComplexF64}
const onemklHalf = Union{Float16,ComplexF16}
include("wrappers.jl")

include("utils.jl")
include("wrappers_blas.jl")
include("wrappers_lapack.jl")
include("linalg.jl")

function band(A::StridedArray, kl, ku)
Expand Down
71 changes: 71 additions & 0 deletions lib/mkl/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#
# Auxiliary
#

function Base.convert(::Type{onemklSide}, side::Char)
if side == 'L'
return ONEMKL_SIDE_LEFT
elseif side == 'R'
return ONEMKL_SIDE_RIGHT
else
throw(ArgumentError("Unknown transpose $side"))
end
end

function Base.convert(::Type{onemklTranspose}, trans::Char)
if trans == 'N'
return ONEMKL_TRANSPOSE_NONTRANS
elseif trans == 'T'
return ONEMKL_TRANSPOSE_TRANS
elseif trans == 'C'
return ONEMLK_TRANSPOSE_CONJTRANS
else
throw(ArgumentError("Unknown transpose $trans"))
end
end

function Base.convert(::Type{onemklUplo}, uplo::Char)
if uplo == 'U'
return ONEMKL_UPLO_UPPER
elseif uplo == 'L'
return ONEMKL_UPLO_LOWER
else
throw(ArgumentError("Unknown transpose $uplo"))
end
end

function Base.convert(::Type{onemklDiag}, diag::Char)
if diag == 'N'
return ONEMKL_DIAG_NONUNIT
elseif diag == 'U'
return ONEMKL_DIAG_UNIT
else
throw(ArgumentError("Unknown transpose $diag"))
end
end

function Base.convert(::Type{onemklIndex}, index::Char)
if index == 'O'
return ONEMKL_INDEX_ONE
elseif index == 'Z'
return ONEMKL_INDEX_ZERO
else
throw(ArgumentError("Unknown index $index"))
end
end

function Base.convert(::Type{onemklLayout}, index::Char)
if index == 'R'
return ONEMKL_LAYOUT_ROW
elseif index == 'C'
return ONEMKL_LAYOUT_COL
else
throw(ArgumentError("Unknown layout $layout"))
end
end

# create a batch of pointers in device memory from a batch of device arrays
@inline function unsafe_batch(batch::Vector{<:oneArray{T}}) where {T}
ptrs = pointer.(batch)
return oneArray(ptrs)
end
Loading

0 comments on commit dba535a

Please sign in to comment.