Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new broadcast API #348

Closed
wants to merge 17 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 74 additions & 39 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,79 @@
## broadcast! ##
################

import Base.Broadcast:
if VERSION < v"0.7.0-DEV.2638"
## Old Broadcast API ##
import Base.Broadcast:
_containertype, promote_containertype, broadcast_indices,
broadcast_c, broadcast_c!

# Add StaticArray as a new output type in Base.Broadcast promotion machinery.
# This isn't the precise output type, just a placeholder to return from
# promote_containertype, which will control dispatch to our broadcast_c.
_containertype(::Type{<:StaticArray}) = StaticArray
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray

# With the above, the default promote_containertype gives reasonable defaults:
# StaticArray, StaticArray -> StaticArray
# Array, StaticArray -> Array
#
# We could be more precise about the latter, but this isn't really possible
# without using Array{N} rather than Array in Base's promote_containertype.
#
# Base also has broadcast with tuple + Array, but while implementing this would
# be consistent with Base, it's not exactly clear it's a good idea when you can
# just use an SVector instead?
promote_containertype(::Type{StaticArray}, ::Type{Any}) = StaticArray
promote_containertype(::Type{Any}, ::Type{StaticArray}) = StaticArray

broadcast_indices(::Type{StaticArray}, A) = indices(A)


# Override for when output type is deduced to be a StaticArray.
@inline function broadcast_c(f, ::Type{StaticArray}, as...)
_broadcast(f, broadcast_sizes(as...), as...)
# Add StaticArray as a new output type in Base.Broadcast promotion machinery.
# This isn't the precise output type, just a placeholder to return from
# promote_containertype, which will control dispatch to our broadcast_c.
_containertype(::Type{<:StaticArray}) = StaticArray
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray

# With the above, the default promote_containertype gives reasonable defaults:
# StaticArray, StaticArray -> StaticArray
# Array, StaticArray -> Array
#
# We could be more precise about the latter, but this isn't really possible
# without using Array{N} rather than Array in Base's promote_containertype.
#
# Base also has broadcast with tuple + Array, but while implementing this would
# be consistent with Base, it's not exactly clear it's a good idea when you can
# just use an SVector instead?
promote_containertype(::Type{StaticArray}, ::Type{Any}) = StaticArray
promote_containertype(::Type{Any}, ::Type{StaticArray}) = StaticArray

# Override for when output type is deduced to be a StaticArray.
@inline function broadcast_c(f, ::Type{StaticArray}, as...)
_broadcast(f, broadcast_sizes(as...), as...)
end

# TODO: This signature could be relaxed to (::Any, ::Type{StaticArray}, ::Type, ...), though
# we'd need to rework how _broadcast!() and broadcast_sizes() interact with normal AbstractArray.
@inline function broadcast_c!(f, ::Type{StaticArray}, ::Type{StaticArray}, dest, as...)
_broadcast!(f, Size(dest), dest, broadcast_sizes(as...), as...)
end
else
## New Broadcast API ##
import Base.Broadcast:
BroadcastStyle, AbstractArrayStyle, broadcast

# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
# A constructor that changes the style parameter N (array dimension) is also required
struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end
StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}()

BroadcastStyle(::Type{<:StaticArray{D, T, N}}) where {D, T, N} = StaticArrayStyle{N}()

# Precedence rules
BroadcastStyle(::StaticArrayStyle{M}, ::Broadcast.DefaultArrayStyle{N}) where {M,N} =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(M), Val(N)))
# FIXME: These two rules should be removed once VectorStyle and MatrixStyle are removed from base/broadcast.jl
BroadcastStyle(::StaticArrayStyle{M}, ::Broadcast.VectorStyle) where M = Broadcast.Unknown()
BroadcastStyle(::StaticArrayStyle{M}, ::Broadcast.MatrixStyle) where M = Broadcast.Unknown()
# End FIXME

# Add a broadcast method that calls the @generated routine
@inline function broadcast(f, ::StaticArrayStyle, ::Void, ::Void, As...)
_broadcast(f, broadcast_sizes(As...), As...)
end

# Add a specialized broadcast! method that overrides the Base fallback and calls the old routine
@inline function broadcast!(f, C, ::StaticArrayStyle, As...)
_broadcast!(f, Size(C), C, broadcast_sizes(As...), As...)
end
end


##############################################
## Old broadcast machinery for StaticArrays ##
##############################################

broadcast_indices(A::StaticArray) = indices(A)

@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
Expand Down Expand Up @@ -113,21 +155,14 @@ end

if VERSION < v"0.7.0-DEV"
# Workaround for #329
@inline function Base.broadcast(f, ::Type{T}, a::StaticArray) where {T}
map(x->f(T,x), a)
end
end

################
## broadcast! ##
################

# TODO: This signature could be relaxed to (::Any, ::Type{StaticArray}, ::Type, ...), though
# we'd need to rework how _broadcast!() and broadcast_sizes() interact with normal AbstractArray.
@inline function broadcast_c!(f, ::Type{StaticArray}, ::Type{StaticArray}, dest, as...)
_broadcast!(f, Size(dest), dest, broadcast_sizes(as...), as...)
@inline function Base.broadcast(f, ::Type{T}, a::StaticArray) where {T}
map(x->f(T,x), a)
end
end

###############################################
## Old broadcast! machinery for StaticArrays ##
###############################################

@generated function _broadcast!(f, ::Size{newsize}, dest::StaticArray, s::Tuple{Vararg{Size}}, as...) where {newsize}
sizes = [sz.parameters[1] for sz ∈ s.parameters]
Expand Down