From 246ff49d26b8171fe79edcb2b408bcccbc6d3d0a Mon Sep 17 00:00:00 2001 From: pabloferz Date: Sat, 22 Oct 2016 08:27:56 -0500 Subject: [PATCH] Make Ref behave as a scalar wrapper for broadcast --- base/broadcast.jl | 3 +++ test/broadcast.jl | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/base/broadcast.jl b/base/broadcast.jl index 48bc79868dec8a..061a20363d5c45 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -30,6 +30,7 @@ end # logic for deciding the resulting container type containertype(x) = containertype(typeof(x)) containertype(::Type) = Any +containertype{T<:Ref}(::Type{T}) = Array containertype{T<:Tuple}(::Type{T}) = Tuple containertype{T<:AbstractArray}(::Type{T}) = Array containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2)) @@ -49,6 +50,7 @@ broadcast_indices(A) = broadcast_indices(containertype(A), A) broadcast_indices(::Type{Any}, A) = () broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),) broadcast_indices(::Type{Array}, A) = indices(A) +broadcast_indices(::Type{Array}, A::Ref) = () @inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...) # shape (i.e., tuple-of-indices) inputs broadcast_shape(shape::Tuple) = shape @@ -121,6 +123,7 @@ map_newindexer(shape, ::Tuple{}) = (), () end @inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I) +@inline _broadcast_getindex(::Type{Array}, A::Ref, I) = A[] @inline _broadcast_getindex(::Type{Any}, A, I) = A @inline _broadcast_getindex(::Any, A, I) = A[I] diff --git a/test/broadcast.jl b/test/broadcast.jl index 7dca2095bff7e2..b8cb989e30899d 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -337,3 +337,9 @@ end @test broadcast(+, 1.0, (0, -2.0)) == (1.0,-1.0) @test broadcast(+, 1.0, (0, -2.0), [1]) == [2.0, 0.0] @test broadcast(*, ["Hello"], ", ", ["World"], "!") == ["Hello, World!"] + +# Ref as 0-dimensional array for broadcast +@test (+).(1, Ref(2)) == fill(3) +@test (+).(Ref(1), Ref(2)) == fill(3) +@test (+).([[0,2], [1,3]], [1,-1]) == [[1,3], [0,2]] +@test (+).([[0,2], [1,3]], Ref{Vector{Int}}([-1,1])) == [[1,1], [2,2]]