Skip to content

Commit

Permalink
Merge pull request #9271 from JuliaLang/sjk/squeeze
Browse files Browse the repository at this point in the history
Fix #4270
  • Loading branch information
JeffBezanson committed Dec 8, 2014
2 parents e1a406e + e40318d commit d7655de
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
30 changes: 21 additions & 9 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,32 @@ reshape(a::AbstractArray, dims::Int...) = reshape(a, dims)
vec(a::AbstractArray) = reshape(a,length(a))
vec(a::AbstractVector) = a

function squeeze(A::AbstractArray, dims)
argtail(x, rest...) = rest
tail(x::Tuple) = argtail(x...)

_sub(::(), ::()) = ()
_sub(t::Tuple, ::()) = t
_sub(t::Tuple, s::Tuple) = _sub(tail(t), tail(s))

function squeeze(A::AbstractArray, dims::Dims)
for i in 1:length(dims)
1 <= dims[i] <= ndims(A) || error("squeezed dims must be in range [1, ndims(A)]")
size(A, dims[i]) == 1 || error("squeezed dims must all be size 1")
for j = 1:i-1
dims[j] == dims[i] && error("squeezed dims must be unique")
end
end
d = ()
for i in 1:ndims(A)
if in(i,dims)
if size(A,i) != 1
error("squeezed dims must all be size 1")
end
else
d = tuple(d..., size(A,i))
for i = 1:ndims(A)
if !in(i, dims)
d = tuple(d..., size(A, i))
end
end
reshape(A, d)
reshape(A, d::typeof(_sub(size(A), dims)))
end

squeeze(A::AbstractArray, dim::Integer) = squeeze(A, (int(dim),))

function copy!(dest::AbstractArray, src)
i = 1
for x in src
Expand Down
2 changes: 2 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,5 @@ const Uint128 = UInt128
@deprecate nextind(a::Any, i::Integer) i+1

@deprecate givens{T}(f::T, g::T, i1::Integer, i2::Integer, cols::Integer) givens(f, g, i1, i2)

@deprecate squeeze(X, dims) squeeze(X, tuple(dims...))
12 changes: 12 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ a = reshape(b, (2, 2, 2, 2, 2))
@test a[2,1,2,2,1] == b[14]
@test a[2,2,2,2,2] == b[end]

a = rand(1, 1, 8, 8, 1)
@test @inferred(squeeze(a, 1)) == @inferred(squeeze(a, (1,))) == reshape(a, (1, 8, 8, 1))
@test @inferred(squeeze(a, (1, 5))) == squeeze(a, (5, 1)) == reshape(a, (1, 8, 8))
@test @inferred(squeeze(a, (1, 2, 5))) == squeeze(a, (5, 2, 1)) == reshape(a, (8, 8))
@test_throws ErrorException squeeze(a, 0)
@test_throws ErrorException squeeze(a, (1, 1))
@test_throws ErrorException squeeze(a, (1, 2, 1))
@test_throws ErrorException squeeze(a, (1, 1, 2))
@test_throws ErrorException squeeze(a, 3)
@test_throws ErrorException squeeze(a, 4)
@test_throws ErrorException squeeze(a, 6)

sz = (5,8,7)
A = reshape(1:prod(sz),sz...)
@test A[2:6] == [2:6]
Expand Down

0 comments on commit d7655de

Please sign in to comment.