Skip to content

Commit

Permalink
Similarly add pointer support
Browse files Browse the repository at this point in the history
  • Loading branch information
mbauman committed Apr 1, 2020
1 parent 5f6f68a commit ce847fd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,16 @@ IndexStyle(::Type{<:AdjOrTransAbsMat}) = IndexCartesian()
convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S, A.parent))
convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent))

# Strides for transposed strided arrays — but only if the elements are actually stored in memory
# Strides and pointer for transposed strided arrays — but only if the elements are actually stored in memory
Base.strides(A::Adjoint{<:Real, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
Base.strides(A::Transpose{<:Any, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
# For matrices it's slightly faster to use reverse and avoid calling stride twice
Base.strides(A::Adjoint{<:Real, <:StridedMatrix}) = reverse(strides(A.parent))
Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent))

Base.pointer(A::Adjoint{<:Real, <:Union{StridedVector,StridedMatrix}) = pointer(A.parent)
Base.pointer(A::Transpose{<:Any, <:Union{StridedVector,StridedMatrix}) = pointer(A.parent)

# for vectors, the semantics of the wrapped and unwrapped types differ
# so attempt to maintain both the parent and wrapper type insofar as possible
similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent))
Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,25 @@ end
@test strides(t(rand(3,2))) == (3, 1)
@test strides(t(view(rand(3, 2), :))) == (6, 1)
@test strides(t(view(rand(3, 2), :, 1:2))) == (3, 1)

A = rand(3)
@test pointer(t(A)) === pointer(A)
B = rand(3,1)
@test pointer(t(B)) === pointer(B)
end
@test_throws MethodError strides(Adjoint(rand(3) .+ rand(3).*im))
@test_throws MethodError strides(Adjoint(rand(3, 2) .+ rand(3, 2).*im))
@test strides(Transpose(rand(3) .+ rand(3).*im)) == (3, 1)
@test strides(Transpose(rand(3, 2) .+ rand(3, 2).*im)) == (3, 1)

C = rand(3) .+ rand(3).*im
@test_throws MethodError pointer(Adjoint(C))
@test pointer(Transpose(C)) === pointer(C)
D = rand(3,2) .+ rand(3,2).*im
@test_throws MethodError pointer(Adjoint(D))
@test pointer(Transpose(D)) === pointer(D)
end

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl"))
using .Main.OffsetArrays
Expand Down

0 comments on commit ce847fd

Please sign in to comment.