Skip to content

Commit

Permalink
refactor cartesian.jl to use dispatch in macros (#24450)
Browse files Browse the repository at this point in the history
add a couple helpful type declarations
  • Loading branch information
JeffBezanson authored Nov 3, 2017
1 parent ae51693 commit d929f0b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 50 deletions.
2 changes: 1 addition & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ similar(a::AbstractArray, ::Type{T}, dims::Dims{N}) where {T,N} = Array{T,N}(

to_shape(::Tuple{}) = ()
to_shape(dims::Dims) = dims
to_shape(dims::DimsOrInds) = map(to_shape, dims)
to_shape(dims::DimsOrInds) = map(to_shape, dims)::DimsOrInds
# each dimension
to_shape(i::Int) = i
to_shape(i::Integer) = Int(i)
Expand Down
2 changes: 1 addition & 1 deletion base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ else
end

_array_for(::Type{T}, itr, ::HasLength) where {T} = Array{T,1}(Int(length(itr)::Integer))
_array_for(::Type{T}, itr, ::HasShape) where {T} = similar(Array{T}, indices(itr))
_array_for(::Type{T}, itr, ::HasShape) where {T} = similar(Array{T}, indices(itr))::Array{T}

function collect(itr::Generator)
isz = iteratorsize(itr.iter)
Expand Down
60 changes: 16 additions & 44 deletions base/cartesian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,8 @@ julia> @macroexpand Base.Cartesian.@nref 3 A i
:(A[i_1, i_2, i_3])
```
"""
macro nref(N, A, sym)
_nref(N, A, sym)
end

function _nref(N::Int, A::Symbol, ex)
vars = [ inlineanonymous(ex,i) for i = 1:N ]
macro nref(N::Int, A::Symbol, ex)
vars = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:ref, A, vars...))
end

Expand All @@ -105,14 +101,10 @@ while `@ncall 2 func a b i->c[i]` yields
func(a, b, c[1], c[2])
"""
macro ncall(N, f, sym...)
_ncall(N, f, sym...)
end

function _ncall(N::Int, f, args...)
macro ncall(N::Int, f, args...)
pre = args[1:end-1]
ex = args[end]
vars = [ inlineanonymous(ex,i) for i = 1:N ]
vars = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:call, f, pre..., vars...))
end

Expand All @@ -132,12 +124,8 @@ quote
end
```
"""
macro nexprs(N, ex)
_nexprs(N, ex)
end

function _nexprs(N::Int, ex::Expr)
exs = [ inlineanonymous(ex,i) for i = 1:N ]
macro nexprs(N::Int, ex::Expr)
exs = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:block, exs...))
end

Expand All @@ -159,17 +147,13 @@ while `@nextract 3 x d->y[2d-1]` yields
x_3 = y[5]
"""
macro nextract(N, esym, isym)
_nextract(N, esym, isym)
end

function _nextract(N::Int, esym::Symbol, isym::Symbol)
aexprs = [Expr(:escape, Expr(:(=), inlineanonymous(esym, i), :(($isym)[$i]))) for i = 1:N]
macro nextract(N::Int, esym::Symbol, isym::Symbol)
aexprs = Any[ Expr(:escape, Expr(:(=), inlineanonymous(esym, i), :(($isym)[$i]))) for i = 1:N ]
Expr(:block, aexprs...)
end

function _nextract(N::Int, esym::Symbol, ex::Expr)
aexprs = [Expr(:escape, Expr(:(=), inlineanonymous(esym, i), inlineanonymous(ex,i))) for i = 1:N]
macro nextract(N::Int, esym::Symbol, ex::Expr)
aexprs = Any[ Expr(:escape, Expr(:(=), inlineanonymous(esym, i), inlineanonymous(ex,i))) for i = 1:N ]
Expr(:block, aexprs...)
end

Expand All @@ -182,15 +166,11 @@ evaluate to `true`.
`@nall 3 d->(i_d > 1)` would generate the expression `(i_1 > 1 && i_2 > 1 && i_3 > 1)`. This
can be convenient for bounds-checking.
"""
macro nall(N, criterion)
_nall(N, criterion)
end

function _nall(N::Int, criterion::Expr)
macro nall(N::Int, criterion::Expr)
if criterion.head != :->
throw(ArgumentError("second argument must be an anonymous function expression yielding the criterion"))
end
conds = [Expr(:escape, inlineanonymous(criterion, i)) for i = 1:N]
conds = Any[ Expr(:escape, inlineanonymous(criterion, i)) for i = 1:N ]
Expr(:&&, conds...)
end

Expand All @@ -202,15 +182,11 @@ evaluate to `true`.
`@nany 3 d->(i_d > 1)` would generate the expression `(i_1 > 1 || i_2 > 1 || i_3 > 1)`.
"""
macro nany(N, criterion)
_nany(N, criterion)
end

function _nany(N::Int, criterion::Expr)
macro nany(N::Int, criterion::Expr)
if criterion.head != :->
error("Second argument must be an anonymous function expression yielding the criterion")
end
conds = [Expr(:escape, inlineanonymous(criterion, i)) for i = 1:N]
conds = Any[ Expr(:escape, inlineanonymous(criterion, i)) for i = 1:N ]
Expr(:||, conds...)
end

Expand All @@ -220,12 +196,8 @@ end
Generates an `N`-tuple. `@ntuple 2 i` would generate `(i_1, i_2)`, and `@ntuple 2 k->k+1`
would generate `(2,3)`.
"""
macro ntuple(N, ex)
_ntuple(N, ex)
end

function _ntuple(N::Int, ex)
vars = [ inlineanonymous(ex,i) for i = 1:N ]
macro ntuple(N::Int, ex)
vars = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:tuple, vars...))
end

Expand Down
4 changes: 0 additions & 4 deletions base/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,6 @@ precompile(Tuple{typeof(Base.length), Tuple{DataType, DataType}})
precompile(Tuple{Type{BoundsError}, Array{Int64, 2}, Tuple{Base.UnitRange{Int64}, Int64}})
precompile(Tuple{typeof(Base.throw_boundserror), Array{Int64, 2}, Tuple{Base.UnitRange{Int64}, Int64}})
precompile(Tuple{getfield(Base.Cartesian, Symbol("#@nexprs")), Int64, Expr})
precompile(Tuple{typeof(Base.Cartesian._nexprs), Int64, Expr})
precompile(Tuple{typeof(Core.Inference.builtin_tfunction), typeof(===), Array{Any, 1}, Core.Inference.InferenceState, Core.Inference.InferenceParams})
precompile(Tuple{typeof(Core.Inference.typeinf_frame), Core.MethodInstance, Bool, Bool, Core.Inference.InferenceParams})
precompile(Tuple{typeof(Core.Inference.typeinf), Core.Inference.InferenceState})
Expand All @@ -837,14 +836,11 @@ precompile(Tuple{typeof(Base.Cartesian.lreplace!), String, Base.Cartesian.LRepla
precompile(Tuple{typeof(Base.Cartesian.exprresolve), Expr})
precompile(Tuple{Type{BoundsError}, Array{Expr, 1}, Base.UnitRange{Int64}})
precompile(Tuple{getfield(Base.Cartesian, Symbol("#@ncall")), Int64, Symbol, Symbol})
precompile(Tuple{typeof(Base.Cartesian._ncall), Int64, Symbol, Symbol})
precompile(Tuple{typeof(Base.getindex), Tuple{Symbol}, Base.UnitRange{Int64}})
precompile(Tuple{getfield(Base.Cartesian, Symbol("#@ncall")), Int64, Symbol, Symbol, Expr})
precompile(Tuple{typeof(Base.Cartesian._ncall), Int64, Symbol, Symbol, Expr})
precompile(Tuple{typeof(Base.endof), Tuple{Symbol, Expr}})
precompile(Tuple{typeof(Base.getindex), Tuple{Symbol, Expr}, Base.UnitRange{Int64}})
precompile(Tuple{getfield(Base.Cartesian, Symbol("#@nloops")), Int64, Symbol, Expr, Expr})
precompile(Tuple{typeof(Base.Cartesian._nloops), Int64, Symbol, Expr, Expr})
precompile(Tuple{typeof(Base.endof), Tuple{Expr}})
precompile(Tuple{typeof(Base.endof), Tuple{Symbol, Symbol, Symbol}})
precompile(Tuple{typeof(Base.getindex), Tuple{Symbol, Symbol, Symbol}, Base.UnitRange{Int64}})
Expand Down

0 comments on commit d929f0b

Please sign in to comment.