Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes for at-view and at-views #20247

Merged
merged 2 commits into from
Jan 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 73 additions & 21 deletions base/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,25 +333,32 @@ should transform to
A[B[endof(B)]]

"""
function replace_ref_end!(ex,withex=nothing)
replace_ref_end!(ex) = replace_ref_end_!(ex, nothing)[1]
# replace_ref_end_!(ex,withex) returns (new ex, whether withex was used)
function replace_ref_end_!(ex, withex)
used_withex = false
if isa(ex,Symbol) && ex == :end
withex === nothing && error("Invalid use of end")
return withex
return withex, true
elseif isa(ex,Expr)
if ex.head == :ref
S = ex.args[1] = replace_ref_end!(ex.args[1],withex)
ex.args[1], used_withex = replace_ref_end_!(ex.args[1],withex)
S = isa(ex.args[1],Symbol) ? ex.args[1]::Symbol : gensym(:S) # temp var to cache ex.args[1] if needed
used_S = false # whether we actually need S
# new :ref, so redefine withex
nargs = length(ex.args)-1
if nargs == 0
return ex
return ex, used_withex
elseif nargs == 1
# replace with endof(S)
ex.args[2] = replace_ref_end!(ex.args[2],:(Base.endof($S)))
ex.args[2], used_S = replace_ref_end_!(ex.args[2],:($endof($S)))
else
n = 1
J = endof(ex.args)
for j = 2:J-1
exj = ex.args[j] = replace_ref_end!(ex.args[j],:(Base.size($S,$n)))
exj, used = replace_ref_end_!(ex.args[j],:($size($S,$n)))
used_S |= used
ex.args[j] = exj
if isa(exj,Expr) && exj.head == :...
# splatted object
exjs = exj.args[1]
Expand All @@ -364,16 +371,23 @@ function replace_ref_end!(ex,withex=nothing)
n += 1
end
end
ex.args[J] = replace_ref_end!(ex.args[J],:(Base.trailingsize($S,$n)))
ex.args[J], used = replace_ref_end_!(ex.args[J],:($trailingsize($S,$n)))
used_S |= used
end
if used_S && S !== ex.args[1]
S0 = ex.args[1]
ex.args[1] = S
ex = Expr(:let, ex, :($S = $S0))
end
else
# recursive search
for i = eachindex(ex.args)
ex.args[i] = replace_ref_end!(ex.args[i],withex)
ex.args[i], used = replace_ref_end_!(ex.args[i],withex)
used_withex |= used
end
end
end
ex
ex, used_withex
end

"""
Expand All @@ -385,9 +399,15 @@ an assignment (e.g. `@view(A[1,2:end]) = ...`). See also [`@views`](@ref)
to switch an entire block of code to use views for slicing.
"""
macro view(ex)
if isa(ex, Expr) && ex.head == :ref
if Meta.isexpr(ex, :ref)
ex = replace_ref_end!(ex)
Expr(:&&, true, esc(Expr(:call,:(Base.view),ex.args...)))
if Meta.isexpr(ex, :ref)
ex = Expr(:call, view, ex.args...)
else # ex replaced by let ...; foo[...]; end
assert(Meta.isexpr(ex, :let) && Meta.isexpr(ex.args[1], :ref))
ex.args[1] = Expr(:call, view, ex.args[1].args...)
end
Expr(:&&, true, esc(ex))
else
throw(ArgumentError("Invalid use of @view macro: argument must be a reference expression A[...]."))
end
Expand All @@ -404,21 +424,53 @@ end
@propagate_inbounds maybeview(A) = getindex(A)
@propagate_inbounds maybeview(A::AbstractArray) = getindex(A)

# _views implements the transformation for the @views macro.
# @views calls esc(_views(...)) to work around #20241,
# so any function calls we insert (to maybeview, or to
# size and endof in replace_ref_end!) must be interpolated
# as values rather than as symbols to ensure that they are called
# from Base rather than from the caller's scope.
_views(x) = x
_views(x::Symbol) = esc(x)
function _views(ex::Expr)
if ex.head in (:(=), :(.=))
# don't use view on the lhs of an assignment
Expr(ex.head, esc(ex.args[1]), _views(ex.args[2]))
# don't use view for ref on the lhs of an assignment,
# but still use views for the args of the ref:
lhs = ex.args[1]
Expr(ex.head, Meta.isexpr(lhs, :ref) ?
Expr(:ref, _views.(lhs.args)...) : _views(lhs),
_views(ex.args[2]))
elseif ex.head == :ref
ex = replace_ref_end!(ex)
Expr(:call, :maybeview, _views.(ex.args)...)
Expr(:call, maybeview, _views.(ex.args)...)
else
h = string(ex.head)
if last(h) == '='
# don't use view on the lhs of an op-assignment
Expr(first(h) == '.' ? :(.=) : :(=), esc(ex.args[1]),
Expr(:call, esc(Symbol(h[1:end-1])), _views.(ex.args)...))
# don't use view on the lhs of an op-assignment a[i...] += ...
if last(h) == '=' && Meta.isexpr(ex.args[1], :ref)
lhs = ex.args[1]

# temp vars to avoid recomputing a and i,
# which will be assigned in a let block:
a = gensym(:a)
i = [gensym(:i) for k = 1:length(lhs.args)-1]

# for splatted indices like a[i, j...], we need to
# splat the corresponding temp var.
I = similar(i, Any)
for k = 1:length(i)
if Meta.isexpr(lhs.args[k+1], :...)
I[k] = Expr(:..., i[k])
lhs.args[k+1] = lhs.args[k+1].args[1] # unsplat
else
I[k] = i[k]
end
end

Expr(:let,
Expr(first(h) == '.' ? :(.=) : :(=), :($a[$(I...)]),
Expr(:call, Symbol(h[1:end-1]),
:($maybeview($a, $(I...))),
_views.(ex.args[2:end])...)),
:($a = $(_views(lhs.args[1]))),
[:($(i[k]) = $(_views(lhs.args[k+1]))) for k=1:length(i)]...)
else
Expr(ex.head, _views.(ex.args)...)
end
Expand All @@ -439,5 +491,5 @@ that appear explicitly in the given `expression`, not array slicing that
occurs in functions called by that code.
"""
macro views(x)
_views(x)
esc(_views(replace_ref_end!(x)))
end
23 changes: 18 additions & 5 deletions test/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,6 @@ end
@test collect(view(view(reshape(1:13^3, 13, 13, 13), 3:7, 6:6, :), 1:2:5, :, 1:2:5)) ==
cat(3,[68,70,72],[406,408,410],[744,746,748])



# tests @view (and replace_ref_end!)
X = reshape(1:24,2,3,4)
Y = 4:-1:1
Expand All @@ -494,10 +492,16 @@ u = (1,2:3)
@test X[(1,)...,(2,)...,2:end] == @view X[(1,)...,(2,)...,2:end]

# test macro hygiene
let size=(x,y)-> error("should not happen")
let size=(x,y)-> error("should not happen"), Base=nothing
@test X[1:end,2,2] == @view X[1:end,2,2]
end

# test that side effects occur only once
let foo = [X]
@test X[2:end-1] == @view (push!(foo,X)[1])[2:end-1]
@test foo == [X, X]
end

# test @views macro
@views let f!(x) = x[1:end-1] .+= x[2:end].^2
x = [1,2,3,4]
Expand All @@ -512,6 +516,16 @@ end
@test x == [5,6,19,4]
f!(x[3:end])
@test x == [5,6,35,4]
x[Y[2:3]] .= 7:8
@test x == [5,8,7,4]
x[(3,)..., ()...] .+= 3
@test x == [5,8,10,4]
i = Int[]
# test that lhs expressions in update operations are evaluated only once:
x[push!(i,4)[1]] += 5
@test x == [5,8,10,9] && i == [4]
x[push!(i,3)[end]] += 2
@test x == [5,8,12,9] && i == [4,3]
end
@views @test isa(X[1:3], SubArray)
@test X[1:end] == @views X[1:end]
Expand All @@ -523,9 +537,8 @@ end
@test X[1:end,2,Y[2:end]] == @views X[1:end,2,Y[2:end]]
@test X[u...,2:end] == @views X[u...,2:end]
@test X[(1,)...,(2,)...,2:end] == @views X[(1,)...,(2,)...,2:end]

# test macro hygiene
let size=(x,y)-> error("should not happen")
let size=(x,y)-> error("should not happen"), Base=nothing
@test X[1:end,2,2] == @views X[1:end,2,2]
end

Expand Down