From 2fb9badcd9f2be87148fd0b472912bd123a63c02 Mon Sep 17 00:00:00 2001 From: Carlo Baldassi Date: Fri, 31 Jan 2014 10:04:14 +0100 Subject: [PATCH] bitarray indexing: sync with array + refactor * Use setindex_shape_check and DimensionMismatch instead of generic errors; this (and a couple of extra minor fixes) makes BitArrays behave like Arrays (again) * Move portions of the indexing code to multidimensional.jl * Restrict signatures so as to avoid to_index and convert, then create very generic methods which do the conversion and the dispatch: this greatly simplifies the logic and removes most of the need for disambiguation. Should also improve code generation. --- base/bitarray.jl | 267 ++++++--------------------------------- base/multidimensional.jl | 156 +++++++++++++++-------- base/operators.jl | 4 +- 3 files changed, 142 insertions(+), 285 deletions(-) diff --git a/base/bitarray.jl b/base/bitarray.jl index b1f10144d6e85..7face7de2064a 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -217,9 +217,10 @@ falses(args...) = fill!(BitArray(args...), false) trues(args...) = fill!(BitArray(args...), true) function one(x::BitMatrix) - m, n = size(x) - a = falses(size(x)) - for i = 1 : min(m,n) + m,n = size(x) + m == n || throw(DimensionMismatch("multiplicative identity defined only for square matrices")) + a = falses(n, n) + for i = 1:n a[i,i] = true end return a @@ -258,7 +259,7 @@ end function reshape{N}(B::BitArray, dims::NTuple{N,Int}) if prod(dims) != length(B) - error("new dimensions $(dims) inconsistent with the array length $(length(B))") + throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $(length(B))")) end Br = BitArray{N}(ntuple(N,i->0)...) Br.chunks = B.chunks @@ -320,18 +321,7 @@ end convert{N}(::Type{BitArray{N}}, B::BitArray{N}) = B reinterpret{N}(::Type{Bool}, B::BitArray, dims::NTuple{N,Int}) = reinterpret(B, dims) -function reinterpret{N}(B::BitArray, dims::NTuple{N,Int}) - if prod(dims) != length(B) - error("new dimensions $(dims) are inconsistent with array length $(length(B))") - end - A = BitArray{N}(ntuple(N,i->0)...) - A.chunks = B.chunks - A.len = prod(dims) - if N != 1 - A.dims = dims - end - return A -end +reinterpret{N}(B::BitArray, dims::NTuple{N,Int}) = reshape(B, dims) # shorthand forms BitArray <-> Array bitunpack{N}(B::BitArray{N}) = convert(Array{Bool,N}, B) @@ -359,9 +349,7 @@ function getindex_unchecked(Bc::Vector{Uint64}, i::Int) end function getindex(B::BitArray, i::Int) - if i < 1 || i > length(B) - throw(BoundsError()) - end + 1 <= i <= length(B) || throw(BoundsError()) return getindex_unchecked(B.chunks, i) end @@ -372,58 +360,6 @@ getindex(B::BitArray) = getindex(B, 1) # 0d bitarray getindex(B::BitArray{0}) = getindex_unchecked(B.chunks, 1) -function getindex(B::BitArray, i1::Real, i2::Real) - #checkbounds(B, i0, i1) # manually inlined for performance - i1, i2 = to_index(i1, i2) - l1 = size(B,1) - 1 <= i1 <= l1 || throw(BoundsError()) - return B[i1 + l1*(i2-1)] -end -function getindex(B::BitArray, i1::Real, i2::Real, i3::Real) - #checkbounds(B, i0, i1, i2) # manually inlined for performance - i1, i2, i3 = to_index(i1, i2, i3) - l1 = size(B,1) - 1 <= i1 <= l1 || throw(BoundsError()) - l2 = size(B,2) - 1 <= i2 <= l2 || throw(BoundsError()) - return B[i1 + l1*((i2-1) + l2*(i3-1))] -end -function getindex(B::BitArray, i1::Real, i2::Real, i3::Real, i4::Real) - #checkbounds(B, i1, i2, i3, i4) - i1, i2, i3, i4 = to_index(i1, i2, i3, i4) - l1 = size(B,1) - 1 <= i1 <= l1 || throw(BoundsError()) - l2 = size(B,2) - 1 <= i2 <= l2 || throw(BoundsError()) - l3 = size(B,3) - 1 <= i3 <= l3 || throw(BoundsError()) - return B[i1 + l1*((i2-1) + l2*((i3-1) + l3*(i4-1)))] -end - -function getindex(B::BitArray, I::Real...) - #checkbounds(B, I...) # inlined for performance - #I = to_index(I) # inlined for performance - ndims = length(I) - i = to_index(I[1]) - l = size(B,1) - 1 <= i <= l || throw(BoundsError()) - index = i - stride = 1 - for k = 2:ndims-1 - stride *= l - i = to_index(I[k]) - l = size(B,k) - 1 <= i <= l || throw(BoundsError()) - index += (i-1) * stride - end - stride *= l - i = to_index(I[ndims]) - index += (i-1) * stride - return B[index] -end - -# note: the Range1{Int} case is still handled by the version above -# (which is fine) function getindex{T<:Real}(B::BitArray, I::AbstractVector{T}) X = BitArray(length(I)) lB = length(B) @@ -432,9 +368,9 @@ function getindex{T<:Real}(B::BitArray, I::AbstractVector{T}) ind = 1 for i in I # faster X[ind] = B[i] - i = to_index(i) - 1 <= i <= lB || throw(BoundsError()) - setindex_unchecked(Xc, getindex_unchecked(Bc, i), ind) + j = to_index(i) + 1 <= j <= lB || throw(BoundsError()) + setindex_unchecked(Xc, getindex_unchecked(Bc, j), ind) ind += 1 end return X @@ -442,30 +378,24 @@ end # logical indexing -function getindex_bool_1d(B::BitArray, I::AbstractArray{Bool}) - n = sum(I) - X = BitArray(n) - lI = length(I) - if lI != length(B) - throw(BoundsError()) - end - Xc = X.chunks - Bc = B.chunks - ind = 1 - for i = 1:length(I) - if I[i] - # faster X[ind] = B[i] - setindex_unchecked(Xc, getindex_unchecked(Bc, i), ind) - ind += 1 +# (multiple signatures for disambiguation) +for IT in [AbstractVector{Bool}, AbstractArray{Bool}] + @eval function getindex(B::BitArray, I::$IT) + checkbounds(B, I) + n = sum(I) + X = BitArray(n) + Xc = X.chunks + Bc = B.chunks + ind = 1 + for i = 1:length(I) + if I[i] + # faster X[ind] = B[i] + setindex_unchecked(Xc, getindex_unchecked(Bc, i), ind) + ind += 1 + end end + return X end - return X -end - -# multiple signatures required for disambiguation -# (see also getindex in multidimensional.jl) -for BT in [BitVector, BitArray], IT in [Range1{Bool}, AbstractVector{Bool}, AbstractArray{Bool}] - @eval getindex(B::$BT, I::$IT) = getindex_bool_1d(B, I) end ## Indexing: setindex! ## @@ -482,133 +412,31 @@ function setindex_unchecked(Bc::Array{Uint64}, x::Bool, i::Int) end end +setindex!(B::BitArray, x::Bool) = setindex!(B, x, 1) + function setindex!(B::BitArray, x::Bool, i::Int) - if i < 1 || i > length(B) - throw(BoundsError()) - end + 1 <= i <= length(B) || throw(BoundsError()) setindex_unchecked(B.chunks, x, i) return B end -setindex!(B::BitArray, x) = setindex!(B, x, 1) - -setindex!(B::BitArray, x, i::Real) = setindex!(B, convert(Bool,x), to_index(i)) - -function setindex!(B::BitArray, x, i1::Real, i2::Real) - #checkbounds(B, i0, i1) # manually inlined for performance - i1, i2 = to_index(i1, i2) - l1 = size(B,1) - 1 <= i1 <= l1 || throw(BoundsError()) - B[i1 + l1*(i2-1)] = x - return B -end - -function setindex!(B::BitArray, x, i1::Real, i2::Real, i3::Real) - #checkbounds(B, i1, i2, i3) # manually inlined for performance - i1, i2, i3 = to_index(i1, i2, i3) - l1 = size(B,1) - 1 <= i1 <= l1 || throw(BoundsError()) - l2 = size(B,2) - 1 <= i2 <= l2 || throw(BoundsError()) - B[i1 + l1*((i2-1) + l2*(i3-1))] = x - return B -end - -function setindex!(B::BitArray, x, i1::Real, i2::Real, i3::Real, i4::Real) - #checkbounds(B, i1, i2, i3, i4) # manually inlined for performance - i1, i2, i3, i4 = to_index(i1, i2, i3, i4) - l1 = size(B,1) - 1 <= i1 <= l1 || throw(BoundsError()) - l2 = size(B,2) - 1 <= i2 <= l2 || throw(BoundsError()) - l3 = size(B,3) - 1 <= i3 <= l3 || throw(BoundsError()) - B[i1 + l1*((i2-1) + l2*((i3-1) + l3*(i4-1)))] = x - return B -end - -function setindex!(B::BitArray, x, i::Real, I::Real...) - #checkbounds(B, I...) # inlined for performance - #I = to_index(I) # inlined for performance - ndims = length(I) + 1 - i = to_index(i) - l = size(B,1) - 1 <= i <= l || throw(BoundsError()) - index = i - stride = 1 - for k = 2:ndims-1 - stride *= l - l = size(B,k) - i = to_index(I[k-1]) - 1 <= i <= l || throw(BoundsError()) - index += (i-1) * stride - end - stride *= l - i = to_index(I[ndims-1]) - index += (i-1) * stride - B[index] = x - return B -end - -function setindex!{T<:Real}(B::BitArray, X::AbstractArray, I::AbstractVector{T}) - if length(X) != length(I); error("argument dimensions must match"); end - count = 1 - for i in I - B[i] = X[count] - count += 1 - end - return B -end - -function setindex!(B::BitArray, X::AbstractArray, i0::Real) - if length(X) != 1 - error("argument dimensions must match") - end - return setindex!(B, X[1], i0) -end - -function setindex!(B::BitArray, X::AbstractArray, i0::Real, i1::Real) - if length(X) != 1 - error("argument dimensions must match") - end - return setindex!(B, X[1], i0, i1) -end - -function setindex!(B::BitArray, X::AbstractArray, I0::Real, I::Real...) - if length(X) != 1 - error("argument dimensions must match") - end - return setindex!(B, X[1], i0, I...) -end - -function setindex!{T<:Real}(B::BitArray, x, I::AbstractVector{T}) - x = convert(Bool, x) - for i in I - B[i] = x - end - return B -end - # logical indexing -function setindex_bool_1d(A::BitArray, x, I::AbstractArray{Bool}) - if length(I) > length(A) - throw(BoundsError()) - end +function setindex!(A::BitArray, x, I::AbstractArray{Bool}) + checkbounds(A, I) + y = convert(Bool, x) Ac = A.chunks for i = 1:length(I) if I[i] - # faster A[i] = x - setindex_unchecked(Ac, convert(Bool, x), i) + # faster A[i] = y + setindex_unchecked(Ac, y, i) end end A end -function setindex_bool_1d(A::BitArray, X::AbstractArray, I::AbstractArray{Bool}) - if length(I) > length(A) - throw(BoundsError()) - end +function setindex!(A::BitArray, X::AbstractArray, I::AbstractArray{Bool}) + checkbounds(A, I) Ac = A.chunks c = 1 for i = 1:length(I) @@ -618,27 +446,10 @@ function setindex_bool_1d(A::BitArray, X::AbstractArray, I::AbstractArray{Bool}) c += 1 end end - A -end - -# lots of definitions here are required just for disambiguation -# (see also setindex! in multidimensional.jl) -for XT in [BitArray, AbstractArray, Any] - for IT in [AbstractVector{Bool}, AbstractArray{Bool}] - @eval setindex!(A::BitArray, X::$XT, I::$IT) = setindex_bool_1d(A, X, I) - end - - for IT in [Range1{Bool}, AbstractVector{Bool}], JT in [Range1{Bool}, AbstractVector{Bool}] - @eval setindex!(A::BitMatrix, x::$XT, I::$IT, J::$JT) = (A[find(I),find(J)] = x; A) + if length(X) != c-1 + throw(DimensionMismatch("assigned $(length(X)) elements to length $(c-1) destination")) end - - for IT in [Range1{Bool}, AbstractVector{Bool}], JT in [Real, Range1] - @eval setindex!(A::BitMatrix, x::$XT, I::$IT, J::$JT) = (A[find(I),J] = x; A) - end - @eval setindex!{T<:Real}(A::BitMatrix, x::$XT, I::AbstractVector{Bool}, J::AbstractVector{T}) = (A[find(I),J] = x; A) - - @eval setindex!(A::BitMatrix, x::$XT, I::Real, J::AbstractVector{Bool}) = (A[I,find(J)] = x; A) - @eval setindex!{T<:Real}(A::BitMatrix, x::$XT, I::AbstractVector{T}, J::AbstractVector{Bool}) = (A[I,find(J)] = x; A) + A end ## Dequeue functionality ## diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 648804cd87bcb..1271a3dd9b103 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -1,7 +1,7 @@ ### From array.jl @ngenerate N function _checksize(A::AbstractArray, I::NTuple{N, Any}...) - @nexprs N d->(size(A, d) == length(I_d) || throw(DimensionMismatch("Index $d has length $(length(I_d)), but size(A, $d) = $(size(A,d))"))) + @nexprs N d->(size(A, d) == length(I_d) || throw(DimensionMismatch("index $d has length $(length(I_d)), but size(A, $d) = $(size(A,d))"))) nothing end checksize(A, I) = (_checksize(A, I); return nothing) @@ -159,32 +159,43 @@ end fill!(A::AbstractArray, x) = (_fill!(A, x); return A) -### from bitarray.jl +### BitArrays -# note: we can gain some performance if the first dimension is a range; -# but we need to single-out the N=0 case due to how @ngenerate works -# case N = 0 -function getindex(B::BitArray, I0::Range1) - ndims(B) < 1 && error("wrong number of dimensions") +## getindex + +# general scalar indexing with two or more indices +# (uses linear indexing, which performs the final bounds check and +# is defined in bitarray.jl) + +@ngenerate N function getindex(B::BitArray, I_0::Int, I::NTuple{N,Int}...) + stride = 1 + index = I_0 + @nexprs N d->begin + l = size(B,d) + stride *= l + 1 <= I_{d-1} <= l || throw(BoundsError()) + index += (I_d - 1) * stride + end + return B[index] +end + +# contiguous multidimensional indexing: if the first dimension is a range, +# we can get some performance from using copy_chunks + +function getindex(B::BitArray, I0::Range1{Int}) checkbounds(B, I0) X = BitArray(length(I0)) copy_chunks(X.chunks, 1, B.chunks, first(I0), length(I0)) return X end -# TODO: extend to I:Union(Real,AbstractArray)... (i.e. not necessarily contiguous) -@ngenerate N function getindex(B::BitArray, I0::Range1, I::NTuple{N,Union(Real,Range1)}...) - ndims(B) < N+1 && error("wrong number of dimensions") +@ngenerate N function getindex(B::BitArray, I0::Range1{Int}, I::NTuple{N,Union(Int,Range1{Int})}...) checkbounds(B, I0, I...) X = BitArray(index_shape(I0, I...)) - I0 = to_index(I0) - f0 = first(I0) l0 = length(I0) - Base.@nexprs N d->(I_d = to_index(I_d)) - gap_lst_1 = 0 @nexprs N d->(gap_lst_{d+1} = length(I_d)) stride = 1 @@ -207,9 +218,10 @@ end return X end -@ngenerate N function getindex(B::BitArray, I::NTuple{N,Union(Real,AbstractVector)}...) +# general multidimensional non-scalar indexing + +@ngenerate N function getindex(B::BitArray, I::NTuple{N,Union(Int,AbstractVector{Int})}...) checkbounds(B, I...) - @nexprs N d->(I_d = to_index(I_d)) X = BitArray(index_shape(I...)) Xc = X.chunks @@ -221,34 +233,49 @@ end return X end -# note: we can gain some performance if the first dimension is a range; -# case N = 0 -function setindex!(B::BitArray, X::BitArray, I0::Range1) - ndims(B) != 1 && error("wrong number of dimensions in assigment") - I0 = to_index(I0) +# general version with Real (or logical) indexing which dispatches on the appropriate method + +@ngenerate N function getindex(B::BitArray, I::NTuple{N,Union(Real,AbstractVector)}...) + @nexprs N d->(J_d = to_index(I_d)) + return @nref N B J +end + +## setindex! + +# general scalar indexing with two or more indices +# (uses linear indexing, which performs the final bounds check and +# is defined in bitarray.jl) + +@ngenerate N function setindex!(B::BitArray, x::Bool, I_0::Int, I::NTuple{N,Int}...) + stride = 1 + index = I_0 + @nexprs N d->begin + l = size(B,d) + stride *= l + 1 <= I_{d-1} <= l || throw(BoundsError()) + index += (I_d - 1) * stride + end + B[index] = x + return B +end + +# contiguous multidimensional indexing: if the first dimension is a range, +# we can get some performance from using copy_chunks + +function setindex!(B::BitArray, X::BitArray, I0::Range1{Int}) checkbounds(B, I0) - lI = length(I0) - length(X) != lI && error("array assignment dimensions mismatch") - lI == 0 && return B - f0 = first(I0) + setindex_shape_check(X, I0) l0 = length(I0) + l0 == 0 && return B + f0 = first(I0) copy_chunks(B.chunks, f0, X.chunks, 1, l0) return B end -# TODO: extend to I:Union(Real,AbstractArray)... (i.e. not necessarily contiguous) -@ngenerate N function setindex!(B::BitArray, X::BitArray, I0::Range1, I::NTuple{N,Union(Real,Range1)}...) - ndims(B) != N+1 && error("wrong number of dimensions in assigment") - I0 = to_index(I0) - lI = length(I0) - - @nexprs N d->begin - I_d = to_index(I_d) - lI *= length(I_d) - end - length(X) != lI && error("array assignment dimensions mismatch") +@ngenerate N function setindex!(B::BitArray, X::BitArray, I0::Range1{Int}, I::NTuple{N,Union(Int,Range1{Int})}...) checkbounds(B, I0, I...) - lI == 0 && return B + setindex_shape_check(X, I0, I...) + length(X) == 0 && return B f0 = first(I0) l0 = length(I0) @@ -275,17 +302,11 @@ end return B end -@ngenerate N function setindex!(B::BitArray, X::AbstractArray, I::NTuple{N,Union(Real,AbstractArray)}...) +# general multidimensional non-scalar indexing + +@ngenerate N function setindex!(B::BitArray, X::AbstractArray, I::NTuple{N,Union(Int,AbstractArray{Int})}...) checkbounds(B, I...) - @nexprs N d->(I_d = to_index(I_d)) - nel = 1 - @nexprs N d->(nel *= length(I_d)) - length(X) != nel && error("argument dimensions must match") - if ndims(X) > 1 - @nexprs N d->begin - size(X,d) != length(I_d) && error("argument dimensions must match") - end - end + setindex_shape_check(X, I...) refind = 1 @nloops N i d->I_d begin (@nref N B i) = X[refind] # TODO: should avoid bounds checking @@ -294,17 +315,39 @@ end return B end -@ngenerate N function setindex!(B::BitArray, x, I::NTuple{N,Union(Real,AbstractArray)}...) - x = convert(Bool, x) +@ngenerate N function setindex!(B::BitArray, x::Bool, I::NTuple{N,Union(Int,AbstractArray{Int})}...) checkbounds(B, I...) - @nexprs N d->(I_d = to_index(I_d)) - Bc = B.chunks + @nexprs N d->(length(I_d) == 0 && throw_setindex_mismatch(x, tuple(I...))) @nloops N i d->I_d begin (@nref N B i) = x # TODO: should avoid bounds checking end return B end +# general versions with Real (or logical) indexing which dispatch on the appropriate method + +# (multiple signatures for disambiguation) +for T in [Real, Union(Real, AbstractArray)] + @eval begin + @ngenerate N function setindex!(B::BitArray, x, I::NTuple{N,$T}...) + y = convert(Bool, x) + @nexprs N d->(J_d = to_index(I_d)) + (@nref N B J) = y + return B + end + @ngenerate N function setindex!(B::BitArray, X::AbstractArray, I::NTuple{N,$T}...) + @nexprs N d->(J_d = to_index(I_d)) + (@nref N B J) = X + return B + end + end +end +setindex!(B::BitArray, x) = setindex!(B, convert(Bool,x)) + + + +## findn + @ngenerate N function findn{N}(B::BitArray{N}) nnzB = nnz(B) I = ntuple(N, x->Array(Int, nnzB)) @@ -320,14 +363,16 @@ end return I end +## permutedims + for (V, PT, BT) in [((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)] - @eval begin - @ngenerate N function permutedims!{$(V...)}(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) + @eval @ngenerate N function permutedims!{$(V...)}(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) dimsB = size(B) - (length(perm) == N && isperm(perm)) || error("no valid permutation of dimensions") + length(perm) == N || error("expected permutation of size $N, but length(perm)=$(length(perm))") + isperm(perm) || error("input is not a permutation") dimsP = size(P) for i = 1:length(perm) - dimsP[i] == dimsB[perm[i]] || error("destination tensor of incorrect size") + dimsP[i] == dimsB[perm[i]] || throw(DimensionMismatch("destination tensor of incorrect size")) end #calculates all the strides @@ -355,5 +400,4 @@ for (V, PT, BT) in [((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)] return P end - end end diff --git a/base/operators.jl b/base/operators.jl index 72880bd067372..4ace0be29b9c6 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -282,8 +282,10 @@ to_index(i) = i to_index(i::Real) = convert(Int, i) to_index(i::Int) = i to_index(r::Range1{Int}) = r -to_index{T}(r::Range1{T}) = to_index(first(r)):to_index(last(r)) +to_index{T<:Real}(r::Range1{T}) = to_index(first(r)):to_index(last(r)) to_index(I::AbstractArray{Bool,1}) = find(I) +to_index(I::Range1{Bool}) = find(I) +to_index{T<:Real}(A::AbstractArray{T}) = int(A) to_index(i1, i2) = to_index(i1), to_index(i2) to_index(i1, i2, i3) = to_index(i1), to_index(i2), to_index(i3) to_index(i1, i2, i3, i4) = to_index(i1), to_index(i2), to_index(i3), to_index(i4)