diff --git a/base/reshapedarray.jl b/base/reshapedarray.jl index 52a8fe755e0ae..466e26953bd56 100644 --- a/base/reshapedarray.jl +++ b/base/reshapedarray.jl @@ -11,28 +11,28 @@ ReshapedArray{T,N}(parent::AbstractArray{T}, dims::NTuple{N,Int}, mi) = Reshaped typealias ReshapedArrayLF{T,N,P<:AbstractArray} ReshapedArray{T,N,P,Tuple{}} # Fast iteration on ReshapedArrays: use the parent iterator -immutable ReshapedRange{I,M} +immutable ReshapedArrayIterator{I,M} iter::I mi::NTuple{M,SignedMultiplicativeInverse{Int}} end -ReshapedRange(A::ReshapedArray) = reshapedrange(parent(A), A.mi) -function reshapedrange{M}(P, mi::NTuple{M}) +ReshapedArrayIterator(A::ReshapedArray) = _rs_iterator(parent(A), A.mi) +function _rs_iterator{M}(P, mi::NTuple{M}) iter = eachindex(P) - ReshapedRange{typeof(iter),M}(iter, mi) + ReshapedArrayIterator{typeof(iter),M}(iter, mi) end immutable ReshapedIndex{T} parentindex::T end -# eachindex(A::ReshapedArray) = ReshapedRange(A) # TODO: uncomment this line -start(R::ReshapedRange) = start(R.iter) -@inline done(R::ReshapedRange, i) = done(R.iter, i) -@inline function next(R::ReshapedRange, i) +# eachindex(A::ReshapedArray) = ReshapedArrayIterator(A) # TODO: uncomment this line +start(R::ReshapedArrayIterator) = start(R.iter) +@inline done(R::ReshapedArrayIterator, i) = done(R.iter, i) +@inline function next(R::ReshapedArrayIterator, i) item, inext = next(R.iter, i) ReshapedIndex(item), inext end -length(R::ReshapedRange) = length(R.iter) +length(R::ReshapedArrayIterator) = length(R.iter) function reshape(parent::AbstractArray, dims::Dims) prod(dims) == length(parent) || throw(DimensionMismatch("parent has $(length(parent)) elements, which is incompatible with size $dims")) @@ -84,19 +84,61 @@ reinterpret{T}(::Type{T}, A::ReshapedArray, dims::Dims) = reinterpret(T, parent( ind2sub_rs((d+1, out...), tail(strds), r) end -@inline getindex(A::ReshapedArrayLF, index::Int) = (@boundscheck checkbounds(A, index); @inbounds ret = parent(A)[index]; ret) -@inline getindex(A::ReshapedArray, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_getindex(A, indexes...)) -@inline getindex(A::ReshapedArray, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds ret = parent(A)[index.parentindex]; ret) +@inline function getindex(A::ReshapedArrayLF, index::Int) + @boundscheck checkbounds(A, index) + @inbounds ret = parent(A)[index] + ret +end +@inline function getindex(A::ReshapedArray, indexes::Int...) + @boundscheck checkbounds(A, indexes...) + _unsafe_getindex(A, indexes...) +end +@inline function getindex(A::ReshapedArray, index::ReshapedIndex) + @boundscheck checkbounds(parent(A), index.parentindex) + @inbounds ret = parent(A)[index.parentindex] + ret +end + +@inline function _unsafe_getindex(A::ReshapedArray, indexes::Int...) + @inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] + ret +end +@inline function _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...) + @inbounds ret = parent(A)[sub2ind(size(A), indexes...)] + ret +end -@inline _unsafe_getindex(A::ReshapedArray, indexes::Int...) = (@inbounds ret = parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...]; ret) -@inline _unsafe_getindex(A::ReshapedArrayLF, indexes::Int...) = (@inbounds ret = parent(A)[sub2ind(size(A), indexes...)]; ret) +@inline function setindex!(A::ReshapedArrayLF, val, index::Int) + @boundscheck checkbounds(A, index) + @inbounds parent(A)[index] = val + val +end +@inline function setindex!(A::ReshapedArray, val, indexes::Int...) + @boundscheck checkbounds(A, indexes...) + _unsafe_setindex!(A, val, indexes...) +end +@inline function setindex!(A::ReshapedArray, val, index::ReshapedIndex) + @boundscheck checkbounds(parent(A), index.parentindex) + @inbounds parent(A)[index.parentindex] = val + val +end + +@inline function _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...) + @inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val + val +end +@inline function _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...) + @inbounds parent(A)[sub2ind(size(A), indexes...)] = val + val +end -@inline setindex!(A::ReshapedArrayLF, val, index::Int) = (@boundscheck checkbounds(A, index); @inbounds parent(A)[index] = val; val) -@inline setindex!(A::ReshapedArray, val, indexes::Int...) = (@boundscheck checkbounds(A, indexes...); _unsafe_setindex!(A, val, indexes...)) -@inline setindex!(A::ReshapedArray, val, index::ReshapedIndex) = (@boundscheck checkbounds(parent(A), index.parentindex); @inbounds parent(A)[index.parentindex] = val; val) +# helpful error message for a common failure case +typealias ReshapedRange{T,N,A<:Range} ReshapedArray{T,N,A,Tuple{}} +setindex!(A::ReshapedRange, val, index::Int) = _rs_setindex!_err() +setindex!(A::ReshapedRange, val, indexes::Int...) = _rs_setindex!_err() +setindex!(A::ReshapedRange, val, index::ReshapedIndex) = _rs_setindex!_err() -@inline _unsafe_setindex!(A::ReshapedArray, val, indexes::Int...) = (@inbounds parent(A)[ind2sub_rs(A.mi, sub2ind(size(A), indexes...))...] = val; val) -@inline _unsafe_setindex!(A::ReshapedArrayLF, val, indexes::Int...) = (@inbounds parent(A)[sub2ind(size(A), indexes...)] = val; val) +_rs_setindex!_err() = error("indexed assignment fails for a reshaped range; consider calling collect") typealias ArrayT{N, T} Array{T,N} convert{T,S,N}(::Type{Array{T,N}}, V::ReshapedArray{S,N}) = copy!(Array(T, size(V)), V) diff --git a/test/arrayops.jl b/test/arrayops.jl index 9867e062af864..9022fc0c18ce6 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -88,13 +88,50 @@ a = reshape(b, (2, 2, 2, 2, 2)) @test a[2,2,2,2,2] == b[end] # reshaping linearslow arrays -a = zeros(1, 5) +a = collect(reshape(1:5, 1, 5)) s = sub(a, :, [2,3,5]) -@test length(reshape(s, length(s))) == 3 +r = reshape(s, length(s)) +@test length(r) == 3 +@test r[1] == 2 +@test r[3,1] == 5 +@test r[Base.ReshapedIndex(CartesianIndex((1,2)))] == 3 +@test parent(reshape(r, (1,3))) === r.parent === s +@test parentindexes(r) == (1:1, 1:3) +@test reshape(r, (3,)) === r +r[2] = -1 +@test a[3] == -1 a = zeros(0, 5) # an empty linearslow array s = sub(a, :, [2,3,5]) @test length(reshape(s, length(s))) == 0 +@test reshape(1:5, (5,)) === 1:5 +@test reshape(1:5, 5) === 1:5 + +# setindex! on a reshaped range +a = reshape(1:20, 5, 4) +for idx in ((3,), (2,2), (Base.ReshapedIndex(1),)) + try + a[idx...] = 7 + catch err + @test err.msg == "indexed assignment fails for a reshaped range; consider calling collect" + end +end + +# operations with LinearFast ReshapedArray +b = collect(1:12) +a = Base.ReshapedArray(b, (4,3), ()) +@test a[3,2] == 7 +@test a[6] == 6 +a[3,2] = -2 +a[6] = -3 +a[Base.ReshapedIndex(5)] = -4 +@test b[5] == -4 +@test b[6] == -3 +@test b[7] == -2 +b = reinterpret(Int, a, (3,4)) +b[1] = -1 +@test vec(b) == vec(a) + a = rand(1, 1, 8, 8, 1) @test @inferred(squeeze(a, 1)) == @inferred(squeeze(a, (1,))) == reshape(a, (1, 8, 8, 1)) @test @inferred(squeeze(a, (1, 5))) == squeeze(a, (5, 1)) == reshape(a, (1, 8, 8)) diff --git a/test/bitarray.jl b/test/bitarray.jl index 3429fe46c51ba..e91156352385b 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -324,22 +324,21 @@ t1 = bitrand(n1, n2) b2 = bitrand(countnz(t1)) @check_bit_operation setindex!(b1, b2, t1) BitMatrix -m1 = rand(1:n1) -m2 = rand(1:n2) - -t1 = bitrand(n1) -b2 = bitrand(countnz(t1), m2) -k2 = randperm(m2) -@check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix -@check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix -@check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix - -t2 = bitrand(n2) -b2 = bitrand(m1, countnz(t2)) -k1 = randperm(m1) -@check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix -@check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix -@check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix +let m1 = rand(1:n1), m2 = rand(1:n2) + t1 = bitrand(n1) + b2 = bitrand(countnz(t1), m2) + k2 = randperm(m2) + @check_bit_operation setindex!(b1, b2, t1, 1:m2) BitMatrix + @check_bit_operation setindex!(b1, b2, t1, n2-m2+1:n2) BitMatrix + @check_bit_operation setindex!(b1, b2, t1, k2) BitMatrix + + t2 = bitrand(n2) + b2 = bitrand(m1, countnz(t2)) + k1 = randperm(m1) + @check_bit_operation setindex!(b1, b2, 1:m1, t2) BitMatrix + @check_bit_operation setindex!(b1, b2, n1-m1+1:n1, t2) BitMatrix + @check_bit_operation setindex!(b1, b2, k1, t2) BitMatrix +end timesofar("indexing") @@ -1054,23 +1053,25 @@ end ## Reductions ## -b1 = bitrand(s1, s2, s3, s4) -m1 = 1 -m2 = 3 -@check_bit_operation maximum(b1, (m1, m2)) BitArray{4} -@check_bit_operation minimum(b1, (m1, m2)) BitArray{4} -@check_bit_operation sum(b1, (m1, m2)) Array{Int,4} - -@check_bit_operation maximum(b1) Bool -@check_bit_operation minimum(b1) Bool -@check_bit_operation any(b1) Bool -@check_bit_operation all(b1) Bool -@check_bit_operation sum(b1) Int - -b0 = falses(0) -@check_bit_operation any(b0) Bool -@check_bit_operation all(b0) Bool -@check_bit_operation sum(b0) Int +let + b1 = bitrand(s1, s2, s3, s4) + m1 = 1 + m2 = 3 + @check_bit_operation maximum(b1, (m1, m2)) BitArray{4} + @check_bit_operation minimum(b1, (m1, m2)) BitArray{4} + @check_bit_operation sum(b1, (m1, m2)) Array{Int,4} + + @check_bit_operation maximum(b1) Bool + @check_bit_operation minimum(b1) Bool + @check_bit_operation any(b1) Bool + @check_bit_operation all(b1) Bool + @check_bit_operation sum(b1) Int + + b0 = falses(0) + @check_bit_operation any(b0) Bool + @check_bit_operation all(b0) Bool + @check_bit_operation sum(b0) Int +end timesofar("reductions")