From 746dbb0c83a6f9a43a19a0328dba3bf150ae0e56 Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Mon, 26 Dec 2016 13:47:02 -0800 Subject: [PATCH] Extend sparse broadcast! to combinations of broadcast scalars and sparse vectors/matrices. Makes broadcast! dispatch on container type (as broadcast), and inject generic sparse broadcast! for the appropriate container type. --- base/broadcast.jl | 4 +++- base/sparse/higherorderfns.jl | 11 ++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 4ad21b7926720..1dbddff61fe01 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -202,7 +202,9 @@ Note that `dest` is only used to store the result, and does not supply arguments to `f` unless it is also listed in the `As`, as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. """ -@inline function broadcast!{N}(f, C::AbstractArray, A, Bs::Vararg{Any,N}) +@inline broadcast!{N}(f, C::AbstractArray, A, Bs::Vararg{Any,N}) = + broadcast_c!(f, containertype(C, A, Bs...), C, A, Bs...) +@inline function broadcast_c!{N}(f, ::Type, C::AbstractArray, A, Bs::Vararg{Any,N}) shape = indices(C) @boundscheck check_broadcast_indices(shape, A, Bs...) keeps, Idefaults = map_newindexer(shape, A, Bs) diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl index 7aa041df6207a..6a302bfe682ed 100644 --- a/base/sparse/higherorderfns.jl +++ b/base/sparse/higherorderfns.jl @@ -5,7 +5,8 @@ module HigherOrderFns # This module provides higher order functions specialized for sparse arrays, # particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present. import Base: map, map!, broadcast, broadcast! -import Base.Broadcast: containertype, promote_containertype, broadcast_indices, broadcast_c +import Base.Broadcast: containertype, promote_containertype, + broadcast_indices, broadcast_c, broadcast_c! using Base: front, tail, to_shape using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype @@ -852,11 +853,15 @@ promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array promote_containertype(::Type{AbstractSparseArray}, ::Type{Array}) = Array promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array -# broadcast entry point for combinations of sparse arrays and other types -function broadcast_c(f, ::Type{AbstractSparseArray}, mixedargs...) +# broadcast[!] entry points for combinations of sparse arrays and other types +@inline function broadcast_c{N}(f, ::Type{AbstractSparseArray}, mixedargs::Vararg{Any,N}) parevalf, passedargstup = capturescalars(f, mixedargs) return broadcast(parevalf, passedargstup...) end +@inline function broadcast_c!{N}(f, ::Type{AbstractSparseArray}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) + parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs) + return broadcast!(parevalf, dest, passedsrcargstup...) +end # capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and # broadcast scalar arguments (mixedargs), and returns a function (parevalf) and a reduced # argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs