Skip to content

Commit

Permalink
Default to iterating over broadcast arguments
Browse files Browse the repository at this point in the history
Broadcast now calls a special helper function, `broadcastable`, on each argument. It ensures that the arguments support indexing and have a shape, or otherwise have defined a custom `BroadcastStyle`.
  • Loading branch information
mbauman committed Mar 16, 2018
1 parent cf65ee1 commit 7d01143
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 39 deletions.
66 changes: 34 additions & 32 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,7 @@ BroadcastStyle(::Type{<:Tuple}) = Style{Tuple}()

struct Unknown <: BroadcastStyle end
BroadcastStyle(::Type{Union{}}) = Unknown() # ambiguity resolution

"""
`Broadcast.Scalar()` is a [`BroadcastStyle`](@ref) indicating that an object is not
treated as a container for the purposes of broadcasting. This is the default for objects
that have not customized `BroadcastStyle`.
"""
struct Scalar <: BroadcastStyle end
BroadcastStyle(::Type) = Scalar()
BroadcastStyle(::Type{<:Ptr}) = Scalar()
BroadcastStyle(::Type) = Unknown()

"""
`Broadcast.AbstractArrayStyle{N} <: BroadcastStyle` is the abstract supertype for any style
Expand Down Expand Up @@ -102,15 +94,14 @@ behaves as an `N`-dimensional array for broadcasting. Specifically, `DefaultArra
used for any
AbstractArray type that hasn't defined a specialized style, and in the absence of
overrides from other `broadcast` arguments the resulting output type is `Array`.
When there are multiple inputs to `broadcast`, `DefaultArrayStyle` "wins" over [`Broadcast.Scalar`](@ref)
but "loses" to any other [`Broadcast.ArrayStyle`](@ref).
When there are multiple inputs to `broadcast`, `DefaultArrayStyle` "loses" to any other [`Broadcast.ArrayStyle`](@ref).
"""
struct DefaultArrayStyle{N} <: AbstractArrayStyle{N} end
(::Type{<:DefaultArrayStyle})(::Val{N}) where N = DefaultArrayStyle{N}()
const DefaultVectorStyle = DefaultArrayStyle{1}
const DefaultMatrixStyle = DefaultArrayStyle{2}
BroadcastStyle(::Type{<:AbstractArray{T,N}}) where {T,N} = DefaultArrayStyle{N}()
BroadcastStyle(::Type{<:Ref}) = DefaultArrayStyle{0}()
BroadcastStyle(::Type{<:Union{Ref,Number}}) = DefaultArrayStyle{0}()

# `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle`
# objects were supplied as arguments, and that no rule was defined for resolving the
Expand Down Expand Up @@ -141,10 +132,8 @@ BroadcastStyle(::BroadcastStyle, ::BroadcastStyle) = Unknown()
BroadcastStyle(::Unknown, ::Unknown) = Unknown()
BroadcastStyle(::S, ::Unknown) where S<:BroadcastStyle = S()
# Precedence rules
BroadcastStyle(::Style{Tuple}, ::Scalar) = Style{Tuple}()
BroadcastStyle(a::AbstractArrayStyle{0}, ::Style{Tuple}) = typeof(a)(Val(1))
BroadcastStyle(a::AbstractArrayStyle{0}, b::Style{Tuple}) = b
BroadcastStyle(a::AbstractArrayStyle, ::Style{Tuple}) = a
BroadcastStyle(a::AbstractArrayStyle, ::Scalar) = a
BroadcastStyle(::A, ::A) where A<:ArrayStyle = A()
BroadcastStyle(::ArrayStyle, ::ArrayStyle) = Unknown()
BroadcastStyle(::A, ::A) where A<:AbstractArrayStyle = A()
Expand Down Expand Up @@ -183,10 +172,9 @@ broadcast_similar(f, ::ArrayConflict, ::Type{Bool}, inds::Indices, As...) =
broadcast_indices() = ()
broadcast_indices(::Type{T}) where T = ()
broadcast_indices(A) = broadcast_indices(combine_styles(A), A)
broadcast_indices(::Scalar, A) = ()
broadcast_indices(::Style{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::DefaultArrayStyle{0}, A::Ref) = ()
broadcast_indices(::AbstractArrayStyle, A) = Base.axes(A)
broadcast_indices(::BroadcastStyle, A) = Base.axes(A)
"""
Base.broadcast_indices(::SrcStyle, A)
Expand Down Expand Up @@ -328,8 +316,7 @@ end

Base.@propagate_inbounds _broadcast_getindex(::Type{T}, I) where T = T
Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(combine_styles(A), A, I)
Base.@propagate_inbounds _broadcast_getindex(::DefaultArrayStyle{0}, A::Ref, I) = A[]
Base.@propagate_inbounds _broadcast_getindex(::Union{Unknown,Scalar}, A, I) = A
Base.@propagate_inbounds _broadcast_getindex(::DefaultArrayStyle{0}, A, I) = A[]
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
Base.@propagate_inbounds _broadcast_getindex(::Style{Tuple}, A::Tuple{Any}, I) = A[1]

Expand Down Expand Up @@ -395,6 +382,16 @@ end
end
end

"""
broadcastable(x)
Return either `x` or an object like `x` such that it supports `axes` and indexing.
"""
broadcastable(x::Union{Symbol,AbstractString,Function,UndefInitializer,Nothing,RoundingMode,Missing}) = Ref(x)
broadcastable(x::Ptr) = Ref{Ptr}(x) # Cannot use Ref(::Ptr) until ambiguous deprecation goes through
broadcastable(::Type{T}) where {T} = Ref{Type{T}}(T)
broadcastable(x::AbstractArray) = x

"""
broadcast!(f, dest, As...)
Expand All @@ -404,7 +401,10 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...)
@inline function broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N}
As′ = map(broadcastable, As)
broadcast!(f, dest, combine_styles(As′...), As′...)
end
@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...)

# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
Expand All @@ -419,14 +419,14 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
return dest
end

# Optimization for the all-Scalar case.
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
# Optimization for the case where all arguments are 0-dimensional
@inline function broadcast!(f, dest, ::AbstractArrayStyle{0}, As::Vararg{Any, N}) where N
if dest isa AbstractArray
if f isa typeof(identity) && N == 1
return fill!(dest, As[1])
return fill!(dest, As[1][])
else
@inbounds for I in eachindex(dest)
dest[I] = f(As...)
dest[I] = f(map(getindex, As)...)
end
return dest
end
Expand Down Expand Up @@ -501,9 +501,8 @@ maptoTuple(f, a, b...) = Tuple{f(a), maptoTuple(f, b...).types...}
# A, broadcast_indices(A)
# )::_broadcast_getindex_eltype(A)
_broadcast_getindex_eltype(A) = _broadcast_getindex_eltype(combine_styles(A), A)
_broadcast_getindex_eltype(::Scalar, ::Type{T}) where T = Type{T}
_broadcast_getindex_eltype(::Union{Unknown,Scalar}, A) = typeof(A)
_broadcast_getindex_eltype(::BroadcastStyle, A) = eltype(A) # Tuple, Array, etc.
_broadcast_getindex_eltype(::DefaultArrayStyle{0}, ::Ref{T}) where {T} = T

# Inferred eltype of result of broadcast(f, xs...)
combine_eltypes(f, A, As...) =
Expand Down Expand Up @@ -584,12 +583,17 @@ julia> string.(("one","two","three","four"), ": ", 1:4)
```
"""
@inline broadcast(f, A, Bs...) =
broadcast(f, combine_styles(A, Bs...), nothing, nothing, A, Bs...)
@inline function broadcast(f, A, Bs...)
A′ = broadcastable(A)
Bs′ = map(broadcastable, Bs)
broadcast(f, combine_styles(A′, Bs′...), nothing, nothing, A′, Bs′...)
end

# In the scalar case we unwrap the arguments and just call `f`
@inline broadcast(f, ::AbstractArrayStyle{0}, ::Nothing, ::Nothing, A, Bs...) = f(A[], map(getindex, Bs)...)

@inline broadcast(f, s::BroadcastStyle, ::Nothing, ::Nothing, A, Bs...) =
broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...),
A, Bs...)
broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...), A, Bs...)

const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict}

Expand Down Expand Up @@ -628,8 +632,6 @@ function broadcast_nonleaf(f, s::NonleafHandlingTypes, ::Type{ElType}, shape::In
_broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1)
end

broadcast(f, ::Union{Scalar,Unknown}, ::Nothing, ::Nothing, a...) = f(a...)

@inline broadcast(f, ::Style{Tuple}, ::Nothing, ::Nothing, A, Bs...) =
tuplebroadcast(f, longest_tuple(A, Bs...), A, Bs...)
@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} =
Expand Down
2 changes: 1 addition & 1 deletion base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
end

function unsafe_getindex(r::LinRange, i::Integer)
lerpi.(i-1, r.lendiv, r.start, r.stop)
lerpi(i-1, r.lendiv, r.start, r.stop)
end

function lerpi(j::Integer, d::Integer, a::T, b::T) where T
Expand Down
2 changes: 2 additions & 0 deletions base/version.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ function print(io::IO, v::VersionNumber)
end
show(io::IO, v::VersionNumber) = print(io, "v\"", v, "\"")

Broadcast.broadcastable(v::VersionNumber) = Ref(v)

const VERSION_REGEX = r"^
v? # prefix (optional)
(\d+) # major (required)
Expand Down
3 changes: 3 additions & 0 deletions stdlib/Dates/src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,6 @@ end
# AbstractArray{TimeType}, AbstractArray{TimeType}
(-)(x::OrdinalRange{T}, y::OrdinalRange{T}) where {T<:TimeType} = Vector(x) - Vector(y)
(-)(x::AbstractRange{T}, y::AbstractRange{T}) where {T<:TimeType} = Vector(x) - Vector(y)

# Allow dates and times to broadcast as unwrapped scalars
Base.Broadcast.broadcastable(x::AbstractTime) = Ref(x)
1 change: 1 addition & 0 deletions stdlib/Dates/src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ function Base.show(io::IO, df::DateFormat)
end
print(io, '"')
end
Base.Broadcast.broadcastable(x::DateFormat) = Ref(x)

"""
dateformat"Y-m-d H:M:S"
Expand Down
3 changes: 3 additions & 0 deletions stdlib/Pkg3/src/resolve/FieldValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,7 @@ function secondmax(f::Field, msk::BitVector = trues(length(f)))
return m2 - m
end

# Support broadcasting like a scalar by default
Base.Broadcast.broadcastable(a::FieldValue) = Ref(a)

end
12 changes: 10 additions & 2 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,8 @@ end

nonscalararg(::SparseVecOrMat) = true
nonscalararg(::Any) = false
scalarwrappedarg(::Union{AbstractArray{<:Any,0},Ref}) = true
scalarwrappedarg(::Any) = false

@inline function _capturescalars()
return (), () -> ()
Expand All @@ -941,6 +943,8 @@ end
let (rest, f) = _capturescalars(mixedargs...)
if nonscalararg(arg)
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
elseif scalarwrappedarg(arg)
return rest, (tail...) -> (arg[], f(tail...)...) # unwrap and add back scalararg after (in makeargs)
else
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
end
Expand All @@ -949,6 +953,8 @@ end
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
if nonscalararg(arg)
return (arg,), (head,) -> (head,) # pass-through
elseif scalarwrappedarg(arg)
return (), () -> (arg[],) # unwrap
else
return (), () -> (arg,) # add scalararg
end
Expand Down Expand Up @@ -977,9 +983,10 @@ Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()

Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()

Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

Expand All @@ -992,7 +999,8 @@ is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_spa
is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...)
is_supported_sparse_broadcast(x, rest...) = BroadcastStyle(typeof(x)) === Broadcast.Scalar() && is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)
function broadcast(f, s::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N}
if is_supported_sparse_broadcast(As...)
return broadcast(f, map(_sparsifystructured, As)...)
Expand Down
8 changes: 4 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ end

# Ref as 0-dimensional array for broadcast
@test (-).(C_NULL, C_NULL)::UInt == 0
@test (+).(1, Ref(2)) == fill(3)
@test (+).(Ref(1), Ref(2)) == fill(3)
@test (+).(1, Ref(2)) == 3
@test (+).(Ref(1), Ref(2)) == 3
@test (+).([[0,2], [1,3]], Ref{Vector{Int}}([1,-1])) == [[1,1], [2,2]]

# Check that broadcast!(f, A) populates A via independent calls to f (#12277, #19722),
Expand Down Expand Up @@ -545,7 +545,7 @@ end
# Test that broadcast treats type arguments as scalars, i.e. containertype yields Any,
# even for subtypes of abstract array. (https://github.com/JuliaStats/DataArrays.jl/issues/229)
@testset "treat type arguments as scalars, DataArrays issue 229" begin
@test Broadcast.combine_styles(AbstractArray) == Broadcast.Scalar()
@test Broadcast.combine_styles(AbstractArray) == Broadcast.Unknown()
@test broadcast(==, [1], AbstractArray) == BitArray([false])
@test broadcast(==, 1, AbstractArray) == false
end
Expand Down Expand Up @@ -574,7 +574,7 @@ end
@test broadcast(foo, "x", [1, 2, 3]) == ["hello", "hello", "hello"]

@test isequal(
[Set([1]), Set([2])] .∪ Set([3]),
[Set([1]), Set([2])] .∪ Ref(Set([3])),
[Set([1, 3]), Set([2, 3])])
end

Expand Down

0 comments on commit 7d01143

Please sign in to comment.