Skip to content

Commit

Permalink
Fix #4270
Browse files Browse the repository at this point in the history
This also makes `squeeze` require a tuple and deprecates the version
that takes an iterator. I doubt there are many instances where the
dimensions to be squeezed aren't given at runtime.
  • Loading branch information
simonster committed Dec 8, 2014
1 parent dd03b5e commit c62abaa
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
36 changes: 27 additions & 9 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,38 @@ reshape(a::AbstractArray, dims::Int...) = reshape(a, dims)
vec(a::AbstractArray) = reshape(a,length(a))
vec(a::AbstractVector) = a

function squeeze(A::AbstractArray, dims)
d = ()
for i in 1:ndims(A)
if in(i,dims)
if size(A,i) != 1
error("squeezed dims must all be size 1")
function nextunsqueezeddim(dims, i)
while true
squeeze = false
for dim in dims
if dim == i
squeeze && error("squeezed dims must be unique")
squeeze = true
end
else
d = tuple(d..., size(A,i))
end
!squeeze && return i
i += 1
end
reshape(A, d)
end

stagedfunction squeeze(A::AbstractArray, dims::Dims)
n = ndims(A)
quote
if !($(Expr(:&&, [:(1 <= dims[$i] <= $n) for i = 1:length(dims)]...)))
error("squeezed dims must be in range [1, ndims(A)]")
elseif !($(Expr(:&&, [:(size(A, dims[$i]) == 1) for i = 1:length(dims)]...)))
error("squeezed dims must all be size 1")
end

dim_1 = nextunsqueezeddim(dims, 1)
$([:($(symbol("dim_$i")) = nextunsqueezeddim(dims, $(symbol("dim_$(i-1)"))+1)) for i = 2:n-length(dims)]...)

reshape(A, tuple($([:(size(A, $(symbol("dim_$i")))) for i = 1:n-length(dims)]...)))
end
end

squeeze(A::AbstractArray, dim::Integer) = squeeze(A, (int(dim),))

function copy!(dest::AbstractArray, src)
i = 1
for x in src
Expand Down
2 changes: 2 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,5 @@ const Uint128 = UInt128
@deprecate nextind(a::Any, i::Integer) i+1

@deprecate givens{T}(f::T, g::T, i1::Integer, i2::Integer, cols::Integer) givens(f, g, i1, i2)

@deprecate squeeze(X, dims) squeeze(X, tuple(dims...))
12 changes: 12 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ a = reshape(b, (2, 2, 2, 2, 2))
@test a[2,1,2,2,1] == b[14]
@test a[2,2,2,2,2] == b[end]

a = rand(1, 1, 8, 8, 1)
@test @inferred(squeeze(a, 1) == @inferred(squeeze(a, (1,))) == reshape(a, (1, 8, 8, 1))
@test @inferred(squeeze(a, (1, 5))) == squeeze(a, (5, 1)) == reshape(a, (1, 8, 8))
@test @inferred(squeeze(a, (1, 2, 5))) == squeeze(a, (5, 2, 1)) == reshape(a, (8, 8))
@test_throws ErrorException squeeze(a, 0)
@test_throws ErrorException squeeze(a, (1, 1))
@test_throws ErrorException squeeze(a, (1, 2, 1))
@test_throws ErrorException squeeze(a, (1, 1, 2))
@test_throws ErrorException squeeze(a, 3)
@test_throws ErrorException squeeze(a, 4)
@test_throws ErrorException squeeze(a, 6)

sz = (5,8,7)
A = reshape(1:prod(sz),sz...)
@test A[2:6] == [2:6]
Expand Down

0 comments on commit c62abaa

Please sign in to comment.