Skip to content

Commit

Permalink
Make Ref behave as a scalar wrapper for broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Oct 23, 2016
1 parent 3f12a38 commit e99c20e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
4 changes: 4 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ end
# logic for deciding the resulting container type
containertype(x) = containertype(typeof(x))
containertype(::Type) = Any
containertype{T<:Ptr}(::Type{T}) = Any
containertype{T<:Tuple}(::Type{T}) = Tuple
containertype{T<:Ref}(::Type{T}) = Array
containertype{T<:AbstractArray}(::Type{T}) = Array
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))
Expand All @@ -49,6 +51,7 @@ broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::Type{Any}, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices(::Type{Array}, A::Ref) = ()
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down Expand Up @@ -121,6 +124,7 @@ map_newindexer(shape, ::Tuple{}) = (), ()
end

@inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
@inline _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
@inline _broadcast_getindex(::Type{Any}, A, I) = A
@inline _broadcast_getindex(::Any, A, I) = A[I]

Expand Down
7 changes: 7 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,10 @@ end
@test broadcast(+, 1.0, (0, -2.0)) == (1.0,-1.0)
@test broadcast(+, 1.0, (0, -2.0), [1]) == [2.0, 0.0]
@test broadcast(*, ["Hello"], ", ", ["World"], "!") == ["Hello, World!"]

# Ref as 0-dimensional array for broadcast
@test (-).(C_NULL, C_NULL)::UInt == 0
@test (+).(1, Ref(2)) == fill(3)
@test (+).(Ref(1), Ref(2)) == fill(3)
@test (+).([[0,2], [1,3]], [1,-1]) == [[1,3], [0,2]]
@test (+).([[0,2], [1,3]], Ref{Vector{Int}}([1,-1])) == [[1,1], [2,2]]

0 comments on commit e99c20e

Please sign in to comment.