Skip to content

Commit

Permalink
enable all tests and implement c/z for getri
Browse files Browse the repository at this point in the history
  • Loading branch information
kballeda authored and amontoison committed Oct 18, 2023
1 parent 08be651 commit db9359e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 38 deletions.
48 changes: 44 additions & 4 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,7 @@ extern "C" void onemklSgetri(syclQueue_t device_queue, int64_t n, float *a, int6
auto getri_status = oneapi::mkl::lapack::getri(device_queue->val, n, a, lda, ipiv, getri_scratchpad_dev,
getri_scratchpad_size);
__FORCE_MKL_FLUSH__(getri_status);
//free(getrf_scratchpad_dev);
//free(getri_scratchpad_dev);
free(getri_scratchpad_dev, context);
}

extern "C" void onemklDgetri(syclQueue_t device_queue, int64_t n, double *a, int64_t lda, int64_t *ipiv) {
Expand All @@ -458,8 +457,49 @@ extern "C" void onemklDgetri(syclQueue_t device_queue, int64_t n, double *a, int
auto getri_status = oneapi::mkl::lapack::getri(device_queue->val, n, a, lda, ipiv, getri_scratchpad_dev,
getri_scratchpad_size);
__FORCE_MKL_FLUSH__(getri_status);
//free(getrf_scratchpad_dev);
//free(getri_scratchpad_dev);
free(getri_scratchpad_dev, context);
}

extern "C" void onemklCgetri(syclQueue_t device_queue, int64_t n, float _Complex *a, int64_t lda, int64_t *ipiv) {
auto main_queue = device_queue->val;
auto device = main_queue.get_device();
auto context = main_queue.get_context();
int64_t getri_scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size<std::complex<float> >(device_queue->val,
n, lda);
int64_t getrf_scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size<std::complex<float> >(device_queue->val,
n, n, lda);
auto getrf_scratchpad_dev = (std::complex<float> *) malloc_device(getrf_scratchpad_size * sizeof(std::complex<float>),
device, context);
auto getri_scratchpad_dev = (std::complex<float> *) malloc_device(getri_scratchpad_size * sizeof(std::complex<float>),
device, context);
auto getrf_status = oneapi::mkl::lapack::getrf(device_queue->val, n, n, reinterpret_cast<std::complex<float> *>(a),
lda, ipiv, getrf_scratchpad_dev, getrf_scratchpad_size);
__FORCE_MKL_FLUSH__(getrf_status);
auto getri_status = oneapi::mkl::lapack::getri(device_queue->val, n, reinterpret_cast<std::complex<float> *>(a),
lda, ipiv, getri_scratchpad_dev, getri_scratchpad_size);
__FORCE_MKL_FLUSH__(getri_status);
free(getri_scratchpad_dev, context);
}

extern "C" void onemklZgetri(syclQueue_t device_queue, int64_t n, double _Complex *a, int64_t lda, int64_t *ipiv) {
auto main_queue = device_queue->val;
auto device = main_queue.get_device();
auto context = main_queue.get_context();
int64_t getri_scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size<std::complex<double> >(device_queue->val,
n, lda);
int64_t getrf_scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size<std::complex<double> >(device_queue->val,
n, n, lda);
auto getrf_scratchpad_dev = (std::complex<double> *) malloc_device(getrf_scratchpad_size * sizeof(std::complex<double>),
device, context);
auto getri_scratchpad_dev = (std::complex<double> *) malloc_device(getri_scratchpad_size * sizeof(std::complex<double>),
device, context);
auto getrf_status = oneapi::mkl::lapack::getrf(device_queue->val, n, n, reinterpret_cast<std::complex<double> *>(a),
lda, ipiv, getrf_scratchpad_dev, getrf_scratchpad_size);
__FORCE_MKL_FLUSH__(getrf_status);
auto getri_status = oneapi::mkl::lapack::getri(device_queue->val, n, reinterpret_cast<std::complex<double> *>(a),
lda, ipiv, getri_scratchpad_dev, getri_scratchpad_size);
__FORCE_MKL_FLUSH__(getri_status);
free(getri_scratchpad_dev, context);
}

extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
Expand Down
4 changes: 4 additions & 0 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ void onemklSgetri(syclQueue_t device_queue, int64_t n,
float *a, int64_t lda, int64_t *ipiv);
void onemklDgetri(syclQueue_t device_queue, int64_t n,
double *a, int64_t lda, int64_t *ipiv);
void onemklCgetri(syclQueue_t device_queue, int64_t n,
float _Complex *a, int64_t lda, int64_t *ipiv);
void onemklZgetri(syclQueue_t device_queue, int64_t n,
double _Complex *a, int64_t lda, int64_t *ipiv);

// XXX: how to expose half in C?
// int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
Expand Down
4 changes: 3 additions & 1 deletion lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ end

for (fname, elty) in
((:onemklSgetri, :Float32),
(:onemklDgetri, :Float64))
(:onemklDgetri, :Float64),
(:onemklCgetri, :ComplexF32),
(:onemklZgetri, :ComplexF64))
@eval begin
function getri!(n::Number,
a::oneStridedVecOrMat{$elty})
Expand Down
10 changes: 10 additions & 0 deletions lib/support/liboneapi_support.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ function onemklDgetri(device_queue, n, a, lda, ipiv)
a::ZePtr{Cdouble}, lda::Int64, ipiv::ZePtr{Int64})::Cvoid
end

function onemklCgetri(device_queue, n, a, lda, ipiv)
@ccall liboneapi_support.onemklCgetri(device_queue::syclQueue_t, n::Int64,
a::ZePtr{ComplexF32}, lda::Int64, ipiv::ZePtr{Int64})::Cvoid
end

function onemklZgetri(device_queue, n, a, lda, ipiv)
@ccall liboneapi_support.onemklZgetri(device_queue::syclQueue_t, n::Int64,
a::ZePtr{ComplexF64}, lda::Int64, ipiv::ZePtr{Int64})::Cvoid
end

function onemklSgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C,
ldc)
@ccall liboneapi_support.onemklSgemm(device_queue::syclQueue_t, transA::onemklTranspose,
Expand Down
60 changes: 27 additions & 33 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ end

@testset "Blas-Extension" begin
@testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64])
#=

@testset "geqrf" begin
A = rand(T, m, n)
d_A = oneArray(A)
Expand All @@ -1091,18 +1091,15 @@ end
hA, ipiv = LinearAlgebra.LAPACK.getrf!(A)
@test hA Array(d_A)
end
=#
if T <: Float32

@testset "getri" begin
n = 8
A = rand(T, n, n)
d_A = oneArray(A)
hipiv = zeros(Int64, n)
daout, ipiv = oneMKL.getri!(n,d_A)
hC = inv(A)
@test hC Array(d_A) rtol=1e-2
end
end

#=
@testset "gelsBatched" begin
Expand All @@ -1118,38 +1115,35 @@ end
end
d_A, d_C = oneMKL.gels_batched!('N', d_A, d_C)
end
=#
@testset "dgmm_batch" begin
group_count = 10
# generate matrices
bA = [rand(T, m, n) for i in 1:group_count]
bC = [rand(T, m, n) for i in 1:group_count]
bX = [rand(T, m) for i in 1:group_count]

if T <: Union{Float32, ComplexF32, ComplexF64}
@testset "dgmm_batch" begin
group_count = 10
# generate matrices
bA = [rand(T, m, n) for i in 1:group_count]
bC = [rand(T, m, n) for i in 1:group_count]
bX = [rand(T, m) for i in 1:group_count]
# move to device
bd_A = oneArray{T, 2}[]
bd_C = oneArray{T, 2}[]
bd_X = oneArray{T, 1}[]
bd_bad = oneArray{T, 2}[]
for i in 1:length(bA)
push!(bd_A, oneArray(bA[i]))
push!(bd_C, oneArray(bC[i]))
if i < length(bA) - 2
push!(bd_bad, oneArray(bC[i]))
end
end
for i in 1:length(bX)
push!(bd_X, oneArray(bX[i]))
# move to device
bd_A = oneArray{T, 2}[]
bd_C = oneArray{T, 2}[]
bd_X = oneArray{T, 1}[]
bd_bad = oneArray{T, 2}[]
for i in 1:length(bA)
push!(bd_A, oneArray(bA[i]))
push!(bd_C, oneArray(bC[i]))
if i < length(bA) - 2
push!(bd_bad, oneArray(bC[i]))
end
oneMKL.dgmm_batch!('L',m, n, bd_A, bd_X, bd_C)
end
for i in 1:length(bX)
push!(bd_X, oneArray(bX[i]))
end
oneMKL.dgmm_batch!('L',m, n, bd_A, bd_X, bd_C)

for i in 1:group_count
hC = diagm(0 => bX[i]) * bA[i]
@test hC ≈ Array(bd_C[i])
end
for i in 1:group_count
hC = diagm(0 => bX[i]) * bA[i]
@test hC Array(bd_C[i])
end
end
=#
end
end

0 comments on commit db9359e

Please sign in to comment.