Skip to content

Commit

Permalink
Restrict sparse broadcast promotion to Array
Browse files Browse the repository at this point in the history
This should be reverted someday
  • Loading branch information
timholy committed Nov 26, 2017
1 parent 7f487b6 commit 776139b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
39 changes: 38 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,32 @@ BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a
BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
typeof(a)(_max(Val(M),Val(N)))

# FIXME
# The following definitions are necessary to limit SparseArray broadcasting to "plain Arrays"
# (see https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382).
# They should be deleted once the sparse broadcast infrastucture is capable of handling
# arbitrary AbstractArrays.
struct VectorStyle <: AbstractArrayStyle{1} end
struct MatrixStyle <: AbstractArrayStyle{2} end
const VMStyle = Union{VectorStyle,MatrixStyle}
# These lose to DefaultArrayStyle
VectorStyle(::Val{N}) where N = DefaultArrayStyle{N}()
MatrixStyle(::Val{N}) where N = DefaultArrayStyle{N}()

BroadcastStyle(::Type{<:Vector}) = VectorStyle()
BroadcastStyle(::Type{<:Matrix}) = MatrixStyle()

BroadcastStyle(::MatrixStyle, ::VectorStyle) = MatrixStyle()
BroadcastStyle(a::AbstractArrayStyle{Any}, ::VectorStyle) = a
BroadcastStyle(a::AbstractArrayStyle{Any}, ::MatrixStyle) = a
BroadcastStyle(a::AbstractArrayStyle{N}, ::VectorStyle) where N = typeof(a)(_max(Val(N), Val(1)))
BroadcastStyle(a::AbstractArrayStyle{N}, ::MatrixStyle) where N = typeof(a)(_max(Val(N), Val(2)))
BroadcastStyle(::VectorStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(1)))
BroadcastStyle(::MatrixStyle, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle(_max(Val(N), Val(2)))
# to avoid the VectorStyle(::Val) constructor we also need the following
BroadcastStyle(::VectorStyle, ::MatrixStyle) = MatrixStyle()
# end FIXME

## Allocating the output container
"""
broadcast_similar(f, ::BroadcastStyle, ::Type{ElType}, inds, As...)
Expand All @@ -181,6 +207,17 @@ broadcast_similar(f, ::ArrayConflict, ::Type{ElType}, inds::Indices, As...) wher
broadcast_similar(f, ::ArrayConflict, ::Type{Bool}, inds::Indices, As...) =
similar(BitArray, inds)

# FIXME: delete when we get rid of VectorStyle and MatrixStyle
broadcast_similar(f, ::VectorStyle, ::Type{ElType}, inds::Indices{1}, As...) where ElType =
similar(Vector{ElType}, inds)
broadcast_similar(f, ::MatrixStyle, ::Type{ElType}, inds::Indices{2}, As...) where ElType =
similar(Matrix{ElType}, inds)
broadcast_similar(f, ::VectorStyle, ::Type{Bool}, inds::Indices{1}, As...) =
similar(BitArray, inds)
broadcast_similar(f, ::MatrixStyle, ::Type{Bool}, inds::Indices{2}, As...) =
similar(BitArray, inds)
# end FIXME

## Computing the result's indices. Most types probably won't need to specialize this.
broadcast_indices() = ()
broadcast_indices(::Type{T}) where T = ()
Expand Down Expand Up @@ -582,7 +619,7 @@ Nullable{Complex{Float64}}()
broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...),
A, Bs...)

const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict}
const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict,VectorStyle,MatrixStyle}

@inline function broadcast(f, s::NonleafHandlingTypes, ::Type{ElType}, inds::Indices, As...) where ElType
if !Base._isleaftype(ElType)
Expand Down
15 changes: 12 additions & 3 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -989,9 +989,18 @@ PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
# FIXME: switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
BroadcastStyle(::Type{<:Base.RowVector{T,<:Vector}}) where T = Broadcast.MatrixStyle() # RowVector not yet defined when broadcast.jl loaded
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(1)))
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(2)))
# end FIXME

broadcast(f, ::PromoteToSparse, ::Void, ::Void, As::Vararg{Any,N}) where {N} =
broadcast(f, map(_sparsifystructured, As)...)
Expand Down

0 comments on commit 776139b

Please sign in to comment.