From 2c8c620522455d7d892c10e48061c56090130edc Mon Sep 17 00:00:00 2001 From: Justin Willmert Date: Wed, 26 Dec 2018 12:06:58 -0600 Subject: [PATCH] Handle sparse outer products specially in broadcast --- stdlib/SparseArrays/src/higherorderfns.jl | 56 +++++++++++++++++++++-- stdlib/SparseArrays/src/sparsevector.jl | 1 + 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/stdlib/SparseArrays/src/higherorderfns.jl b/stdlib/SparseArrays/src/higherorderfns.jl index 63d318af434f90..9cc6309bcc2b07 100644 --- a/stdlib/SparseArrays/src/higherorderfns.jl +++ b/stdlib/SparseArrays/src/higherorderfns.jl @@ -8,7 +8,8 @@ import Base: map, map!, broadcast, copy, copyto! using Base: front, tail, to_shape using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, - AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange + AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange, + SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros using Base.Broadcast: BroadcastStyle, Broadcasted, flatten using LinearAlgebra @@ -92,6 +93,10 @@ is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_suppor is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...) is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...) +is_specialcase_sparse_broadcast(f, rest...) = false +is_specialcase_sparse_broadcast(::typeof(*), ::SparseVectorUnion, + ::AdjOrTransSparseVectorUnion) = true + # Dispatch on broadcast operations by number of arguments const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} = Broadcasted{Style,Axes,F,Tuple{}} @@ -810,6 +815,49 @@ end _finishempty!(C::SparseVector) = C _finishempty!(C::SparseMatrixCSC) = (fill!(C.colptr, 1); C) +# special case - vector outer product +_copy(f::typeof(*), x::SparseVectorUnion, y::AdjOrTransSparseVectorUnion) = _outer(x, y) +@inline _outer(x::SparseVectorUnion, y::Adjoint) = return _outer(conj, x, y) +@inline _outer(x::SparseVectorUnion, y::Transpose) = return _outer(identity, x, y) +function _outer(trans::Tf, x, y) where Tf + w = parent(y) + nx = length(x) + nw = length(w) + rowvalx = nonzeroinds(x) + rowvalw = nonzeroinds(w) + nzvalsx = nonzeros(x) + nzvalsw = nonzeros(w) + nnzx = length(nzvalsx) + nnzw = length(nzvalsw) + + nnzC = nnzx * nnzw + Tv = typeof(one(eltype(x)) * one(eltype(w))) + Ti = promote_type(indtype(x), indtype(w)) + colptrC = zeros(Ti, nw + 1) + rowvalC = Vector{Ti}(undef, nnzC) + nzvalsC = Vector{Tv}(undef, nnzC) + + idx = 0 + @inbounds colptrC[1] = 1 + @inbounds for jj = 1:nnzw + wval = nzvalsw[jj] + iszero(wval) && continue + col = rowvalw[jj] + wval = trans(wval) + + for ii = 1:nnzx + xval = nzvalsx[ii] + iszero(xval) && continue + idx += 1 + colptrC[col+1] += 1 + rowvalC[idx] = rowvalx[ii] + nzvalsC[idx] = xval * wval + end + end + cumsum!(colptrC, colptrC) + + return SparseMatrixCSC(nx, nw, colptrC, rowvalC, nzvalsC) +end # (9) _broadcast_zeropres!/_broadcast_notzeropres! for more than two (input) sparse vectors/matrices function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMat,N}) where {Tf,N} @@ -1079,8 +1127,10 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f( function copy(bc::Broadcasted{PromoteToSparse}) bcf = flatten(bc) - if is_supported_sparse_broadcast(bcf.args...) - broadcast(bcf.f, map(_sparsifystructured, bcf.args)...) + if is_specialcase_sparse_broadcast(bcf.f, bcf.args...) + return _copy(bcf.f, bcf.args...) + elseif is_supported_sparse_broadcast(bcf.args...) + return broadcast(bcf.f, map(_sparsifystructured, bcf.args)...) else return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc)) end diff --git a/stdlib/SparseArrays/src/sparsevector.jl b/stdlib/SparseArrays/src/sparsevector.jl index 11d9c2a3dbdbe0..277c5c736c819c 100644 --- a/stdlib/SparseArrays/src/sparsevector.jl +++ b/stdlib/SparseArrays/src/sparsevector.jl @@ -34,6 +34,7 @@ SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} = # union of such a view and a SparseVector so we define an alias for such a union as well const SparseColumnView{T} = SubArray{T,1,<:SparseMatrixCSC,Tuple{Base.Slice{Base.OneTo{Int}},Int},false} const SparseVectorUnion{T} = Union{SparseVector{T}, SparseColumnView{T}} +const AdjOrTransSparseVectorUnion{T} = LinearAlgebra.AdjOrTrans{T, <:SparseVectorUnion{T}} ### Basic properties