Skip to content

Commit

Permalink
Handle sparse outer products specially in broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
jmert committed Dec 26, 2018
1 parent e90995c commit 2c8c620
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
56 changes: 53 additions & 3 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import Base: map, map!, broadcast, copy, copyto!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange,
SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
using LinearAlgebra

Expand Down Expand Up @@ -92,6 +93,10 @@ is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_suppor
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)

is_specialcase_sparse_broadcast(f, rest...) = false
is_specialcase_sparse_broadcast(::typeof(*), ::SparseVectorUnion,
::AdjOrTransSparseVectorUnion) = true

# Dispatch on broadcast operations by number of arguments
const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} =
Broadcasted{Style,Axes,F,Tuple{}}
Expand Down Expand Up @@ -810,6 +815,49 @@ end
_finishempty!(C::SparseVector) = C
_finishempty!(C::SparseMatrixCSC) = (fill!(C.colptr, 1); C)

# special case - vector outer product
_copy(f::typeof(*), x::SparseVectorUnion, y::AdjOrTransSparseVectorUnion) = _outer(x, y)
@inline _outer(x::SparseVectorUnion, y::Adjoint) = return _outer(conj, x, y)
@inline _outer(x::SparseVectorUnion, y::Transpose) = return _outer(identity, x, y)
function _outer(trans::Tf, x, y) where Tf
w = parent(y)
nx = length(x)
nw = length(w)
rowvalx = nonzeroinds(x)
rowvalw = nonzeroinds(w)
nzvalsx = nonzeros(x)
nzvalsw = nonzeros(w)
nnzx = length(nzvalsx)
nnzw = length(nzvalsw)

nnzC = nnzx * nnzw
Tv = typeof(one(eltype(x)) * one(eltype(w)))
Ti = promote_type(indtype(x), indtype(w))
colptrC = zeros(Ti, nw + 1)
rowvalC = Vector{Ti}(undef, nnzC)
nzvalsC = Vector{Tv}(undef, nnzC)

idx = 0
@inbounds colptrC[1] = 1
@inbounds for jj = 1:nnzw
wval = nzvalsw[jj]
iszero(wval) && continue
col = rowvalw[jj]
wval = trans(wval)

for ii = 1:nnzx
xval = nzvalsx[ii]
iszero(xval) && continue
idx += 1
colptrC[col+1] += 1
rowvalC[idx] = rowvalx[ii]
nzvalsC[idx] = xval * wval
end
end
cumsum!(colptrC, colptrC)

return SparseMatrixCSC(nx, nw, colptrC, rowvalC, nzvalsC)
end

# (9) _broadcast_zeropres!/_broadcast_notzeropres! for more than two (input) sparse vectors/matrices
function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMat,N}) where {Tf,N}
Expand Down Expand Up @@ -1079,8 +1127,10 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(

function copy(bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
if is_supported_sparse_broadcast(bcf.args...)
broadcast(bcf.f, map(_sparsifystructured, bcf.args)...)
if is_specialcase_sparse_broadcast(bcf.f, bcf.args...)
return _copy(bcf.f, bcf.args...)
elseif is_supported_sparse_broadcast(bcf.args...)
return broadcast(bcf.f, map(_sparsifystructured, bcf.args)...)
else
return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc))
end
Expand Down
1 change: 1 addition & 0 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} =
# union of such a view and a SparseVector so we define an alias for such a union as well
const SparseColumnView{T} = SubArray{T,1,<:SparseMatrixCSC,Tuple{Base.Slice{Base.OneTo{Int}},Int},false}
const SparseVectorUnion{T} = Union{SparseVector{T}, SparseColumnView{T}}
const AdjOrTransSparseVectorUnion{T} = LinearAlgebra.AdjOrTrans{T, <:SparseVectorUnion{T}}

### Basic properties

Expand Down

0 comments on commit 2c8c620

Please sign in to comment.