Skip to content

Commit

Permalink
Strides for Adjoint & Transpose (#710)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Jul 15, 2020
1 parent e667af7 commit c3984f3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ changes in `julia`.

## Supported features

* `strides` is defined for Adjoint and Transpose ([#35929]). (since Compat 3.14)

* `Compat.get_num_threads()` adds the functionality of `LinearAlgebra.BLAS.get_num_threads()`, and has matching `Compat.set_num_threads(n)` ([#36360]). (since Compat 3.13.0)

* `@inferred [AllowedType] f(x)` is defined ([#27516]). (since Compat 3.12.0)
Expand Down Expand Up @@ -177,3 +179,4 @@ Note that you should specify the correct minimum version for `Compat` in the
[#35577]: https://github.com/JuliaLang/julia/pull/35577
[#27516]: https://github.com/JuliaLang/julia/pull/27516
[#36360]: https://github.com/JuliaLang/julia/pull/36360
[#35929]: https://github.com/JuliaLang/julia/pull/35929
25 changes: 25 additions & 0 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,31 @@ if VERSION < v"1.5.0-DEV.681"
Base.union(r::Base.OneTo, s::Base.OneTo) = Base.OneTo(max(r.stop,s.stop))
end

# https://github.com/JuliaLang/julia/pull/35929
# and also https://github.com/JuliaLang/julia/pull/29135 -> Julia 1.5
if VERSION < v"1.5.0-rc1.13" || v"1.6.0-" < VERSION < v"1.6.0-DEV.323"

# Compat.stride not Base.stride, so as not to overwrite the method, and not to create ambiguities:
function stride(A::AbstractArray, k::Integer)
st = strides(A)
k ndims(A) && return st[k]
return sum(st .* size(A))
end
stride(A,k) = Base.stride(A,k) # Fall-through for other methods.

# These were first defined for Adjoint{...,StridedVector} etc in #29135
Base.strides(A::Adjoint{<:Real, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1))
Base.strides(A::Transpose{<:Any, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1))
Base.strides(A::Adjoint{<:Real, <:AbstractMatrix}) = reverse(strides(A.parent))
Base.strides(A::Transpose{<:Any, <:AbstractMatrix}) = reverse(strides(A.parent))
Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)

Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P)
Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P)

end

# https://github.com/JuliaLang/julia/pull/27516
if VERSION < v"1.2.0-DEV.77"
import Test: @inferred
Expand Down
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,33 @@ end
@test union(Base.OneTo(3), Base.OneTo(4)) === Base.OneTo(4)
end

# https://github.com/JuliaLang/julia/pull/35929
# https://github.com/JuliaLang/julia/pull/29135
@testset "strided transposes" begin
for t in (Adjoint, Transpose)
@test strides(t(rand(3))) == (3, 1)
@test strides(t(rand(3,2))) == (3, 1)
@test strides(t(view(rand(3, 2), :))) == (6, 1)
@test strides(t(view(rand(3, 2), :, 1:2))) == (3, 1)

A = rand(3)
@test pointer(t(A)) === pointer(A)
B = rand(3,1)
@test pointer(t(B)) === pointer(B)
end
@test_throws MethodError strides(Adjoint(rand(3) .+ rand(3).*im))
@test_throws MethodError strides(Adjoint(rand(3, 2) .+ rand(3, 2).*im))
@test strides(Transpose(rand(3) .+ rand(3).*im)) == (3, 1)
@test strides(Transpose(rand(3, 2) .+ rand(3, 2).*im)) == (3, 1)

C = rand(3) .+ rand(3).*im
@test_throws ErrorException pointer(Adjoint(C))
@test pointer(Transpose(C)) === pointer(C)
D = rand(3,2) .+ rand(3,2).*im
@test_throws ErrorException pointer(Adjoint(D))
@test pointer(Transpose(D)) === pointer(D)
end

# https://github.com/JuliaLang/julia/pull/27516
@testset "two arg @inferred" begin
g(a) = a < 10 ? missing : 1
Expand Down

0 comments on commit c3984f3

Please sign in to comment.