Skip to content

Commit

Permalink
Improve the dispatch for sparse routines (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Apr 3, 2024
1 parent 949a457 commit b6a393f
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ SpecialFunctions = "1.3, 2"
StaticArrays = "1"
julia = "1.8"
oneAPI_Level_Zero_Loader_jll = "1.9"
oneAPI_Support_jll = "~0.3.2"
oneAPI_Support_jll = "~0.3.3"

[extras]
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"
50 changes: 50 additions & 0 deletions lib/mkl/array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
export oneSparseMatrixCSR

abstract type oneAbstractSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}

mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
rowPtr::oneVector{Ti}
colVal::oneVector{Ti}
nzVal::oneVector{Tv}
dims::NTuple{2,Int}
nnz::Ti
end

Base.length(A::oneSparseMatrixCSR) = prod(A.dims)
Base.size(A::oneSparseMatrixCSR) = A.dims

function Base.size(A::oneSparseMatrixCSR, d::Integer)
if d == 1 || d == 2
return A.dims[d]
else
throw(ArgumentError("dimension must be 1 or 2, got $d"))
end
end

SparseArrays.nnz(A::oneSparseMatrixCSR) = A.nnz
SparseArrays.nonzeros(A::oneSparseMatrixCSR) = A.nzVal

for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC]
@eval Base.show(io::IOContext, x::$gpu) =
show(io, $cpu(x))

@eval function Base.show(io::IO, mime::MIME"text/plain", S::$gpu)
xnnz = nnz(S)
m, n = size(S)
print(io, m, "×", n, " ", typeof(S), " with ", xnnz, " stored ",
xnnz == 1 ? "entry" : "entries")
if !(m == 0 || n == 0)
println(io, ":")
io = IOContext(io, :typeinfo => eltype(S))
if ndims(S) == 1
show(io, $cpu(S))
else
# so that we get the nice Braille pattern
Base.print_array(io, $cpu(S))
end
end
end
end
24 changes: 24 additions & 0 deletions lib/mkl/interfaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# interfacing with other packages

using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, MulAddMul

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end

if VERSION v"1.10-"
for SparseMatrixType in (:oneSparseMatrixCSR,)
@eval begin
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
end
end
end
end
2 changes: 2 additions & 0 deletions lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
const onemklComplex = Union{ComplexF32,ComplexF64}
const onemklHalf = Union{Float16,ComplexF16}

include("array.jl")
include("utils.jl")
include("wrappers_blas.jl")
include("wrappers_lapack.jl")
include("wrappers_sparse.jl")
include("linalg.jl")
include("interfaces.jl")

function band(A::StridedArray, kl, ku)
m, n = size(A)
Expand Down
29 changes: 14 additions & 15 deletions lib/mkl/wrappers_sparse.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
export oneSparseMatrixCSR

mutable struct oneSparseMatrixCSR{T}
handle::matrix_handle_t
type::Type{T}
m::Int
n::Int
end

for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32),
(:onemklSsparse_set_csr_data_64, :Float32 , :Int64),
(:onemklDsparse_set_csr_data , :Float64 , :Int32),
Expand All @@ -21,12 +12,20 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
onemklXsparse_init_matrix_handle(handle_ptr)
m, n = size(A)
At = SparseMatrixCSC(A |> transpose)
row_ptr = oneVector{$intty}(At.colptr)
col_ind = oneVector{$intty}(At.rowval)
val = oneVector{$elty}(At.nzval)
queue = global_queue(context(val), device(val))
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', row_ptr, col_ind, val)
return oneSparseMatrixCSR{$elty}(handle_ptr[], $elty, m, n)
rowPtr = oneVector{$intty}(At.colptr)
colVal = oneVector{$intty}(At.rowval)
nzVal = oneVector{$elty}(At.nzval)
nnzA = length(At.nzval)
queue = global_queue(context(nzVal), device(nzVal))
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
return oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
end

function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
At = SparseMatrixCSC(reverse(A.dims)..., Array(A.rowPtr), Array(A.colVal), Array(A.nzVal))
A_csc = SparseMatrixCSC(At |> transpose)
return A_csc
end
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,8 @@ end
A = sprand(T, 20, 10, 0.5)
A = SparseMatrixCSC{T, S}(A)
B = oneSparseMatrixCSR(A)
A2 = SparseMatrixCSC(B)
@test A == A2
end
end

Expand Down

0 comments on commit b6a393f

Please sign in to comment.