Skip to content

Commit

Permalink
make Some type a zero-dim broadcast container (e.g. a scalar)
Browse files Browse the repository at this point in the history
Replaces #35778
Replaces #39184
Fixes #39151
Refs #35675
Refs #43200
  • Loading branch information
vtjnash committed Nov 23, 2021
1 parent 08ea2d8 commit 59b8341
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ Standard library changes
* TCP socket objects now expose `closewrite` functionality and support half-open mode usage ([#40783]).
* Intersect returns a result with the eltype of the type-promoted eltypes of the two inputs ([#41769]).
* `Iterators.countfrom` now accepts any type that defines `+`. ([#37747])
* `Some`containers now support broadcast as zero dimensional immutable containers. `Some(x)`
should be preferred to `Ref(x)` when you wish to exempt `x` from broadcasting ([#35778]).

#### InteractiveUtils
* A new macro `@time_imports` for reporting any time spent importing packages and their dependencies ([#41612])
Expand Down
31 changes: 23 additions & 8 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,9 @@ Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())]
Index into `A` with `I`, collapsing broadcasted indices to their singleton indices as appropriate.
"""
Base.@propagate_inbounds _broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices
Base.@propagate_inbounds _broadcast_getindex(A::Union{Ref,Some,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices
Base.@propagate_inbounds _broadcast_getindex(::Ref{Type{T}}, I) where {T} = T
Base.@propagate_inbounds _broadcast_getindex(::Some{Type{T}}, I) where {T} = T
# Tuples are statically known to be singleton or vector-like
Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I) = A[1]
Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I) = A[I[1]]
Expand Down Expand Up @@ -661,6 +662,20 @@ Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:An
args = _getindex(tail(tail(bc.args)), I)
return _broadcast_getindex_evalf(bc.f, T, S, args...)
end
Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Some{Type{T}},Vararg{Any}}}, I) where {T}
args = _getindex(tail(bc.args), I)
return _broadcast_getindex_evalf(bc.f, T, args...)
end
Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Any,Some{Type{T}},Vararg{Any}}}, I) where {T}
arg1 = _broadcast_getindex(bc.args[1], I)
args = _getindex(tail(tail(bc.args)), I)
return _broadcast_getindex_evalf(bc.f, arg1, T, args...)
end
Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Some{Type{T}},Some{Type{S}},Vararg{Any}}}, I) where {T,S}
args = _getindex(tail(tail(bc.args)), I)
return _broadcast_getindex_evalf(bc.f, T, S, args...)
end


# Utilities for _broadcast_getindex
Base.@propagate_inbounds _getindex(args::Tuple, I) = (_broadcast_getindex(args[1], I), _getindex(tail(args), I)...)
Expand Down Expand Up @@ -691,15 +706,15 @@ julia> Broadcast.broadcastable([1,2,3]) # like `identity` since arrays already s
3
julia> Broadcast.broadcastable(Int) # Types don't support axes, indexing, or iteration but are commonly used as scalars
Base.RefValue{Type{Int64}}(Int64)
Base.Some{Type{Int64}}(Int64)
julia> Broadcast.broadcastable("hello") # Strings break convention of matching iteration and act like a scalar instead
Base.RefValue{String}("hello")
Base.Some{String}("hello")
```
"""
broadcastable(x::Union{Symbol,AbstractString,Function,UndefInitializer,Nothing,RoundingMode,Missing,Val,Ptr,AbstractPattern,Pair,IO}) = Ref(x)
broadcastable(::Type{T}) where {T} = Ref{Type{T}}(T)
broadcastable(x::Union{AbstractArray,Number,AbstractChar,Ref,Tuple,Broadcasted}) = x
broadcastable(x::Union{Symbol,AbstractString,Function,UndefInitializer,Nothing,RoundingMode,Missing,Val,Ptr,AbstractPattern,Pair,IO}) = Some(x)
broadcastable(::Type{T}) where {T} = Some{Type{T}}(T)
broadcastable(x::Union{AbstractArray,Number,AbstractChar,Some,Ref,Tuple,Broadcasted}) = x
# Default to collecting iterables — which will error for non-iterables
broadcastable(x) = collect(x)
broadcastable(::Union{AbstractDict, NamedTuple}) = throw(ArgumentError("broadcasting over dictionaries and `NamedTuple`s is reserved"))
Expand All @@ -722,7 +737,7 @@ combine_eltypes(f, args::Tuple) =
"""
broadcast(f, As...)
Broadcast the function `f` over the arrays, tuples, collections, [`Ref`](@ref)s and/or scalars `As`.
Broadcast the function `f` over the arrays, tuples, collections, [`Some`](@ref)s and/or scalars `As`.
Broadcasting applies the function `f` over the elements of the container arguments and the
scalars themselves in `As`. Singleton and missing dimensions are expanded to match the
Expand Down Expand Up @@ -781,7 +796,7 @@ julia> abs.((1, -2))
julia> broadcast(+, 1.0, (0, -2.0))
(1.0, -1.0)
julia> (+).([[0,2], [1,3]], Ref{Vector{Int}}([1,-1]))
julia> (+).([[0,2], [1,3]], Some{Vector{Int}}([1,-1]))
2-element Vector{Vector{Int64}}:
[1, 1]
[2, 2]
Expand Down
29 changes: 20 additions & 9 deletions base/some.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
Some{T}
Some{T} <: AbstractArray{T,0}
A wrapper type used in `Union{Some{T}, Nothing}` to distinguish between the absence
of a value ([`nothing`](@ref)) and the presence of a `nothing` value (i.e. `Some(nothing)`).
Use [`something`](@ref) to access the value wrapped by a `Some` object.
"""
struct Some{T}
struct Some{T} <: AbstractArray{T,0}
value::T
end

Some(::Type{T}) where {T} = Some{Type{T}}(T)

promote_rule(::Type{Some{T}}, ::Type{Some{S}}) where {T, S<:T} = Some{T}
eltype(x::Type{<:Some{T}}) where {T} = @isdefined(T) ? T : Any
size(x::Some) = ()
axes(x::Some) = ()
length(x::Some) = 1
isempty(x::Some) = false
ndims(x::Some) = 0
ndims(::Type{<:Some}) = 0
iterate(r::Some) = (r.value, nothing)
getindex(r::Some) = r.value
iterate(r::Some, s) = nothing
IteratorSize(::Type{<:Some}) = HasShape{0}()

nonnothingtype(::Type{T}) where {T} = typesplit(T, Nothing)
function nonnothingtype_checked(T::Type)
R = nonnothingtype(T)
R >: T && error("could not compute non-nothing type")
return R
end

promote_rule(::Type{Some{T}}, ::Type{Some{S}}) where {T, S<:T} = Some{T}
promote_rule(T::Type{Nothing}, S::Type) = Union{S, Nothing}
function promote_rule(T::Type{>:Nothing}, S::Type)
R = nonnothingtype(T)
Expand All @@ -26,12 +43,6 @@ function promote_rule(T::Type{>:Nothing}, S::Type)
return Union{R, Nothing}
end

function nonnothingtype_checked(T::Type)
R = nonnothingtype(T)
R >: T && error("could not compute non-nothing type")
return R
end

convert(::Type{T}, x::T) where {T>:Nothing} = x
convert(::Type{T}, x) where {T>:Nothing} = convert(nonnothingtype_checked(T), x)
convert(::Type{Nothing}, x) = throw(MethodError(convert, (Nothing, x)))
Expand Down
16 changes: 14 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ end
@test (+).(Ref(1), Ref(2)) == 3
@test (+).([[0,2], [1,3]], Ref{Vector{Int}}([1,-1])) == [[1,1], [2,2]]

# Some as 0-dimensional array for broadcast
@test (-).(C_NULL, C_NULL)::UInt == 0
@test (+).(1, Some(2)) == 3
@test (+).(Some(1), Some(2)) == 3
@test (+).([[0,2], [1,3]], Some{Vector{Int}}([1,-1])) == [[1,1], [2,2]]


# Check that broadcast!(f, A) populates A via independent calls to f (#12277, #19722),
# and similarly for broadcast!(f, A, numbers...) (#19799).
@test let z = 1; A = broadcast!(() -> z += 1, zeros(2)); A[1] != A[2]; end
Expand Down Expand Up @@ -573,6 +580,11 @@ let io = IOBuffer()
broadcast(x -> print(io, x), [Ref(1.0)])
@test String(take!(io)) == "Base.RefValue{Float64}(1.0)"
end
@test getindex.([Some(1), Some(2)]) == [1, 2]
let io = IOBuffer()
broadcast(x -> print(io, x), [Some(1.0)])
@test String(take!(io)) == "Some(1.0)"
end

# Test that broadcast's promotion mechanism handles closures accepting more than one argument.
# (See issue #19641 and referenced issues and pull requests.)
Expand Down Expand Up @@ -635,7 +647,7 @@ end
@test broadcast(foo, "x", [1, 2, 3]) == ["hello", "hello", "hello"]

@test isequal(
[Set([1]), Set([2])] .∪ Ref(Set([3])),
[Set([1]), Set([2])] .∪ Some(Set([3])),
[Set([1, 3]), Set([2, 3])])
end

Expand Down Expand Up @@ -916,7 +928,7 @@ end
@test reduce(paren, bcraw) == foldl(paren, xs)

# issue #41055
bc = Broadcast.instantiate(Broadcast.broadcasted(Base.literal_pow, Ref(^), [1,2], Ref(Val(2))))
bc = Broadcast.instantiate(Broadcast.broadcasted(Base.literal_pow, Some(^), [1,2], Some(Val(2))))
@test sum(bc, dims=1, init=0) == [5]
bc = Broadcast.instantiate(Broadcast.broadcasted(*, ['a','b'], 'c'))
@test prod(bc, dims=1, init="") == ["acbc"]
Expand Down

0 comments on commit 59b8341

Please sign in to comment.