Skip to content

Commit

Permalink
Merge pull request #5671 from JuliaLang/teh/cartesian
Browse files Browse the repository at this point in the history
More Cartesian housekeeping, plus implementation of copy!
  • Loading branch information
timholy committed Feb 8, 2014
2 parents 43e8d16 + 36ddd12 commit 836f0ae
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 91 deletions.
4 changes: 2 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ broadcast!_function(f::Function) = (B, As...) -> broadcast!(f, B, As...)
broadcast_function(f::Function) = (As...) -> broadcast(f, As...)

broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(Array(eltype(src), broadcast_shape(I...)), src, I...)
@ngenerate N function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::NTuple{N, AbstractArray}...)
@ngenerate N typeof(dest) function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::NTuple{N, AbstractArray}...)
check_broadcast_shape(size(dest), I...) # unnecessary if this function is never called directly
checkbounds(src, I...)
@nloops N i dest d->(@nexprs N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
Expand All @@ -119,7 +119,7 @@ broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex
dest
end

@ngenerate N function broadcast_setindex!(A::AbstractArray, x, I::NTuple{N, AbstractArray}...)
@ngenerate N typeof(A) function broadcast_setindex!(A::AbstractArray, x, I::NTuple{N, AbstractArray}...)
checkbounds(A, I...)
shape = broadcast_shape(I...)
@nextract N shape d->(length(shape) < d ? 1 : shape[d])
Expand Down
70 changes: 54 additions & 16 deletions base/cartesian.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
module Cartesian

export @ngenerate, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, ngenerate
export @ngenerate, @nsplat, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, ngenerate

const CARTESIAN_DIMS = 4

### @ngenerate, for auto-generation of separate versions of functions for different dimensionalities
# Examples (deliberately trivial):
# @ngenerate N myndims{T,N}(A::Array{T,N}) = N
# @ngenerate N returntype myndims{T,N}(A::Array{T,N}) = N
# or alternatively
# function gen_body(N::Int)
# quote
# return $N
# end
# end
# eval(ngenerate(:N, :(myndims{T,N}(A::Array{T,N})), gen_body))
# eval(ngenerate(:N, returntypeexpr, :(myndims{T,N}(A::Array{T,N})), gen_body))
# The latter allows you to use a single gen_body function for both ngenerate and
# when your function maintains its own method cache (e.g., reduction or broadcasting).
#
# Special syntax for function prototypes:
# @ngenerate N function myfunction(A::AbstractArray, I::NTuple{N, Int}...)
# @ngenerate N returntype function myfunction(A::AbstractArray, I::NTuple{N, Int}...)
# for N = 3 translates to
# function myfunction(A::AbstractArray, I_1::Int, I_2::Int, I_3::Int)
# and for the generic (cached) case as
Expand All @@ -27,19 +29,52 @@ export @ngenerate, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, n
# To avoid ambiguity, it would be preferable to have some specific syntax for this, such as
# myfunction(A::AbstractArray, I::Int...N)
# where N can be an integer or symbol. Currently T...N generates a parser error.

const CARTESIAN_DIMS = 2 # FIXME: increase after testing is complete

macro ngenerate(itersym, funcexpr)
macro ngenerate(itersym, returntypeexpr, funcexpr)
isfuncexpr(funcexpr) || error("Requires a function expression")
esc(ngenerate(itersym, funcexpr.args[1], N->sreplace!(copy(funcexpr.args[2]), itersym, N)))
esc(ngenerate(itersym, returntypeexpr, funcexpr.args[1], N->sreplace!(copy(funcexpr.args[2]), itersym, N)))
end

# @nsplat takes an expression like
# @nsplat N 2:3 myfunction(A, I::NTuple{N,Real}...) = getindex(A, I...)
# and generates
# myfunction(A, I_1::Real, I_2::Real) = getindex(A, I_1, I_2)
# myfunction(A, I_1::Real, I_2::Real, I_3::Real) = getindex(A, I_1, I_2, I_3)
# myfunction(A, I::Real...) = getindex(A, I...)
# An @nsplat function _cannot_ have any other Cartesian macros in it.
# If you omit the range, it uses 1:CARTESIAN_DIMS.
macro nsplat(itersym, args...)
local rng
if length(args) == 1
rng = 1:CARTESIAN_DIMS
funcexpr = args[1]
elseif length(args) == 2
rangeexpr = args[1]
funcexpr = args[2]
if !isa(rangeexpr, Expr) || rangeexpr.head != :(:) || length(rangeexpr.args) != 2
error("First argument must be a from:to expression")
end
rng = rangeexpr.args[1]:rangeexpr.args[2]
else
error("Wrong number of arguments")
end
isfuncexpr(funcexpr) || error("Second argument must be a function expression")
prototype = funcexpr.args[1]
body = funcexpr.args[2]
varname, T = get_splatinfo(prototype, itersym)
isempty(varname) && error("Last argument must be a splat")
explicit = [Expr(:function, resolvesplat!(copy(prototype), varname, T, N),
resolvesplats!(copy(body), varname, N)) for N in rng]
protosplat = resolvesplat!(copy(prototype), varname, T, 0)
protosplat.args[end] = Expr(:..., protosplat.args[end])
splat = Expr(:function, protosplat, body)
esc(Expr(:block, explicit..., splat))
end

generate1(itersym, prototype, bodyfunc, N::Int, varname, T) =
Expr(:function, spliceint!(sreplace!(resolvesplat!(copy(prototype), varname, T, N), itersym, N)),
resolvesplats!(bodyfunc(N), varname, N))

function ngenerate(itersym, prototype, bodyfunc, dims=1:CARTESIAN_DIMS, makecached::Bool = true)
function ngenerate(itersym, returntypeexpr, prototype, bodyfunc, dims=1:CARTESIAN_DIMS, makecached::Bool = true)
varname, T = get_splatinfo(prototype, itersym)
# Generate versions for specific dimensions
fdim = [generate1(itersym, prototype, bodyfunc, N, varname, T) for N in dims]
Expand Down Expand Up @@ -71,7 +106,7 @@ function ngenerate(itersym, prototype, bodyfunc, dims=1:CARTESIAN_DIMS, makecach
_F_
end)
end
$(dictname)[$itersym]($(fargs...))
($(dictname)[$itersym]($(fargs...)))::$returntypeexpr
end)
Expr(:block, fdim..., quote
let $dictname = Dict{Int,Function}()
Expand Down Expand Up @@ -117,7 +152,8 @@ end
# Replace splatted with desplatted for a specific number of arguments
function resolvesplat!(prototype, varname, T::Union(Type,Symbol,Expr), N::Int)
if !isempty(varname)
prototype.args[end] = Expr(:(::), symbol(string(varname, "_1")), T)
prototype.args[end] = N > 0 ? Expr(:(::), symbol(string(varname, "_1")), T) :
Expr(:(::), symbol(varname), T)
for i = 2:N
push!(prototype.args, Expr(:(::), symbol(string(varname, "_", i)), T))
end
Expand Down Expand Up @@ -274,13 +310,15 @@ function _nref(N::Int, A::Symbol, ex)
end

# Generate f(arg1, arg2, ...)
macro ncall(N, f, sym)
_ncall(N, f, sym)
macro ncall(N, f, sym...)
_ncall(N, f, sym...)
end

function _ncall(N::Int, f, ex)
function _ncall(N::Int, f, args...)
pre = args[1:end-1]
ex = args[end]
vars = [ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:call, f, vars...))
Expr(:escape, Expr(:call, f, pre..., vars...))
end

# Generate N expressions
Expand Down
94 changes: 42 additions & 52 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
### From array.jl

@ngenerate N function _checksize(A::AbstractArray, I::NTuple{N, Any}...)
@ngenerate N Nothing function checksize(A::AbstractArray, I::NTuple{N, Any}...)
@nexprs N d->(size(A, d) == length(I_d) || throw(DimensionMismatch("Index $d has length $(length(I_d)), but size(A, $d) = $(size(A,d))")))
nothing
end
checksize(A, I) = (_checksize(A, I); return nothing)
checksize(A, I, J) = (_checksize(A, I, J); return nothing)
checksize(A, I...) = (_checksize(A, I...); return nothing)

unsafe_getindex(v::Real, ind::Integer) = v
unsafe_getindex(v::Ranges, ind::Integer) = first(v) + (ind-1)*step(v)
unsafe_getindex(v::AbstractArray, ind::Integer) = v[ind]

# Version that uses cartesian indexing for src
@ngenerate N function _getindex!(dest::Array, src::AbstractArray, I::NTuple{N,Union(Int,AbstractVector)}...)
@ngenerate N typeof(dest) function _getindex!(dest::Array, src::AbstractArray, I::NTuple{N,Union(Int,AbstractVector)}...)
checksize(dest, I...)
k = 1
@nloops N i dest d->(@inbounds j_d = unsafe_getindex(I_d, i_d)) begin
@inbounds dest[k] = (@nref N src j)
k += 1
end
dest
end

# Version that uses linear indexing for src
@ngenerate N function _getindex!(dest::Array, src::Array, I::NTuple{N,Union(Int,AbstractVector)}...)
@ngenerate N typeof(dest) function _getindex!(dest::Array, src::Array, I::NTuple{N,Union(Int,AbstractVector)}...)
checksize(dest, I...)
stride_1 = 1
@nexprs N d->(stride_{d+1} = stride_d*size(src,d))
Expand All @@ -33,31 +31,29 @@ end
@inbounds dest[k] = src[offset_0]
k += 1
end
dest
end

getindex!(dest, src, I) = (checkbounds(src, I); _getindex!(dest, src, to_index(I)); return dest)
getindex!(dest, src, I, J) = (checkbounds(src, I, J); _getindex!(dest, src, to_index(I), to_index(J)); return dest)
getindex!(dest, src, I...) = (checkbounds(src, I...); _getindex!(dest, src, to_index(I)...); return dest)
# It's most efficient to call checkbounds first, then to_index, and finally
# allocate the output. Hence the different variants.
_getindex(A, I::(Union(Int,AbstractVector)...)) =
_getindex!(similar(A, index_shape(I...)), A, I...)

getindex(A::Array, I::Union(Real,AbstractVector)) = getindex!(similar(A, index_shape(I)), A, I)
function getindex(A::Array, I::Union(Real,AbstractVector)...)
@nsplat N function getindex(A::Array, I::NTuple{N,Union(Real,AbstractVector)}...)
checkbounds(A, I...)
Ii = to_index(I)
dest = similar(A, index_shape(Ii...))
_getindex!(dest, A, Ii...)
dest
_getindex(A, to_index(I...))
end
# Version of the above for 2d without the splats
function getindex(A::Array, I::Union(Real,AbstractVector), J::Union(Real,AbstractVector))
checkbounds(A, I, J)
Ii, Ji = to_index(I), to_index(J)
dest = similar(A, index_shape(Ii,Ji))
_getindex!(dest, A, Ii, Ji)
dest

# Also a safe version of getindex!
@nsplat N function getindex!(dest, src, I::NTuple{N,Union(Real,AbstractVector)}...)
checkbounds(src, I...)
_getindex!(dest, src, to_index(I...)...)
end


@ngenerate N function _setindex!(A::Array, x, I::NTuple{N,Union(Int,AbstractArray)}...)
@ngenerate N typeof(A) function setindex!(A::Array, x, J::NTuple{N,Union(Real,AbstractArray)}...)
@ncall N checkbounds A J
@nexprs N d->(I_d = to_index(J_d))
stride_1 = 1
@nexprs N d->(stride_{d+1} = stride_d*size(A,d))
@nexprs N d->(offset_d = 1) # really only need offset_$N = 1
Expand All @@ -67,34 +63,19 @@ end
end
else
X = x
setindex_shape_check(X, I...)
@ncall N setindex_shape_check X I
# TODO? A variant that can use cartesian indexing for RHS
k = 1
@nloops N i d->(1:length(I_d)) d->(@inbounds offset_{d-1} = offset_d + (unsafe_getindex(I_d, i_d)-1)*stride_d) begin
@inbounds A[offset_0] = X[k]
k += 1
end
end
end

function setindex!(A::Array, x, I::Union(Real,AbstractArray), J::Union(Real,AbstractArray))
checkbounds(A, I, J)
_setindex!(A, x, to_index(I), to_index(J))
A
end
function setindex!(A::Array, x, I::Union(Real,AbstractArray))
checkbounds(A, I)
_setindex!(A, x, to_index(I))
A
end
function setindex!(A::Array, x, I::Union(Real,AbstractArray)...)
checkbounds(A, I...)
_setindex!(A, x, to_index(I)...)
A
end


@ngenerate N function findn{T,N}(A::AbstractArray{T,N})
@ngenerate N NTuple{N,Vector{Int}} function findn{T,N}(A::AbstractArray{T,N})
nnzA = countnz(A)
@nexprs N d->(I_d = Array(Int, nnzA))
k = 1
Expand Down Expand Up @@ -127,7 +108,7 @@ function gen_getindex_body(N::Int)
end
end

eval(ngenerate(:N, :(getindex{T}(s::SubArray{T,N}, ind::Integer)), gen_getindex_body, 2:5, false))
eval(ngenerate(:N, nothing, :(getindex{T}(s::SubArray{T,N}, ind::Integer)), gen_getindex_body, 2:5, false))


function gen_setindex!_body(N::Int)
Expand All @@ -145,19 +126,28 @@ function gen_setindex!_body(N::Int)
end
end

eval(ngenerate(:N, :(setindex!{T}(s::SubArray{T,N}, v, ind::Integer)), gen_setindex!_body, 2:5, false))
eval(ngenerate(:N, nothing, :(setindex!{T}(s::SubArray{T,N}, v, ind::Integer)), gen_setindex!_body, 2:5, false))


### from abstractarray.jl

@ngenerate N function _fill!{T,N}(A::AbstractArray{T,N}, x)
@ngenerate N typeof(A) function fill!{T,N}(A::AbstractArray{T,N}, x)
@nloops N i A begin
@inbounds (@nref N A i) = x
end
A
end

fill!(A::AbstractArray, x) = (_fill!(A, x); return A)

@ngenerate N typeof(dest) function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N})
if @nall N d->(size(dest,d) == size(src,d))
@nloops N i dest begin
@inbounds (@nref N dest i) = (@nref N src i)
end
else
invoke(copy!, (typeof(dest), Any), dest, src)
end
dest
end

### from bitarray.jl

Expand All @@ -173,7 +163,7 @@ function getindex(B::BitArray, I0::Range1)
end

# TODO: extend to I:Union(Real,AbstractArray)... (i.e. not necessarily contiguous)
@ngenerate N function getindex(B::BitArray, I0::Range1, I::NTuple{N,Union(Real,Range1)}...)
@ngenerate N BitArray{length(I)+1} function getindex(B::BitArray, I0::Range1, I::NTuple{N,Union(Real,Range1)}...)
ndims(B) < N+1 && error("wrong number of dimensions")
checkbounds(B, I0, I...)
X = BitArray(index_shape(I0, I...))
Expand Down Expand Up @@ -207,7 +197,7 @@ end
return X
end

@ngenerate N function getindex(B::BitArray, I::NTuple{N,Union(Real,AbstractVector)}...)
@ngenerate N BitArray{length(I)} function getindex(B::BitArray, I::NTuple{N,Union(Real,AbstractVector)}...)
checkbounds(B, I...)
@nexprs N d->(I_d = to_index(I_d))
X = BitArray(index_shape(I...))
Expand Down Expand Up @@ -237,7 +227,7 @@ function setindex!(B::BitArray, X::BitArray, I0::Range1)
end

# TODO: extend to I:Union(Real,AbstractArray)... (i.e. not necessarily contiguous)
@ngenerate N function setindex!(B::BitArray, X::BitArray, I0::Range1, I::NTuple{N,Union(Real,Range1)}...)
@ngenerate N typeof(B) function setindex!(B::BitArray, X::BitArray, I0::Range1, I::NTuple{N,Union(Real,Range1)}...)
ndims(B) != N+1 && error("wrong number of dimensions in assigment")
I0 = to_index(I0)
lI = length(I0)
Expand Down Expand Up @@ -275,7 +265,7 @@ end
return B
end

@ngenerate N function setindex!(B::BitArray, X::AbstractArray, I::NTuple{N,Union(Real,AbstractArray)}...)
@ngenerate N typeof(B) function setindex!(B::BitArray, X::AbstractArray, I::NTuple{N,Union(Real,AbstractArray)}...)
checkbounds(B, I...)
@nexprs N d->(I_d = to_index(I_d))
nel = 1
Expand All @@ -294,7 +284,7 @@ end
return B
end

@ngenerate N function setindex!(B::BitArray, x, I::NTuple{N,Union(Real,AbstractArray)}...)
@ngenerate N typeof(B) function setindex!(B::BitArray, x, I::NTuple{N,Union(Real,AbstractArray)}...)
x = convert(Bool, x)
checkbounds(B, I...)
@nexprs N d->(I_d = to_index(I_d))
Expand All @@ -305,7 +295,7 @@ end
return B
end

@ngenerate N function findn{N}(B::BitArray{N})
@ngenerate N NTuple{N,Vector{Int}} function findn{N}(B::BitArray{N})
nnzB = countnz(B)
I = ntuple(N, x->Array(Int, nnzB))
if nnzB > 0
Expand All @@ -322,7 +312,7 @@ end

for (V, PT, BT) in [((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)]
@eval begin
@ngenerate N function permutedims!{$(V...)}(P::$PT{$(V...)}, B::$BT{$(V...)}, perm)
@ngenerate N typeof(P) function permutedims!{$(V...)}(P::$PT{$(V...)}, B::$BT{$(V...)}, perm)
dimsB = size(B)
(length(perm) == N && isperm(perm)) || error("no valid permutation of dimensions")
dimsP = size(P)
Expand Down
Loading

0 comments on commit 836f0ae

Please sign in to comment.