Skip to content

Commit

Permalink
Merge pull request #15973 from JuliaLang/teh/more_reshape
Browse files Browse the repository at this point in the history
reshape: helpful error message, more tests
  • Loading branch information
timholy committed Apr 21, 2016
2 parents 17dfdc1 + 64745d7 commit 9b029a3
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 54 deletions.
80 changes: 61 additions & 19 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@ ReshapedArray{T,N}(parent::AbstractArray{T}, dims::NTuple{N,Int}, mi) = Reshaped
typealias ReshapedArrayLF{T,N,P<:AbstractArray} ReshapedArray{T,N,P,Tuple{}}

# Fast iteration on ReshapedArrays: use the parent iterator
immutable ReshapedRange{I,M}
immutable ReshapedArrayIterator{I,M}
iter::I
mi::NTuple{M,SignedMultiplicativeInverse{Int}}
end
ReshapedRange(A::ReshapedArray) = reshapedrange(parent(A), A.mi)
function reshapedrange{M}(P, mi::NTuple{M})
ReshapedArrayIterator(A::ReshapedArray) = _rs_iterator(parent(A), A.mi)
function _rs_iterator{M}(P, mi::NTuple{M})
iter = eachindex(P)
ReshapedRange{typeof(iter),M}(iter, mi)
ReshapedArrayIterator{typeof(iter),M}(iter, mi)
end

immutable ReshapedIndex{T}
parentindex::T
end

# eachindex(A::ReshapedArray) = ReshapedRange(A) # TODO: uncomment this line
start(R::ReshapedRange) = start(R.iter)
@inline done(R::ReshapedRange, i) = done(R.iter, i)
@inline function next(R::ReshapedRange, i)
# eachindex(A::ReshapedArray) = ReshapedArrayIterator(A) # TODO: uncomment this line
start(R::ReshapedArrayIterator) = start(R.iter)
@inline done(R::ReshapedArrayIterator, i) = done(R.iter, i)
@inline function next(R::ReshapedArrayIterator, i)
item, inext = next(R.iter, i)
ReshapedIndex(item), inext
end
length(R::ReshapedRange) = length(R.iter)
length(R::ReshapedArrayIterator) = length(R.iter)

function reshape(parent::AbstractArray, dims::Dims)
prod(dims) == length(parent) || throw(DimensionMismatch("parent has $(length(parent)) elements, which is incompatible with size $dims"))
Expand Down Expand Up @@ -84,19 +84,61 @@ reinterpret{T}(::Type{T}, A::ReshapedArray, dims::Dims) = reinterpret(T, parent(
ind2sub_rs((d+1, out...), tail(strds), r)
end

@inline getindex(A::ReshapedArrayLF, index::Int) = (@boundscheck checkbounds(A, index); @inbounds ret = parent(A)[index]; ret)
@inline getindex(A::ReshapedArray, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_getindex(A, indexes...))
@inline getindex(A::ReshapedArray, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds ret = parent(A)[index.parentindex]; ret)
@inline function getindex(A::ReshapedArrayLF, index::Int)
@boundscheck checkbounds(A, index)
@inbounds ret = parent(A)[index]
ret
end
@inline function getindex(A::ReshapedArray, indexes::Int...)
@boundscheck checkbounds(A, indexes...)
_unsafe_getindex(A, indexes...)
end
@inline function getindex(A::ReshapedArray, index::ReshapedIndex)
@boundscheck checkbounds(parent(A), index.parentindex)
@inbounds ret = parent(A)[index.parentindex]
ret
end

@inline function _unsafe_getindex(A::ReshapedArray, indexes::Int...)
@inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...]
ret
end
@inline function _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...)
@inbounds ret = parent(A)[sub2ind(size(A), indexes...)]
ret
end

@inline _unsafe_getindex(A::ReshapedArray, indexes::Int...) = (@inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...]; ret)
@inline _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...) = (@inbounds ret = parent(A)[sub2ind(size(A), indexes...)]; ret)
@inline function setindex!(A::ReshapedArrayLF, val, index::Int)
@boundscheck checkbounds(A, index)
@inbounds parent(A)[index] = val
val
end
@inline function setindex!(A::ReshapedArray, val, indexes::Int...)
@boundscheck checkbounds(A, indexes...)
_unsafe_setindex!(A, val, indexes...)
end
@inline function setindex!(A::ReshapedArray, val, index::ReshapedIndex)
@boundscheck checkbounds(parent(A), index.parentindex)
@inbounds parent(A)[index.parentindex] = val
val
end

@inline function _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...)
@inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val
val
end
@inline function _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...)
@inbounds parent(A)[sub2ind(size(A), indexes...)] = val
val
end

@inline setindex!(A::ReshapedArrayLF, val, index::Int) = (@boundscheck checkbounds(A, index); @inbounds parent(A)[index] = val; val)
@inline setindex!(A::ReshapedArray, val, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_setindex!(A, val, indexes...))
@inline setindex!(A::ReshapedArray, val, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds parent(A)[index.parentindex] = val; val)
# helpful error message for a common failure case
typealias ReshapedRange{T,N,A<:Range} ReshapedArray{T,N,A,Tuple{}}
setindex!(A::ReshapedRange, val, index::Int) = _rs_setindex!_err()
setindex!(A::ReshapedRange, val, indexes::Int...) = _rs_setindex!_err()
setindex!(A::ReshapedRange, val, index::ReshapedIndex) = _rs_setindex!_err()

@inline _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...) = (@inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val; val)
@inline _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...) = (@inbounds parent(A)[sub2ind(size(A), indexes...)] = val; val)
_rs_setindex!_err() = error("indexed assignment fails for a reshaped range; consider calling collect")

typealias ArrayT{N, T} Array{T,N}
convert{T,S,N}(::Type{Array{T,N}}, V::ReshapedArray{S,N}) = copy!(Array(T, size(V)), V)
Expand Down
41 changes: 39 additions & 2 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,50 @@ a = reshape(b, (2, 2, 2, 2, 2))
@test a[2,2,2,2,2] == b[end]

# reshaping linearslow arrays
a = zeros(1, 5)
a = collect(reshape(1:5, 1, 5))
s = sub(a, :, [2,3,5])
@test length(reshape(s, length(s))) == 3
r = reshape(s, length(s))
@test length(r) == 3
@test r[1] == 2
@test r[3,1] == 5
@test r[Base.ReshapedIndex(CartesianIndex((1,2)))] == 3
@test parent(reshape(r, (1,3))) === r.parent === s
@test parentindexes(r) == (1:1, 1:3)
@test reshape(r, (3,)) === r
r[2] = -1
@test a[3] == -1
a = zeros(0, 5) # an empty linearslow array
s = sub(a, :, [2,3,5])
@test length(reshape(s, length(s))) == 0

@test reshape(1:5, (5,)) === 1:5
@test reshape(1:5, 5) === 1:5

# setindex! on a reshaped range
a = reshape(1:20, 5, 4)
for idx in ((3,), (2,2), (Base.ReshapedIndex(1),))
try
a[idx...] = 7
catch err
@test err.msg == "indexed assignment fails for a reshaped range; consider calling collect"
end
end

# operations with LinearFast ReshapedArray
b = collect(1:12)
a = Base.ReshapedArray(b, (4,3), ())
@test a[3,2] == 7
@test a[6] == 6
a[3,2] = -2
a[6] = -3
a[Base.ReshapedIndex(5)] = -4
@test b[5] == -4
@test b[6] == -3
@test b[7] == -2
b = reinterpret(Int, a, (3,4))
b[1] = -1
@test vec(b) == vec(a)

a = rand(1, 1, 8, 8, 1)
@test @inferred(squeeze(a, 1)) == @inferred(squeeze(a, (1,))) == reshape(a, (1, 8, 8, 1))
@test @inferred(squeeze(a, (1, 5))) == squeeze(a, (5, 1)) == reshape(a, (1, 8, 8))
Expand Down
67 changes: 34 additions & 33 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,22 +324,21 @@ t1 = bitrand(n1, n2)
b2 = bitrand(countnz(t1))
@check_bit_operation setindex!(b1, b2, t1) BitMatrix

m1 = rand(1:n1)
m2 = rand(1:n2)

t1 = bitrand(n1)
b2 = bitrand(countnz(t1), m2)
k2 = randperm(m2)
@check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix

t2 = bitrand(n2)
b2 = bitrand(m1, countnz(t2))
k1 = randperm(m1)
@check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix
let m1 = rand(1:n1), m2 = rand(1:n2)
t1 = bitrand(n1)
b2 = bitrand(countnz(t1), m2)
k2 = randperm(m2)
@check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix
@check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix

t2 = bitrand(n2)
b2 = bitrand(m1, countnz(t2))
k1 = randperm(m1)
@check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix
@check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix
end

timesofar("indexing")

Expand Down Expand Up @@ -1054,23 +1053,25 @@ end

## Reductions ##

b1 = bitrand(s1, s2, s3, s4)
m1 = 1
m2 = 3
@check_bit_operation maximum(b1, (m1, m2)) BitArray{4}
@check_bit_operation minimum(b1, (m1, m2)) BitArray{4}
@check_bit_operation sum(b1, (m1, m2)) Array{Int,4}

@check_bit_operation maximum(b1) Bool
@check_bit_operation minimum(b1) Bool
@check_bit_operation any(b1) Bool
@check_bit_operation all(b1) Bool
@check_bit_operation sum(b1) Int

b0 = falses(0)
@check_bit_operation any(b0) Bool
@check_bit_operation all(b0) Bool
@check_bit_operation sum(b0) Int
let
b1 = bitrand(s1, s2, s3, s4)
m1 = 1
m2 = 3
@check_bit_operation maximum(b1, (m1, m2)) BitArray{4}
@check_bit_operation minimum(b1, (m1, m2)) BitArray{4}
@check_bit_operation sum(b1, (m1, m2)) Array{Int,4}

@check_bit_operation maximum(b1) Bool
@check_bit_operation minimum(b1) Bool
@check_bit_operation any(b1) Bool
@check_bit_operation all(b1) Bool
@check_bit_operation sum(b1) Int

b0 = falses(0)
@check_bit_operation any(b0) Bool
@check_bit_operation all(b0) Bool
@check_bit_operation sum(b0) Int
end

timesofar("reductions")

Expand Down

0 comments on commit 9b029a3

Please sign in to comment.