diff --git a/src/higherorderfns.jl b/src/higherorderfns.jl index 3a45d195..c79a6123 100644 --- a/src/higherorderfns.jl +++ b/src/higherorderfns.jl @@ -973,35 +973,39 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f( # and rebroadcast. otherwise, divert to generic AbstractArray broadcast code. struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end -const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal} -Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse() - PromoteToSparse(::Val{0}) = PromoteToSparse() PromoteToSparse(::Val{1}) = PromoteToSparse() PromoteToSparse(::Val{2}) = PromoteToSparse() PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() -Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse() -Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() - -# FIXME: switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle -# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse() -# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() -# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() -Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Vector} where T}) = Broadcast.MatrixStyle() # Adjoint not yet defined when broadcast.jl loaded -Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Vector} where T}) = Broadcast.MatrixStyle() # Transpose not yet defined when broadcast.jl loaded +const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal} +Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse() Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse() -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse() -Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N = - Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(1))) -Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) where N = - Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(2))) -# end FIXME -broadcast(f, ::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} = - broadcast(f, map(_sparsifystructured, As)...) +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse() +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() +Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse() +Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() + +# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray +# could report itself as a DefaultArrayStyle(). +# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details +is_supported_sparse_broadcast() = true +is_supported_sparse_broadcast(::AbstractArray, rest...) = false +is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...) +is_supported_sparse_broadcast(x, rest...) = BroadcastStyle(typeof(x)) === Broadcast.Scalar() && is_supported_sparse_broadcast(rest...) +function broadcast(f, s::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} + if is_supported_sparse_broadcast(As...) + return broadcast(f, map(_sparsifystructured, As)...) + else + return broadcast(f, Broadcast.ArrayConflict(), nothing, nothing, As...) + end +end # For broadcast! with ::Any inputs, we need a layer of indirection to determine whether # the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,