Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

index_labels method and IndexLabel type indexing #328

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ end
Returns the parent array that type `T` wraps.
"""
parent_type(x) = parent_type(typeof(x))
parent_type(::Type{Symmetric{T,S}}) where {T,S} = S
parent_type(@nospecialize T::Type{<:Union{Symmetric,Hermitian}}) = fieldtype(T, :data)
parent_type(::Type{<:AbstractTriangular{T,S}}) where {T,S} = S
parent_type(@nospecialize T::Type{<:PermutedDimsArray}) = fieldtype(T, :parent)
parent_type(@nospecialize T::Type{<:Adjoint}) = fieldtype(T, :parent)
Expand Down Expand Up @@ -667,6 +667,76 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int)
end
end

"""
IndexLabel(label)

A type that clearly communicates to internal methods to lookup the index corresponding to
for `label`.
"""
struct IndexLabel{L} <: ArrayIndex{1}
label::L
end

"""
UnlabelledIndices(indices)

A set of indices that explicitly do not have any labels and cannot be accessed with an
[`IndexLabel`](@ref).
"""
struct UnlabelledIndices{I<:AbstractUnitRange{Int}} <: AbstractUnitRange{Int}
indices::I
end

"""
LabelledIndices(labels)

A subtype of `AbstractUnitRange{Int}` whose associeated with labels (`labels`).
`eachindex(labels)` are the indices for `LabelledIndices`.
"""
struct LabelledIndices{L<:AbstractVector} <: AbstractUnitRange{Int}
labels::L
end
# don't nest instances of `LabelledIndices`
LabelledIndices(labels::LabelledIndices) = labels

Base.parent(x::UnlabelledIndices) = getfield(x, :indices)
Base.parent(x::LabelledIndices) = getfield(x, :labels)

parent_type(@nospecialize T::Type{<:UnlabelledIndices}) = fieldtype(T, :indices)
parent_type(@nospecialize T::Type{<:LabelledIndices}) = fieldtype(T, :labels)

is_forwarding_wrapper(@nospecialize T::Type{<:Union{UnlabelledIndices,LabelledIndices}}) = true

Base.size(x::Union{UnlabelledIndices,LabelledIndices}) = size(parent(x))
Base.axes(x::Union{UnlabelledIndices,LabelledIndices}) = axes(parent(x))

Base.first(x::UnlabelledIndices) = first(parent(x))
Base.first(x::LabelledIndices) = firstindex(parent(x))

Base.last(x::UnlabelledIndices) = last(parent(x))
Base.last(x::LabelledIndices) = lastindex(parent(x))

"""
getlabels(x, idx)

Given a collection of labelled indices (`x`), returns the subset of lablled indices
corresponding to the index `idx` are returned.
"""
Base.@propagate_inbounds function getlabels(x::LabelledIndices, idx::I) where {I}
ndims_shape(I) === 0 ? parent(x)[idx] : LabelledIndices(parent(x)[idx])
end
function getlabels(x::UnlabelledIndices, idx::I) where {I}
@boundscheck checkbounds(parent(x), idx)
ndims_shape(I) === 0 ? nothing : UnlabelledIndices(eachindex(idx))
end

"""
setlabels!(x, idx, vals)

Sets new labels `vals` at the indices `idx` for the collection of labelled indices `x`.
"""
Base.@propagate_inbounds setlabels!(x::LabelledIndices, i, v) = setindex!(parent(x), i, v)

_cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i)
_cartesian_index(::Any) = nothing

Expand Down Expand Up @@ -766,6 +836,7 @@ julia> ArrayInterfaceCore.ndims_index([CartesianIndex(1, 2), CartesianIndex(1, 3
ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N
# preserve CartesianIndices{0} as they consume a dimension.
ndims_index(::Type{CartesianIndices{0,Tuple{}}}) = 1
ndims_index(@nospecialize T::Type{<:Union{Number,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 1
ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T)
ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T))
ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask))
Expand Down Expand Up @@ -793,7 +864,7 @@ julia> ndims(CartesianIndices((2,2))[[CartesianIndex(1, 1), CartesianIndex(1, 2)
ndims_shape(T::DataType) = ndims_index(T)
ndims_shape(::Type{Colon}) = 1
ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T)
ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex}}) = 0
ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex,IndexLabel,Symbol,AbstractString,AbstractChar}}) = 0
ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1
ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T)
ndims_shape(x) = ndims_shape(typeof(x))
Expand Down
5 changes: 3 additions & 2 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
issingular, isstructured, matrix_colors, restructure, lu_instance,
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims,
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides,
stride_preserving_index
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel,
LabelledIndices, getlabels, setlabels!, UnlabelledIndices, defines_strides, stride_preserving_index

# ArrayIndex subtypes and methods
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex
Expand Down Expand Up @@ -35,6 +35,7 @@ using Base.Iterators: Pairs
using LinearAlgebra

import Compat
using Compat: Returns

_add1(@nospecialize x) = x + oneunit(x)
_sub1(@nospecialize x) = x - oneunit(x)
Expand Down
19 changes: 17 additions & 2 deletions src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N
end
end


# FUTURE NOTE: we avoid `SOneTo(1)` when `axis(A, dim::Int)``. This is inended to decreases
# breaking changes for this adopting this method to situations where they clearly benefit
# from the propagation of static axes. This creates the somewhat awkward situation of
Expand All @@ -113,7 +112,7 @@ axes(A::ReshapedArray) = Base.axes(A)
@inline function axes(x::Union{MatAdjTrans,PermutedDimsArray})
map(GetIndex{false}(axes(parent(x))), to_parent_dims(x))
end
axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1))
axes(A::VecAdjTrans) = (SOneTo{1}(), getfield(axes(parent(A)), 1))

@inline axes(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_axes, x), sub_axes_map(typeof(x))))
@inline _sub_axes(x::SubArray, axis::SOneTo) = axis
Expand Down Expand Up @@ -248,3 +247,19 @@ lazy_axes(x::AbstractRange, ::StaticInt{1}) = Base.axes1(x)
lazy_axes(x, ::Colon) = LazyAxis{:}(x)
lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : LazyAxis{dim}(x)
@inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims)

"""
index_labels(x)
index_labels(x, dim)

Returns a tuple of labels assigned to each axis or a collection of labels corresponding to
each index along `dim` of `x`. Default is to return `UnlabelledIndices(axes(x, dim))`.
"""
index_labels(x, dim) = index_labels(x, to_dims(x, dim))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the function live in ArrayInterfaceCore so that existing "named array" packages can overload it? BTW, it would be good to ping their authors to ensure they would all be OK with the API, otherwise it won't make a lot of sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing packages can overload it from ArrayInterface. If you're referring to supporting named dimensions like NamedDims.jl, then they define ArrayInterface.dimnames and to_dims maps to the appropriate dimension so they don't have to overload every method with a dim argument.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayInterface is relatively heavy, which is why ArrayInterfaceCore was created IIUC. I guess it's up to package authors to say whether a dependency on ArrayInterface is acceptable for them or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was more of an issue before we went through all the trouble to fix invalidations due to StaticInt earlier this year. That doesn't mean we can't improve the situation. We are actively trying to move matured functionality into base where appropriate (see #340). I regularly review the code here in an effort to eliminate problematic code that still exists (e.g., redundancies, generated functions, etc,). For example, once we know how this PR is going to look I can finally finish an effort to consolidate a lot of "indexing.jl" by overloading base methods instead of reimplementing a lot of what's in base already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it's a bit of a chicken and egg issue. We use StaticSymbol for dimension names known at compile time so that we can use them as a point of reference in an inferrible way. It's pretty difficult to do this only relying on constant propagation (demonstrated that with static sizes here JuliaLang/julia#44538 (comment)).

If someone has a reliable solution I'm open to it. I've been trying to actively address this and related issues for years

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oxinabox is it still a concern depending on ArrayInterface at this point?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that ArrayInterface doesn't like Requires.jl a billion packages, I am much more comfortable depending upon it for NamedDims.jl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion was just to put empty function definitions in ArrayInterfaceCore (like we do with DataAPI and StatsAPI). That doesn't prevent keeping fallback method definitions in ArrayInterface as this PR does. But packages that don't want to use fallback definitions are still able to overload the functions by depending only on ArrayInterfaceCore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've been trying to avoid situations where the behavior of a method is different if ArrayInterface is loaded vs just ArrayInterfaceCore.

@inline function index_labels(x, dim::CanonicalInt)
dim > ndims(x) ? UnlabelledIndices(SOneTo(1)) : getfield(index_labels(x), Int(dim))
end
@inline function index_labels(x)
is_forwarding_wrapper(x) ? index_labels(buffer(x)) : map(UnlabelledIndices, axes(x))
end
index_labels(axis::LazyAxis{N}) where {N} = (index_labels(getfield(axis, :parent), N),)
5 changes: 5 additions & 0 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ end
return ntuple(Compat.Returns(:_), StaticInt(ndims(T)))
end
end
known_dimnames(::Type{<:LazyAxis{:,P}}) where {P} = (first(known_dimnames(P)),)
known_dimnames(::Type{<:LazyAxis{N,P}}) where {N,P} = (getfield(known_dimnames(P), N),)

@inline function known_dimnames(::Type{T}) where {T}
if is_forwarding_wrapper(T)
return known_dimnames(parent_type(T))
Expand Down Expand Up @@ -207,6 +210,8 @@ end
return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x)))
end
end
dimnames(x::LazyAxis{:,P}) where {P} = (first(dimnames(getfield(x, :parent))),)
dimnames(x::LazyAxis{N,P}) where {N,P} = (getfield(dimnames(getfield(x, :parent)), N),)
@inline function dimnames(x::X) where {X}
if is_forwarding_wrapper(X)
return dimnames(parent(x))
Expand Down
26 changes: 24 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,33 @@ end
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
max(_add1(canonicalize(i.x)), static_first(x)):static_last(x)
end
# integer indexing
to_index(x, i::AbstractArray{<:Integer}) = i
to_index(x, i::AbstractArray{<:Union{Base.BitInteger,StaticInt}}) = i
to_index(x, @nospecialize(i::StaticInt)) = i
to_index(x, i::Integer) = Int(i)
@inline to_index(x, i) = to_index(IndexStyle(x), x, i)
# label indexing
to_index(x, i::IndexLabel) = to_index(getfield(index_labels(x), 1), i)
function to_index(x::LabelledIndices, i::IndexLabel)
index = findfirst(==(getfield(i, :label)), parent(x))
# delay throwing bounds-error if we didn't find label
index === nothing ? typemin(Int) : index
end
to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) = to_index(x, IndexLabel(i))
# TODO there's probably a more efficient way of doing this
to_index(x, i::LabelledIndices) = to_index(getfield(index_labels(x), 1), i)
to_index(x::LabelledIndices, i::LabelledIndices) = findall(in(parent(i)), parent(x))
to_index(x, ks::AbstractArray{<:IndexLabel}) = [to_index(x, k) for k in ks]
function to_index(x, i::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}})
to_index(x, LabelledIndices(i))
end
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel})
to_index(getfield(index_labels(x), 1), i)
end
@inline function to_index(x::LabelledIndices, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel})
findall(i.f(i.x.label), parent(x))
end

# integer indexing
function to_index(S::IndexStyle, x, i)
throw(ArgumentError(
"invalid index: $S does not support indices of type $(typeof(i)) for instances of type $(typeof(x))."
Expand Down
20 changes: 20 additions & 0 deletions test/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,23 @@ if isdefined(Base, :ReshapedReinterpretArray)
@inferred(ArrayInterface.axes(fa)) isa ArrayInterface.axes_types(fa)
end
end

@testset "index_labels" begin
colors = LabelledArray([(R = rand(), G = rand(), B = rand()) for i ∈ 1:100], (range(-10, 10, length=100),));

@test parent(@inferred(ArrayInterface.index_labels(colors))[1]) == range(-10, 10, length=100)
@test parent(ArrayInterface.index_labels(colors, 1)) == range(-10, 10, length=100)
@test ArrayInterface.index_labels(colors, 2) == ArrayInterface.UnlabelledIndices(axes(colors, 2))
@test ArrayInterface.index_labels(parent(colors)) == map(ArrayInterface.UnlabelledIndices, axes(colors))

label = ArrayInterface.IndexLabel(ArrayInterface.getlabels(ArrayInterface.index_labels(colors)[1], 3))
@test @inferred(ArrayInterface.getindex(colors, label)) == colors[3]
@test @inferred(ArrayInterface.getindex(colors, <=(label))) == colors[1:3]
end

#=
ArrayInterface.to_indices(colors, (label,))
axis = ArrayInterface.lazy_axes(colors)[1]
labels = ArrayInterface.index_labels(axis)[1]
ArrayInterface.to_index(labels, label)
=#
7 changes: 7 additions & 0 deletions test/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ end
r4 = reinterpret(reshape, Float64, x)
w = Wrapper(x)
dnums = ntuple(+, length(d))
lz2 = ArrayInterface.lazy_axes(x)[2]
lzslice = ArrayInterface.LazyAxis{:}(x)

@test @inferred(ArrayInterface.has_dimnames(x)) == true
@test @inferred(ArrayInterface.has_dimnames(z)) == true
@test @inferred(ArrayInterface.has_dimnames(ones(2, 2))) == false
@test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) == false
@test @inferred(ArrayInterface.has_dimnames(typeof(x))) == true
@test @inferred(ArrayInterface.has_dimnames(typeof(view(x, :, 1, :)))) == true
@test @inferred(ArrayInterface.dimnames(x)) === d
@test @inferred(ArrayInterface.dimnames(lz2)) === (static(:y),)
@test @inferred(ArrayInterface.dimnames(lzslice)) === (static(:x),)
@test @inferred(ArrayInterface.dimnames(w)) === d
@test @inferred(ArrayInterface.dimnames(r1)) === d
@test @inferred(ArrayInterface.dimnames(r2)) === (static(:_), d...)
Expand Down Expand Up @@ -64,6 +69,8 @@ end
# multidmensional indices
@test @inferred(ArrayInterface.known_dimnames(view(x, ones(Int, 2, 2), 1))) === (:_, :_)
@test @inferred(ArrayInterface.known_dimnames(view(x, [CartesianIndex(1,1), CartesianIndex(1,1)]))) === (:_,)
@test @inferred(ArrayInterface.known_dimnames(lz2)) === (:y,)
@test @inferred(ArrayInterface.known_dimnames(lzslice)) === (:x,)

@test @inferred(ArrayInterface.known_dimnames(z)) === (nothing, :y)
@test @inferred(ArrayInterface.known_dimnames(reshape(x, (1, 4)))) === (:x, :y)
Expand Down
11 changes: 10 additions & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ function ArrayInterface.known_dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L
ArrayInterface.Static.known(L)
end

Base.parent(x::NamedDimsWrapper) = x.parent
struct LabelledArray{T,N,P<:AbstractArray{T,N},L} <: ArrayInterface.AbstractArray2{T,N}
parent::P
labels::L

LabelledArray(p::P, labels::L) where {P,L} = new{eltype(P),ndims(p),P,L}(p, labels)
end
ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true
Base.parent(x::LabelledArray) = getfield(x, :parent)
ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P
ArrayInterface.index_labels(x::LabelledArray) = map(ArrayInterface.LabelledIndices, getfield(x, :labels))

# Dummy array type with undetermined contiguity properties
struct DummyZeros{T,N} <: AbstractArray{T,N}
Expand Down