diff --git a/base/subarray.jl b/base/subarray.jl index 538bf8fd65059..568aeeb377b61 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -333,25 +333,32 @@ should transform to A[B[endof(B)]] """ -function replace_ref_end!(ex,withex=nothing) +replace_ref_end!(ex) = replace_ref_end_!(ex, nothing)[1] +# replace_ref_end_!(ex,withex) returns (new ex, whether withex was used) +function replace_ref_end_!(ex, withex) + used_withex = false if isa(ex,Symbol) && ex == :end withex === nothing && error("Invalid use of end") - return withex + return withex, true elseif isa(ex,Expr) if ex.head == :ref - S = ex.args[1] = replace_ref_end!(ex.args[1],withex) + ex.args[1], used_withex = replace_ref_end_!(ex.args[1],withex) + S = isa(ex.args[1],Symbol) ? ex.args[1]::Symbol : gensym(:S) # temp var to cache ex.args[1] if needed + used_S = false # whether we actually need S # new :ref, so redefine withex nargs = length(ex.args)-1 if nargs == 0 - return ex + return ex, used_withex elseif nargs == 1 # replace with endof(S) - ex.args[2] = replace_ref_end!(ex.args[2],:(Base.endof($S))) + ex.args[2], used_S = replace_ref_end_!(ex.args[2],:($endof($S))) else n = 1 J = endof(ex.args) for j = 2:J-1 - exj = ex.args[j] = replace_ref_end!(ex.args[j],:(Base.size($S,$n))) + exj, used = replace_ref_end_!(ex.args[j],:($size($S,$n))) + used_S |= used + ex.args[j] = exj if isa(exj,Expr) && exj.head == :... # splatted object exjs = exj.args[1] @@ -364,16 +371,23 @@ function replace_ref_end!(ex,withex=nothing) n += 1 end end - ex.args[J] = replace_ref_end!(ex.args[J],:(Base.trailingsize($S,$n))) + ex.args[J], used = replace_ref_end_!(ex.args[J],:($trailingsize($S,$n))) + used_S |= used + end + if used_S && S !== ex.args[1] + S0 = ex.args[1] + ex.args[1] = S + ex = Expr(:let, ex, :($S = $S0)) end else # recursive search for i = eachindex(ex.args) - ex.args[i] = replace_ref_end!(ex.args[i],withex) + ex.args[i], used = replace_ref_end_!(ex.args[i],withex) + used_withex |= used end end end - ex + ex, used_withex end """ @@ -385,9 +399,15 @@ an assignment (e.g. `@view(A[1,2:end]) = ...`). See also [`@views`](@ref) to switch an entire block of code to use views for slicing. """ macro view(ex) - if isa(ex, Expr) && ex.head == :ref + if Meta.isexpr(ex, :ref) ex = replace_ref_end!(ex) - Expr(:&&, true, esc(Expr(:call,:(Base.view),ex.args...))) + if Meta.isexpr(ex, :ref) + ex = Expr(:call, view, ex.args...) + else # ex replaced by let ...; foo[...]; end + assert(Meta.isexpr(ex, :let) && Meta.isexpr(ex.args[1], :ref)) + ex.args[1] = Expr(:call, view, ex.args[1].args...) + end + Expr(:&&, true, esc(ex)) else throw(ArgumentError("Invalid use of @view macro: argument must be a reference expression A[...].")) end @@ -404,21 +424,53 @@ end @propagate_inbounds maybeview(A) = getindex(A) @propagate_inbounds maybeview(A::AbstractArray) = getindex(A) +# _views implements the transformation for the @views macro. +# @views calls esc(_views(...)) to work around #20241, +# so any function calls we insert (to maybeview, or to +# size and endof in replace_ref_end!) must be interpolated +# as values rather than as symbols to ensure that they are called +# from Base rather than from the caller's scope. _views(x) = x -_views(x::Symbol) = esc(x) function _views(ex::Expr) if ex.head in (:(=), :(.=)) - # don't use view on the lhs of an assignment - Expr(ex.head, esc(ex.args[1]), _views(ex.args[2])) + # don't use view for ref on the lhs of an assignment, + # but still use views for the args of the ref: + lhs = ex.args[1] + Expr(ex.head, Meta.isexpr(lhs, :ref) ? + Expr(:ref, _views.(lhs.args)...) : _views(lhs), + _views(ex.args[2])) elseif ex.head == :ref - ex = replace_ref_end!(ex) - Expr(:call, :maybeview, _views.(ex.args)...) + Expr(:call, maybeview, _views.(ex.args)...) else h = string(ex.head) - if last(h) == '=' - # don't use view on the lhs of an op-assignment - Expr(first(h) == '.' ? :(.=) : :(=), esc(ex.args[1]), - Expr(:call, esc(Symbol(h[1:end-1])), _views.(ex.args)...)) + # don't use view on the lhs of an op-assignment a[i...] += ... + if last(h) == '=' && Meta.isexpr(ex.args[1], :ref) + lhs = ex.args[1] + + # temp vars to avoid recomputing a and i, + # which will be assigned in a let block: + a = gensym(:a) + i = [gensym(:i) for k = 1:length(lhs.args)-1] + + # for splatted indices like a[i, j...], we need to + # splat the corresponding temp var. + I = similar(i, Any) + for k = 1:length(i) + if Meta.isexpr(lhs.args[k+1], :...) + I[k] = Expr(:..., i[k]) + lhs.args[k+1] = lhs.args[k+1].args[1] # unsplat + else + I[k] = i[k] + end + end + + Expr(:let, + Expr(first(h) == '.' ? :(.=) : :(=), :($a[$(I...)]), + Expr(:call, Symbol(h[1:end-1]), + :($maybeview($a, $(I...))), + _views.(ex.args[2:end])...)), + :($a = $(_views(lhs.args[1]))), + [:($(i[k]) = $(_views(lhs.args[k+1]))) for k=1:length(i)]...) else Expr(ex.head, _views.(ex.args)...) end @@ -439,5 +491,5 @@ that appear explicitly in the given `expression`, not array slicing that occurs in functions called by that code. """ macro views(x) - _views(x) + esc(_views(replace_ref_end!(x))) end diff --git a/test/subarray.jl b/test/subarray.jl index 3791d38e8c01f..cfaa8cff5b2c2 100644 --- a/test/subarray.jl +++ b/test/subarray.jl @@ -473,8 +473,6 @@ end @test collect(view(view(reshape(1:13^3, 13, 13, 13), 3:7, 6:6, :), 1:2:5, :, 1:2:5)) == cat(3,[68,70,72],[406,408,410],[744,746,748]) - - # tests @view (and replace_ref_end!) X = reshape(1:24,2,3,4) Y = 4:-1:1 @@ -494,10 +492,16 @@ u = (1,2:3) @test X[(1,)...,(2,)...,2:end] == @view X[(1,)...,(2,)...,2:end] # test macro hygiene -let size=(x,y)-> error("should not happen") +let size=(x,y)-> error("should not happen"), Base=nothing @test X[1:end,2,2] == @view X[1:end,2,2] end +# test that side effects occur only once +let foo = [X] + @test X[2:end-1] == @view (push!(foo,X)[1])[2:end-1] + @test foo == [X, X] +end + # test @views macro @views let f!(x) = x[1:end-1] .+= x[2:end].^2 x = [1,2,3,4] @@ -512,6 +516,16 @@ end @test x == [5,6,19,4] f!(x[3:end]) @test x == [5,6,35,4] + x[Y[2:3]] .= 7:8 + @test x == [5,8,7,4] + x[(3,)..., ()...] .+= 3 + @test x == [5,8,10,4] + i = Int[] + # test that lhs expressions in update operations are evaluated only once: + x[push!(i,4)[1]] += 5 + @test x == [5,8,10,9] && i == [4] + x[push!(i,3)[end]] += 2 + @test x == [5,8,12,9] && i == [4,3] end @views @test isa(X[1:3], SubArray) @test X[1:end] == @views X[1:end] @@ -523,9 +537,8 @@ end @test X[1:end,2,Y[2:end]] == @views X[1:end,2,Y[2:end]] @test X[u...,2:end] == @views X[u...,2:end] @test X[(1,)...,(2,)...,2:end] == @views X[(1,)...,(2,)...,2:end] - # test macro hygiene -let size=(x,y)-> error("should not happen") +let size=(x,y)-> error("should not happen"), Base=nothing @test X[1:end,2,2] == @views X[1:end,2,2] end