Skip to content

Commit

Permalink
Support pointer(A::PermDimsArray) and strides(A) if parent(A) support…
Browse files Browse the repository at this point in the history
…s them

(cherry picked from commit 2bc83ca)
ref #20385

Qualify PermutedDimsArray as it isn't exported on release-0.5
  • Loading branch information
timholy authored and tkelman committed Mar 1, 2017
1 parent a3982a7 commit dea17b4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
14 changes: 14 additions & 0 deletions base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ Base.parent(A::PermutedDimsArray) = A.parent
Base.size{T,N,perm}(A::PermutedDimsArray{T,N,perm}) = genperm(size(parent(A)), perm)
Base.indices{T,N,perm}(A::PermutedDimsArray{T,N,perm}) = genperm(indices(parent(A)), perm)

Base.unsafe_convert{T}(::Type{Ptr{T}}, A::PermutedDimsArray{T}) = Base.unsafe_convert(Ptr{T}, parent(A))

# It's OK to return a pointer to the first element, and indeed quite
# useful for wrapping C routines that require a different storage
# order than used by Julia. But for an array with unconventional
# storage order, a linear offset is ambiguous---is it a memory offset
# or a linear index?
Base.pointer{T}(A::PermutedDimsArray{T}, i::Integer) = throw(ArgumentError("pointer(A, i) is deliberately unsupported for PermutedDimsArray"))

function Base.strides{T,N,perm}(A::PermutedDimsArray{T,N,perm})
s = strides(parent(A))
ntuple(d->s[perm[d]], Val{N})
end

@inline function Base.getindex{T,N,perm,iperm}(A::PermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N})
@boundscheck checkbounds(A, I...)
@inbounds val = getindex(A.parent, genperm(I, iperm)...)
Expand Down
6 changes: 6 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ end
@test_throws ArgumentError permutedims(s, (1,1,1))
@test_throws ArgumentError Base.PermutedDimsArrays.PermutedDimsArray(a, (1,1,1))
@test_throws ArgumentError Base.PermutedDimsArrays.PermutedDimsArray(s, (1,1,1))
cp = Base.PermutedDimsArrays.PermutedDimsArray(c, (3,2,1))
@test pointer(cp) == pointer(c)
@test_throws ArgumentError pointer(cp, 2)
@test strides(cp) == (9,3,1)
ap = Base.PermutedDimsArrays.PermutedDimsArray(collect(a), (2,1,3))
@test strides(ap) == (3,1,12)

## ipermutedims ##

Expand Down

0 comments on commit dea17b4

Please sign in to comment.