Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneMKL] Interface lapack routines #376

Merged
merged 3 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2291,25 +2291,21 @@ extern "C" int64_t onemklZgeqrf_scratchpad_size(syclQueue_t device_queue, int64_

extern "C" int onemklCgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda, float _Complex *tau, float _Complex *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<float>*>(a), lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(scratchpad), scratchpad_size, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

extern "C" int onemklDgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double *tau, double *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

extern "C" int onemklSgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *tau, float *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, a, lda, tau, scratchpad, scratchpad_size, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

extern "C" int onemklZgeqrf(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda, double _Complex *tau, double _Complex *scratchpad, int64_t scratchpad_size) {
auto status = oneapi::mkl::lapack::geqrf(device_queue->val, m, n, reinterpret_cast<std::complex<double>*>(a), lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(scratchpad), scratchpad_size, {});
__FORCE_MKL_FLUSH__(status);
return 0;
}

Expand Down
4 changes: 4 additions & 0 deletions lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ using ..SYCL: syclQueue_t
using GPUArrays

using LinearAlgebra
using LinearAlgebra: checksquare
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag, chkuplo

using SparseArrays

# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16
Expand All @@ -21,6 +24,7 @@ const onemklHalf = Union{Float16,ComplexF16}

include("utils.jl")
include("wrappers_blas.jl")
include("wrappers_lapack.jl")
include("wrappers_sparse.jl")
include("linalg.jl")

Expand Down
120 changes: 115 additions & 5 deletions lib/mkl/wrappers_blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ function symm(side::Char,
end

## syrk
for (fname, elty) in ((:onemklDsyrk,:Float64),
(:onemklSsyrk,:Float32),
(:onemklCsyrk,:ComplexF32),
(:onemklZsyrk,:ComplexF64))
for (fname, elty) in ((:onemklSsyrk, :Float32),
(:onemklDsyrk, :Float64),
(:onemklCsyrk, :ComplexF32),
(:onemklZsyrk, :ComplexF64))
@eval begin
function syrk!(uplo::Char,
trans::Char,
Expand Down Expand Up @@ -703,10 +703,28 @@ for (fname, elty) in ((:onemklSger, :Float32),
end
end

# spr
for (fname, elty) in ((:onemklSspr, :Float32),
(:onemklDspr, :Float64))
@eval begin
function spr!(uplo::Char,
alpha::Number,
x::oneStridedVector{$elty},
A::oneStridedVector{$elty})
n = round(Int, (sqrt(8*length(A))-1)/2)
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
incx = stride(x,1)
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), uplo, n, alpha, x, incx, A)
A
end
end
end

#symv
for (fname, elty) in ((:onemklSsymv,:Float32),
(:onemklDsymv,:Float64))
# Note that the complex symv are not BLAS but auiliary functions in LAPACK
# Note that the complex symv are not BLAS but auxiliary functions in LAPACK
@eval begin
function symv!(uplo::Char,
alpha::Number,
Expand Down Expand Up @@ -898,6 +916,67 @@ function gbmv(trans::Char,
gbmv(trans, m, kl, ku, one(T), a, x)
end

# spmv
for (fname, elty) in ((:onemklSspmv, :Float32),
(:onemklDspmv, :Float64))
@eval begin
function spmv!(uplo::Char,
alpha::Number,
A::oneStridedVector{$elty},
x::oneStridedVector{$elty},
beta::Number,
y::oneStridedVector{$elty})
n = round(Int, (sqrt(8*length(A))-1)/2)
if n != length(x) || n != length(y)
throw(DimensionMismatch(""))
end
incx = stride(x,1)
incy = stride(y,1)
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), uplo, n, alpha, A, x, incx, beta, y, incy)
y
end
end
end

function spmv(uplo::Char, alpha::Number,
A::oneStridedVector{T}, x::oneStridedVector{T}) where T
spmv!(uplo, alpha, A, x, zero(T), similar(x))
end

function spmv(uplo::Char, A::oneStridedVector{T}, x::oneStridedVector{T}) where T
spmv(uplo, one(T), A, x)
end

# tbsv, (TB) triangular banded matrix solve
for (fname, elty) in ((:onemklStbsv, :Float32),
(:onemklDtbsv, :Float64),
(:onemklCtbsv, :ComplexF32),
(:onemklZtbsv, :ComplexF64))
@eval begin
function tbsv!(uplo::Char,
trans::Char,
diag::Char,
k::Integer,
A::oneStridedMatrix{$elty},
x::oneStridedVector{$elty})
m, n = size(A)
if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end
if m < 1+k throw(DimensionMismatch("Array A has fewer than 1+k rows")) end
if n != length(x) throw(DimensionMismatch("")) end
lda = max(1,stride(A,2))
incx = stride(x,1)
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx)
x
end
end
end
function tbsv(uplo::Char, trans::Char, diag::Char, k::Integer,
A::oneStridedMatrix{T}, x::oneStridedVector{T}) where T
tbsv!(uplo, trans, diag, k, A, copy(x))
end

# tbmv
### tbmv, (TB) triangular banded matrix-vector multiplication
for (fname, elty) in ((:onemklStbmv,:Float32),
Expand Down Expand Up @@ -1150,6 +1229,37 @@ function gemm(transA::Char,
B::oneStridedVecOrMat{T}) where T
gemm(transA, transB, one(T), A, B)
end

## dgmm
for (fname, elty) in ((:onemklSdgmm, :Float32),
(:onemklDdgmm, :Float64),
(:onemklCdgmm, :ComplexF32),
(:onemklZdgmm, :ComplexF64))
@eval begin
function dgmm!(mode::Char,
A::oneStridedMatrix{$elty},
X::oneStridedVector{$elty},
C::oneStridedMatrix{$elty})
m, n = size(C)
mA, nA = size(A)
lx = length(X)
if ((mA != m) || (nA != n )) throw(DimensionMismatch("")) end
if ((mode == 'L') && (lx != m)) throw(DimensionMismatch("")) end
if ((mode == 'R') && (lx != n)) throw(DimensionMismatch("")) end
lda = max(1,stride(A,2))
incx = stride(X,1)
ldc = max(1,stride(C,2))
queue = global_queue(context(A), device(A))
$fname(sycl_queue(queue), mode, m, n, A, lda, X, incx, C, ldc)
C
end
end
end
function dgmm(mode::Char, A::oneStridedMatrix{T}, X::oneStridedVector{T}) where T
m,n = size(A)
dgmm!( mode, A, X, similar(A, (m,n) ) )
end

for (fname, elty) in
((:onemklSgemmBatchStrided, Float32),
(:onemklDgemmBatchStrided, Float64),
Expand Down
Loading