diff --git a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl index 905dac535..e286446fb 100644 --- a/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl +++ b/lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl @@ -119,7 +119,7 @@ end Returns the parent array that type `T` wraps. """ parent_type(x) = parent_type(typeof(x)) -parent_type(::Type{Symmetric{T,S}}) where {T,S} = S +parent_type(@nospecialize T::Type{<:Union{Symmetric,Hermitian}}) = fieldtype(T, :data) parent_type(::Type{<:AbstractTriangular{T,S}}) where {T,S} = S parent_type(@nospecialize T::Type{<:PermutedDimsArray}) = fieldtype(T, :parent) parent_type(@nospecialize T::Type{<:Adjoint}) = fieldtype(T, :parent) @@ -667,6 +667,76 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int) end end +""" + IndexLabel(label) + +A type that clearly communicates to internal methods to lookup the index corresponding to +for `label`. +""" +struct IndexLabel{L} <: ArrayIndex{1} + label::L +end + +""" + UnlabelledIndices(indices) + +A set of indices that explicitly do not have any labels and cannot be accessed with an +[`IndexLabel`](@ref). +""" +struct UnlabelledIndices{I<:AbstractUnitRange{Int}} <: AbstractUnitRange{Int} + indices::I +end + +""" + LabelledIndices(labels) + +A subtype of `AbstractUnitRange{Int}` whose associeated with labels (`labels`). +`eachindex(labels)` are the indices for `LabelledIndices`. +""" +struct LabelledIndices{L<:AbstractVector} <: AbstractUnitRange{Int} + labels::L +end +# don't nest instances of `LabelledIndices` +LabelledIndices(labels::LabelledIndices) = labels + +Base.parent(x::UnlabelledIndices) = getfield(x, :indices) +Base.parent(x::LabelledIndices) = getfield(x, :labels) + +parent_type(@nospecialize T::Type{<:UnlabelledIndices}) = fieldtype(T, :indices) +parent_type(@nospecialize T::Type{<:LabelledIndices}) = fieldtype(T, :labels) + +is_forwarding_wrapper(@nospecialize T::Type{<:Union{UnlabelledIndices,LabelledIndices}}) = true + +Base.size(x::Union{UnlabelledIndices,LabelledIndices}) = size(parent(x)) +Base.axes(x::Union{UnlabelledIndices,LabelledIndices}) = axes(parent(x)) + +Base.first(x::UnlabelledIndices) = first(parent(x)) +Base.first(x::LabelledIndices) = firstindex(parent(x)) + +Base.last(x::UnlabelledIndices) = last(parent(x)) +Base.last(x::LabelledIndices) = lastindex(parent(x)) + +""" + getlabels(x, idx) + +Given a collection of labelled indices (`x`), returns the subset of lablled indices +corresponding to the index `idx` are returned. +""" +Base.@propagate_inbounds function getlabels(x::LabelledIndices, idx::I) where {I} + ndims_shape(I) === 0 ? parent(x)[idx] : LabelledIndices(parent(x)[idx]) +end +function getlabels(x::UnlabelledIndices, idx::I) where {I} + @boundscheck checkbounds(parent(x), idx) + ndims_shape(I) === 0 ? nothing : UnlabelledIndices(eachindex(idx)) +end + +""" + setlabels!(x, idx, vals) + +Sets new labels `vals` at the indices `idx` for the collection of labelled indices `x`. +""" +Base.@propagate_inbounds setlabels!(x::LabelledIndices, i, v) = setindex!(parent(x), i, v) + _cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i) _cartesian_index(::Any) = nothing @@ -766,6 +836,7 @@ julia> ArrayInterfaceCore.ndims_index([CartesianIndex(1, 2), CartesianIndex(1, 3 ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N # preserve CartesianIndices{0} as they consume a dimension. ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1 +ndims_index(@nospecialize T::Type{<:Union{Number,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 1 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T) ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T)) ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask)) @@ -793,7 +864,7 @@ julia> ndims(CartesianIndices((2,2))[[CartesianIndex(1, 1), CartesianIndex(1, 2) ndims_shape(T::DataType) = ndims_index(T) ndims_shape(::Type{Colon}) = 1 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T) -ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex}}) = 0 +ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 0 ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T) ndims_shape(x) = ndims_shape(typeof(x)) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 26e173974..2bae3d9ca 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -6,8 +6,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff issingular, isstructured, matrix_colors, restructure, lu_instance, safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type, ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims, - parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides, - stride_preserving_index + parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel, + LabelledIndices, getlabels, setlabels!, UnlabelledIndices, defines_strides, stride_preserving_index # ArrayIndex subtypes and methods import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex @@ -35,6 +35,7 @@ using Base.Iterators: Pairs using LinearAlgebra import Compat +using Compat: Returns _add1(@nospecialize x) = x + oneunit(x) _sub1(@nospecialize x) = x - oneunit(x) diff --git a/src/axes.jl b/src/axes.jl index d992c512b..1672ee392 100644 --- a/src/axes.jl +++ b/src/axes.jl @@ -90,7 +90,6 @@ function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N end end - # FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases # breaking changes for this adopting this method to situations where they clearly benefit # from the propagation of static axes. This creates the somewhat awkward situation of @@ -113,7 +112,7 @@ axes(A::ReshapedArray) = Base.axes(A) @inline function axes(x::Union{MatAdjTrans,PermutedDimsArray}) map(GetIndex{false}(axes(parent(x))), to_parent_dims(x)) end -axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1)) +axes(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes(parent(A)), 1)) @inline axes(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_axes, x), sub_axes_map(typeof(x)))) @inline _sub_axes(x::SubArray, axis::SOneTo) = axis @@ -248,3 +247,19 @@ lazy_axes(x::AbstractRange, ::StaticInt{1}) = Base.axes1(x) lazy_axes(x, ::Colon) = LazyAxis{:}(x) lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : LazyAxis{dim}(x) @inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims) + +""" + index_labels(x) + index_labels(x, dim) + +Returns a tuple of labels assigned to each axis or a collection of labels corresponding to +each index along `dim` of `x`. Default is to return `UnlabelledIndices(axes(x, dim))`. +""" +index_labels(x, dim) = index_labels(x, to_dims(x, dim)) +@inline function index_labels(x, dim::CanonicalInt) + dim > ndims(x) ? UnlabelledIndices(SOneTo(1)) : getfield(index_labels(x), Int(dim)) +end +@inline function index_labels(x) + is_forwarding_wrapper(x) ? index_labels(buffer(x)) : map(UnlabelledIndices, axes(x)) +end +index_labels(axis::LazyAxis{N}) where {N} = (index_labels(getfield(axis, :parent), N),) diff --git a/src/dimensions.jl b/src/dimensions.jl index 194658852..d562be58f 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -148,6 +148,9 @@ end return ntuple(Compat.Returns(:_), StaticInt(ndims(T))) end end +known_dimnames(::Type{<:LazyAxis{:,P}}) where {P} = (first(known_dimnames(P)),) +known_dimnames(::Type{<:LazyAxis{N,P}}) where {N,P} = (getfield(known_dimnames(P), N),) + @inline function known_dimnames(::Type{T}) where {T} if is_forwarding_wrapper(T) return known_dimnames(parent_type(T)) @@ -207,6 +210,8 @@ end return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x))) end end +dimnames(x::LazyAxis{:,P}) where {P} = (first(dimnames(getfield(x, :parent))),) +dimnames(x::LazyAxis{N,P}) where {N,P} = (getfield(dimnames(getfield(x, :parent)), N),) @inline function dimnames(x::X) where {X} if is_forwarding_wrapper(X) return dimnames(parent(x)) diff --git a/src/indexing.jl b/src/indexing.jl index a4b4bbca0..88029b49b 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -173,11 +173,33 @@ end @inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}}) max(_add1(canonicalize(i.x)), static_first(x)):static_last(x) end -# integer indexing -to_index(x, i::AbstractArray{<:Integer}) = i +to_index(x, i::AbstractArray{<:Union{Base.BitInteger,StaticInt}}) = i to_index(x, @nospecialize(i::StaticInt)) = i to_index(x, i::Integer) = Int(i) @inline to_index(x, i) = to_index(IndexStyle(x), x, i) +# label indexing +to_index(x, i::IndexLabel) = to_index(getfield(index_labels(x), 1), i) +function to_index(x::LabelledIndices, i::IndexLabel) + index = findfirst(==(getfield(i, :label)), parent(x)) + # delay throwing bounds-error if we didn't find label + index === nothing ? typemin(Int) : index +end +to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) = to_index(x, IndexLabel(i)) +# TODO there's probably a more efficient way of doing this +to_index(x, i::LabelledIndices) = to_index(getfield(index_labels(x), 1), i) +to_index(x::LabelledIndices, i::LabelledIndices) = findall(in(parent(i)), parent(x)) +to_index(x, ks::AbstractArray{<:IndexLabel}) = [to_index(x, k) for k in ks] +function to_index(x, i::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}}) + to_index(x, LabelledIndices(i)) +end +@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel}) + to_index(getfield(index_labels(x), 1), i) +end +@inline function to_index(x::LabelledIndices, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel}) + findall(i.f(i.x.label), parent(x)) +end + +# integer indexing function to_index(S::IndexStyle, x, i) throw(ArgumentError( "invalid index: $S does not support indices of type $(typeof(i)) for instances of type $(typeof(x))." diff --git a/test/axes.jl b/test/axes.jl index 325b20c69..8a2b333e5 100644 --- a/test/axes.jl +++ b/test/axes.jl @@ -104,3 +104,23 @@ if isdefined(Base, :ReshapedReinterpretArray) @inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa) end end + +@testset "index_labels" begin + colors = LabelledArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),)); + + @test parent(@inferred(ArrayInterface.index_labels(colors))[1]) == range(-10, 10, length=100) + @test parent(ArrayInterface.index_labels(colors, 1)) == range(-10, 10, length=100) + @test ArrayInterface.index_labels(colors, 2) == ArrayInterface.UnlabelledIndices(axes(colors, 2)) + @test ArrayInterface.index_labels(parent(colors)) == map(ArrayInterface.UnlabelledIndices, axes(colors)) + + label = ArrayInterface.IndexLabel(ArrayInterface.getlabels(ArrayInterface.index_labels(colors)[1], 3)) + @test @inferred(ArrayInterface.getindex(colors, label)) == colors[3] + @test @inferred(ArrayInterface.getindex(colors, <=(label))) == colors[1:3] +end + +#= +ArrayInterface.to_indices(colors, (label,)) +axis = ArrayInterface.lazy_axes(colors)[1] +labels = ArrayInterface.index_labels(axis)[1] +ArrayInterface.to_index(labels, label) +=# diff --git a/test/dimensions.jl b/test/dimensions.jl index 00876a383..48dd6b1c1 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -29,6 +29,9 @@ end r4 = reinterpret(reshape, Float64, x) w = Wrapper(x) dnums = ntuple(+, length(d)) + lz2 = ArrayInterface.lazy_axes(x)[2] + lzslice = ArrayInterface.LazyAxis{:}(x) + @test @inferred(ArrayInterface.has_dimnames(x)) == true @test @inferred(ArrayInterface.has_dimnames(z)) == true @test @inferred(ArrayInterface.has_dimnames(ones(2, 2))) == false @@ -36,6 +39,8 @@ end @test @inferred(ArrayInterface.has_dimnames(typeof(x))) == true @test @inferred(ArrayInterface.has_dimnames(typeof(view(x, :, 1, :)))) == true @test @inferred(ArrayInterface.dimnames(x)) === d + @test @inferred(ArrayInterface.dimnames(lz2)) === (static(:y),) + @test @inferred(ArrayInterface.dimnames(lzslice)) === (static(:x),) @test @inferred(ArrayInterface.dimnames(w)) === d @test @inferred(ArrayInterface.dimnames(r1)) === d @test @inferred(ArrayInterface.dimnames(r2)) === (static(:_), d...) @@ -64,6 +69,8 @@ end # multidmensional indices @test @inferred(ArrayInterface.known_dimnames(view(x, ones(Int, 2, 2), 1))) === (:_, :_) @test @inferred(ArrayInterface.known_dimnames(view(x, [CartesianIndex(1,1), CartesianIndex(1,1)]))) === (:_,) + @test @inferred(ArrayInterface.known_dimnames(lz2)) === (:y,) + @test @inferred(ArrayInterface.known_dimnames(lzslice)) === (:x,) @test @inferred(ArrayInterface.known_dimnames(z)) === (nothing, :y) @test @inferred(ArrayInterface.known_dimnames(reshape(x, (1, 4)))) === (:x, :y) diff --git a/test/setup.jl b/test/setup.jl index 4dc7869ba..6e9c99f58 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -34,7 +34,16 @@ function ArrayInterface.known_dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L ArrayInterface.Static.known(L) end -Base.parent(x::NamedDimsWrapper) = x.parent +struct LabelledArray{T,N,P<:AbstractArray{T,N},L} <: ArrayInterface.AbstractArray2{T,N} + parent::P + labels::L + + LabelledArray(p::P, labels::L) where {P,L} = new{eltype(P),ndims(p),P,L}(p, labels) +end +ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true +Base.parent(x::LabelledArray) = getfield(x, :parent) +ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P +ArrayInterface.index_labels(x::LabelledArray) = map(ArrayInterface.LabelledIndices, getfield(x, :labels)) # Dummy array type with undetermined contiguity properties struct DummyZeros{T,N} <: AbstractArray{T,N}