From 315964ae7c3f49f6b44cf0d38790f8a35ee06a49 Mon Sep 17 00:00:00 2001 From: Andy Ferris Date: Tue, 18 Dec 2018 21:37:51 +1000 Subject: [PATCH] Make `strides` into a generic trait Returns `nothing` for non-strided arrays, otherwise gives the give strides in memory. Useful as an extensible trait in generic contexts, and simpler to overload for cases of "wrapped" arrays where "stridedness" can be deferred to the parent rather than a complex (and inextensible) method signature. --- base/abstractarray.jl | 15 ++++++++++++--- base/permuteddimsarray.jl | 6 +++++- stdlib/LinearAlgebra/src/adjtrans.jl | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 6b8f6b92b0bd6..ddca3f0e9a87a 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -311,7 +311,9 @@ last(a) = a[end] """ strides(A) -Return a tuple of the memory strides in each dimension. +Return a tuple of the memory strides in each dimension, for an `AbstractArray` with a +strided memory layout. For arrays with a non-strided layout (such as sparse arrays), return +`nothing`. # Examples ```jldoctest @@ -321,7 +323,7 @@ julia> strides(A) (1, 3, 12) ``` """ -function strides end +strides(::AbstractArray) = nothing """ stride(A, k::Integer) @@ -339,7 +341,14 @@ julia> stride(A,3) 12 ``` """ -stride(A::AbstractArray, k::Integer) = strides(A)[k] +function stride(A::AbstractArray, k::Integer) + str = strides(A) + if str === nothing + return nothing + else + return str[k] + end +end @inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...) size_to_strides(s, d) = (s,) diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index d50cc11678e78..e10d2a2979c46 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -60,7 +60,11 @@ Base.pointer(A::PermutedDimsArray, i::Integer) = throw(ArgumentError("pointer(A, function Base.strides(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} s = strides(parent(A)) - ntuple(d->s[perm[d]], Val(N)) + if s === nothing + return nothing + else + return ntuple(d->s[perm[d]], Val(N)) + end end @inline function Base.getindex(A::PermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}) where {T,N,perm,iperm} diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index 8ded77a7eda74..a823dd14b69c9 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -155,6 +155,25 @@ vec(v::TransposeAbsVec) = parent(v) cmp(A::AdjOrTransAbsVec, B::AdjOrTransAbsVec) = cmp(parent(A), parent(B)) isless(A::AdjOrTransAbsVec, B::AdjOrTransAbsVec) = isless(parent(A), parent(B)) +# provide strides, but only for eltypes that are directly stored in memory (i.e. unaffected +# by recursive `adjoint` and `transpose`, being `Real` and `Number` respectively) +function Base.strides(a::Union{Adjoint{<:Real, <:AbstractVector}, Transpose{<:Number, <:AbstractVector}}) + str = strides(a.parent) + if str === nothing + return nothing + else + return (1, str[1]) + end +end +function Base.strides(a::Union{Adjoint{<:Real, <:AbstractMatrix}, Transpose{<:Number, <:AbstractMatrix}}) + str = strides(a.parent) + if str === nothing + return nothing + else + return (str[2], str[1]) + end +end + ### concatenation # preserve Adjoint/Transpose wrapper around vectors # to retain the associated semantics post-concatenation