Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mbauman committed Jun 24, 2020
1 parent 31e9517 commit 65f4046
Showing 1 changed file with 138 additions and 0 deletions.
138 changes: 138 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -978,3 +978,141 @@ end
@test Core.sizeof(arrayOfUInt48) == 24
end
end

struct Strider{T,N} <: AbstractArray{T,N}
data::Vector{T}
offset::Int
strides::NTuple{N,Int}
size::NTuple{N,Int}
end
function Strider{T}(strides::NTuple{N}, size::NTuple{N}) where {T,N}
offset = 1-sum(strides .* (strides .< 0) .* (size .- 1))
data = Array{T}(undef, sum(abs.(strides) .* (size .- 1)) + 1)
return Strider{T, N, Vector{T}}(data, offset, strides, size)
end
function Strider(vec::AbstractArray{T}, strides::NTuple{N}, size::NTuple{N}) where {T,N}
offset = 1-sum(strides .* (strides .< 0) .* (size .- 1))
@assert length(vec) >= sum(abs.(strides) .* (size .- 1)) + 1
return Strider{T, N}(vec, offset, strides, size)
end
Base.size(S::Strider) = S.size
function Base.getindex(S::Strider{<:Any,N}, I::Vararg{Int,N}) where {N}
return S.data[sum(S.strides .* (I .- 1)) + S.offset]
end
Base.strides(S::Strider) = S.strides
Base.elsize(::Type{<:Strider{T}}) where {T} = Base.elsize(Vector{T})
Base.unsafe_convert(::Type{Ptr{T}}, S::Strider{T}) where {T} = pointer(S.data, S.offset)

@testset "Simple 3d strided views and permutes" for sz in ((5, 3, 2), (7, 11, 13))
A = collect(reshape(1:prod(sz), sz))
S = Strider(vec(A), strides(A), sz)
@test pointer(A) == pointer(S)
for i in 1:prod(sz)
@test pointer(A, i) == pointer(S, i)
@test A[i] == S[i]
end
for idxs in ((1:sz[1], 1:sz[2], 1:sz[3]),
(1:sz[1], 2:2:sz[2], sz[3]:-1:1),
(2:2:sz[1]-1, sz[2]:-1:1, sz[3]:-2:2),
(sz[1]:-1:1, sz[2]:-1:1, sz[3]:-1:1),
(sz[1]-1:-3:1, sz[2]:-2:3, 1:sz[3]),)
Av = view(A, idxs...)
Sv = view(S, idxs...)
Ss = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs))
@test pointer(Av) == pointer(Sv) == pointer(Ss)
for i in 1:length(Av)
@test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i)
@test Av[i] == Sv[i] == Ss[i]
end
for perm in ((3, 2, 1), (2, 1, 3), (3, 1, 2))
P = permutedims(A, perm)
Ap = Base.PermutedDimsArray(A, perm)
Sp = Base.PermutedDimsArray(S, perm)
Ps = Strider{Int, 3}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
@test pointer(Ap) == pointer(Sp) == pointer(Ps)
for i in 1:length(Ap)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i)
@test P[i] == Ap[i] == Sp[i] == Ps[i]
end
Pv = view(P, idxs[collect(perm)]...)
Apv = view(Ap, idxs[collect(perm)]...)
Spv = view(Sp, idxs[collect(perm)]...)
Pvs = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
@test pointer(Apv) == pointer(Spv) == pointer(Pvs)
for i in 1:length(Apv)
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i)
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i]
end
Vp = permutedims(Av, perm)
Avp = Base.PermutedDimsArray(Av, perm)
Svp = Base.PermutedDimsArray(Sv, perm)
@test pointer(Avp) == pointer(Svp)
for i in 1:length(Avp)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Avp, i) == pointer(Svp, i)
@test Vp[i] == Avp[i] == Svp[i]
end
end
end
end

@testset "simple 2d strided views, permutes, transposes" for sz in ((5, 3), (7, 11))
A = collect(reshape(1:prod(sz), sz))
S = Strider(vec(A), strides(A), sz)
@test pointer(A) == pointer(S)
for i in 1:prod(sz)
@test pointer(A, i) == pointer(S, i)
@test A[i] == S[i]
end
for idxs in ((1:sz[1], 1:sz[2]),
(1:sz[1], 2:2:sz[2]),
(2:2:sz[1]-1, sz[2]:-1:1),
(sz[1]:-1:1, sz[2]:-1:1),
(sz[1]-1:-3:1, sz[2]:-2:3),)
Av = view(A, idxs...)
Sv = view(S, idxs...)
Ss = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs))
@test pointer(Av) == pointer(Sv) == pointer(Ss)
for i in 1:length(Av)
@test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i)
@test Av[i] == Sv[i] == Ss[i]
end
perm = (2, 1)
P = permutedims(A, perm)
Ap = Base.PermutedDimsArray(A, perm)
At = transpose(A)
Aa = adjoint(A)
Sp = Base.PermutedDimsArray(S, perm)
Ps = Strider{Int, 2}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
@test pointer(Ap) == pointer(Sp) == pointer(Ps) == pointer(At) == pointer(Aa)
for i in 1:length(Ap)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
@test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
@test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i]
end
Pv = view(P, idxs[collect(perm)]...)
Apv = view(Ap, idxs[collect(perm)]...)
Atv = view(At, idxs[collect(perm)]...)
Ata = view(Aa, idxs[collect(perm)]...)
Spv = view(Sp, idxs[collect(perm)]...)
Pvs = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
@test pointer(Apv) == pointer(Spv) == pointer(Pvs) == pointer(Atv) == pointer(Ata)
for i in 1:length(Apv)
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i)
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i]
end
Vp = permutedims(Av, perm)
Avp = Base.PermutedDimsArray(Av, perm)
Avt = transpose(Av)
Ava = adjoint(Av)
Svp = Base.PermutedDimsArray(Sv, perm)
@test pointer(Avp) == pointer(Svp) == pointer(Avt) == pointer(Ava)
for i in 1:length(Avp)
# This is intentionally disabled due to ambiguity
@test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i)
@test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i]
end
end
end

0 comments on commit 65f4046

Please sign in to comment.