Skip to content

Commit

Permalink
Simplify initial index labels PR
Browse files Browse the repository at this point in the history
* removed support for nested traversal of `index_labels`. Support for
  this can be added later but it required making a lot of decisions
  about managing labels that up front that would be better addressed
  through iterative PRs
* removed `has_index_labels`. There are some odd corner cases for this
  one. Particularly for `SubArrays` where the presence of labels in the
  parent don't always have a clear way of propagating forward. Again,
  we can address this one but it will take some decisions about how
  labels are propagated.
* `UnlabelledIndices` and `LabelledIndices` are types that provide a
  more clear structure to what a label is and how they are accessed.
  • Loading branch information
Tokazama committed Sep 29, 2022
1 parent 122536f commit d7c0a27
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 155 deletions.
61 changes: 61 additions & 0 deletions lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,67 @@ struct IndexLabel{L} <: ArrayIndex{1}
label::L
end

"""
UnlabelledIndices(indices)
A set of indices that explicitly do not have any labels and cannot be accessed with an
[`IndexLabel`](@ref).
"""
struct UnlabelledIndices{I<:AbstractUnitRange{Int}} <: AbstractUnitRange{Int}
indices::I
end

"""
LabelledIndices(labels)
A subtype of `AbstractUnitRange{Int}` whose associeated with labels (`labels`).
`eachindex(labels)` are the indices for `LabelledIndices`.
"""
struct LabelledIndices{L<:AbstractVector} <: AbstractUnitRange{Int}
labels::L
end
# don't nest instances of `LabelledIndices`
LabelledIndices(labels::LabelledIndices) = labels

Base.parent(x::UnlabelledIndices) = getfield(x, :indices)
Base.parent(x::LabelledIndices) = getfield(x, :labels)

parent_type(@nospecialize T::Type{<:UnlabelledIndices}) = fieldtype(T, :indices)
parent_type(@nospecialize T::Type{<:LabelledIndices}) = fieldtype(T, :labels)

is_forwarding_wrapper(@nospecialize T::Type{<:Union{UnlabelledIndices,LabelledIndices}}) = true

Base.size(x::Union{UnlabelledIndices,LabelledIndices}) = size(parent(x))
Base.axes(x::Union{UnlabelledIndices,LabelledIndices}) = axes(parent(x))

Base.first(x::UnlabelledIndices) = first(parent(x))
Base.first(x::LabelledIndices) = firstindex(parent(x))

Base.last(x::UnlabelledIndices) = last(parent(x))
Base.last(x::LabelledIndices) = lastindex(parent(x))

"""
getlabels(x, idx)
Given a collection of labelled indices (`x`), the subset of lablled indices corresponding
to the index `idx` are returned.
"""
Base.@propagate_inbounds function getlabels(x::LabelledIndices, idx::I) where {I}
ndims_shape(I) === 0 ? parent(x)[idx] : LabelledIndices(parent(x)[idx])
end
function getlabels(x::UnlabelledIndices, idx::I) where {I}
@boundscheck checkbounds(parent(x), idx)
ndims_shape(I) === 0 ? nothing : UnlabelledIndices(eachindex(idx))
end

"""
setlabels!(x, idx, vals)
Given a collection of labelled indices (`x`), the subset of lablled indices corresponding
to the index `idx` are returned.
"""
Base.@propagate_inbounds setlabels!(x::LabelledIndices, i, v) = setindex!(parent(x), i, v)

_cartesian_index(i::Tuple{Vararg{Int}}) = CartesianIndex(i)
_cartesian_index(::Any) = nothing

Expand Down
4 changes: 2 additions & 2 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
issingular, isstructured, matrix_colors, restructure, lu_instance,
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims,
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel, defines_strides,
stride_preserving_index
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, IndexLabel,
LabelledIndices, getlabels, UnlabelledIndices, defines_strides, stride_preserving_index

# ArrayIndex subtypes and methods
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex
Expand Down
102 changes: 4 additions & 98 deletions src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,112 +248,18 @@ lazy_axes(x, ::Colon) = LazyAxis{:}(x)
lazy_axes(x, ::StaticInt{dim}) where {dim} = ndims(x) < dim ? SOneTo{1}() : LazyAxis{dim}(x)
@inline lazy_axes(x, dims::Tuple) = map(Base.Fix1(lazy_axes, x), dims)

"""
has_index_labels(x) -> Bool
Returns `true` if `x` has has any index labels. If [`index_labels`](@ref) returns a tuple of
`nothing`, this will be `false`.
See also: [`index_labels`](@ref)
"""
has_index_labels(x) = _any_labels(index_labels(x))
function has_index_labels(x::Union{Base.NonReshapedReinterpretArray,Transpose,Adjoint,PermutedDimsArray,Symmetric,Hermitian})
has_index_labels(parent(x))
end
function has_index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
if has_index_labels(parent(x))
true
else
size1 = div(sizeof(S), sizeof(T))
size1 > 1 && size1 === fieldcount(S)
end
end
function has_index_labels(x::SubArray)
if has_index_labels(parent(x))
return true
else
inds = x.indices
for i in 1:nfields(inds)
has_index_labels(getfield(inds, i)) && return true
end
return false
end
end
_any_labels(@nospecialize labels::Tuple{Vararg{Nothing}}) = false
_any_labels(@nospecialize labels::Tuple{Vararg{Any}}) = true

"""
index_labels(x)
index_labels(x, dim)
Returns a tuple of labels assigned to each axis or a collection of labels corresponding to
each index along `dim` of `x`. Default is to simply return `nothing`.
See also: [`has_index_labels`](@ref)
each index along `dim` of `x`. Default is to return `UnlabelledIndices(axes(x, dim))`.
"""
index_labels(x, dim) = index_labels(x, to_dims(x, dim))
index_labels(@nospecialize x::Number) = ()
@inline function index_labels(x, dim::CanonicalInt)
dim > ndims(x) ? nothing : getfield(index_labels(x), Int(dim))
dim > ndims(x) ? UnlabelledIndices(SOneTo(1)) : getfield(index_labels(x), Int(dim))
end
@inline function index_labels(x)
if is_forwarding_wrapper(x)
index_labels(buffer(x))
else
ntuple(Returns(nothing), Val{ndims(x)}())
end
end
function index_labels(x::Union{MatAdjTrans,PermutedDimsArray})
map(GetIndex{false}(index_labels(parent(x))), to_parent_dims(x))
end
index_labels(x::VecAdjTrans) = (nothing, getfield(index_labels(parent(x)), 1))
function index_labels(x::SubArray)
labels = index_labels(parent(x))
inds = x.indices
info = IndicesInfo(x)
pdims = parentdims(info)
cdims = childdims(info)
flatten_tuples(ntuple(Val{nfields(pdims)}()) do i
pdim_i = getfield(pdims, i)
cdim_i = getfield(cdims, i)
index = getfield(inds, i)
if pdim_i isa Tuple || cdim_i isa Tuple # no direct mapping to parent axes
index_labels(index)
elseif cdim_i === 0 # integer indexing drops axes
()
elseif pdim_i === 0 # trailing dimension
nothing
elseif index isa Base.Slice # index into labels where there is direct mapping to parent axis
(getfield(labels, pdim_i),)
else
labels_i = getfield(labels, pdim_i)
labels_i === nothing ? index_labels(index) : (@inbounds(labels_i[index]),)
end
end)
end
index_labels(x::Union{LinearIndices,CartesianIndices}) = map(first index_labels, x.indices)
index_labels(x::Union{Symmetric,Hermitian}) = index_labels(parent(x))
index_labels(@nospecialize(x::LazyAxis{:})) = (nothing,)
index_labels(x::LazyAxis{N}) where {N} = (getfield(index_labels(getfield(x, :parent)), N),)
@inline @inline function index_labels(x::Base.NonReshapedReinterpretArray{T,N,S}) where {T,N,S}
if sizeof(T) === sizeof(S)
return index_labels(parent(x))
else
return (nothing, Base.tail(index_labels(parent(x)))...)
end
end
function index_labels(x::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
_reinterpret_index_labels(div(StaticInt(sizeof(S)), StaticInt(sizeof(T))), x)
end
@inline function _reinterpreted_fieldnames(@nospecialize T::Type{<:Base.ReshapedReinterpretArray})
S = eltype(parent_type(T))
isstructtype(S) ? fieldnames(S) : ()
end
function _reinterpret_index_labels(s::StaticInt{N}, x::Base.ReshapedReinterpretArray) where {N}
__reinterpret_index_labels(s, _reinterpreted_fieldnames(typeof(x)), index_labels(parent(x)))
end
@inline function __reinterpret_index_labels(::StaticInt{N}, fields::NTuple{M,Symbol}, ks::Tuple) where {N,M}
N === M ? (fields, ks...,) : (nothing, ks...,)
is_forwarding_wrapper(x) ? index_labels(buffer(x)) : map(UnlabelledIndices, axes(x))
end
_reinterpret_index_labels(::StaticInt{1}, x::Base.ReshapedReinterpretArray) = index_labels(parent(x))
_reinterpret_index_labels(::StaticInt{0}, x::Base.ReshapedReinterpretArray) = Base.tail(index_labels(parent(x)))
index_labels(axis::LazyAxis{N}) where {N} = (index_labels(getfield(axis, :parent), N),)
4 changes: 0 additions & 4 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ function _init_dimsmap(@nospecialize info::IndicesInfo)
ntuple(i -> static(getfield(cdims, i)), length(pdims))
end

parentdims(::IndicesInfo{<:Any,pdims}) where {pdims} = pdims

childdims(::IndicesInfo{<:Any,<:Any,cdims}) where {cdims} = cdims

"""
to_parent_dims(::Type{T}) -> Tuple{Vararg{Union{StaticInt,Tuple{Vararg{StaticInt}}}}}
Expand Down
29 changes: 16 additions & 13 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,30 +170,33 @@ end
@inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}})
max(canonicalize(i.x), static_first(x)):static_last(x)
end
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel})
findall(i.f(i.x.label), first(index_labels(x)))
end
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
max(_add1(canonicalize(i.x)), static_first(x)):static_last(x)
end
to_index(x, i::AbstractArray{<:Union{Base.BitInteger,StaticInt}}) = i
to_index(x, @nospecialize(i::StaticInt)) = i
to_index(x, i::Integer) = Int(i)
@inline to_index(x, i) = to_index(IndexStyle(x), x, i)
# key indexing
function to_index(x, i::IndexLabel)
index = findfirst(==(getfield(i, :label)), first(index_labels(x)))
# label indexing
to_index(x, i::IndexLabel) = to_index(getfield(index_labels(x), 1), i)
function to_index(x::LabelledIndices, i::IndexLabel)
index = findfirst(==(getfield(i, :label)), parent(x))
# delay throwing bounds-error if we didn't find label
index === nothing ? offset1(x) - 1 : index
end
function to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number})
index = findfirst(==(i), getfield(index_labels(x), 1))
index === nothing ? offset1(x) - 1 : index
index === nothing ? typemin(Int) : index
end
to_index(x, i::Union{Symbol,AbstractString,AbstractChar,Number}) = to_index(x, IndexLabel(i))
# TODO there's probably a more efficient way of doing this
to_index(x, i::LabelledIndices) = to_index(getfield(index_labels(x), 1), i)
to_index(x::LabelledIndices, i::LabelledIndices) = findall(in(parent(i)), parent(x))
to_index(x, ks::AbstractArray{<:IndexLabel}) = [to_index(x, k) for k in ks]
function to_index(x, ks::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}})
[to_index(x, k) for k in ks]
function to_index(x, i::AbstractArray{<:Union{Symbol,AbstractString,AbstractChar,Number}})
to_index(x, LabelledIndices(i))
end
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel})
to_index(getfield(index_labels(x), 1), i)
end
@inline function to_index(x::LabelledIndices, i::Base.Fix2{<:Union{typeof(>),typeof(>=),typeof(<=),typeof(<),typeof(isless)},<:IndexLabel})
findall(i.f(i.x.label), parent(x))
end

# integer indexing
Expand Down
53 changes: 16 additions & 37 deletions test/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,41 +107,20 @@ end

@testset "index_labels" begin
colors = LabelledArray([(R = rand(), G = rand(), B = rand()) for i 1:100], (range(-10, 10, length=100),));
caxis = ArrayInterface.LazyAxis{1}(colors);
colormat = reinterpret(reshape, Float64, colors);
cmat_view1 = view(colormat, :, 4);
cmat_view2 = view(colormat, :, 4:7);
cmat_view3 = view(colormat, 2:3,:);
absym_abstr = LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],));

@test @inferred(ArrayInterface.index_labels(colors)) == (range(-10, 10, length=100),)
@test @inferred(ArrayInterface.index_labels(caxis)) == (range(-10, 10, length=100),)
@test ArrayInterface.index_labels(view(colors, :, :), 2) === nothing
@test @inferred(ArrayInterface.index_labels(LinearIndices((caxis,)))) == (range(-10, 10, length=100),)
@test @inferred(ArrayInterface.index_labels(colormat)) == ((:R, :G, :B), range(-10, 10, length=100))
@test @inferred(ArrayInterface.index_labels(colormat')) == (range(-10, 10, length=100), (:R, :G, :B))
@test @inferred(ArrayInterface.index_labels(cmat_view1)) == ((:R, :G, :B),)
@test @inferred((ArrayInterface.index_labels(cmat_view2))) == ((:R, :G, :B), -9.393939393939394:0.20202020202020202:-8.787878787878787)
@test @inferred((ArrayInterface.index_labels(view(colormat, 1, :)'))) == (nothing, range(-10, 10, length=100))
# can't infer this b/c tuple is being indexed by range
@test ArrayInterface.index_labels(cmat_view3) == ((:G, :B), -10.0:0.20202020202020202:10.0)
@test @inferred(ArrayInterface.index_labels(Symmetric(view(colormat, :, 1:3)))) == ((:R, :G, :B), -10.0:0.20202020202020202:-9.595959595959595)

@test @inferred(ArrayInterface.index_labels(reinterpret(Int8, absym_abstr))) == (nothing, ["a", "b"])
@test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int8, absym_abstr))) == (nothing, [:a, :b], ["a", "b"])
@test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Int64, LabelledArray(rand(Int32, 2,2), ([:a, :b], ["a", "b"],))))) == (["a", "b"],)
@test @inferred(ArrayInterface.index_labels(reinterpret(reshape, Float64, LabelledArray(rand(Int64, 2,2), ([:a, :b], ["a", "b"],))))) == ([:a, :b], ["a", "b"],)
@test @inferred(ArrayInterface.index_labels(reinterpret(Float64, absym_abstr))) == ([:a, :b], ["a", "b"],)

@test ArrayInterface.has_index_labels(colors)
@test ArrayInterface.has_index_labels(caxis)
@test ArrayInterface.has_index_labels(colormat)
@test ArrayInterface.has_index_labels(cmat_view1)
@test !ArrayInterface.has_index_labels(view(colors, :, :))

@test @inferred(ArrayInterface.getindex(colormat, :R, :)) == colormat[1, :]
@test @inferred(ArrayInterface.getindex(cmat_view1, :R)) == cmat_view1[1]
@test @inferred(ArrayInterface.getindex(colormat, :,ArrayInterface.IndexLabel(-9.595959595959595))) == colormat[:, 3]
@test @inferred(ArrayInterface.getindex(colormat, :,<=(ArrayInterface.IndexLabel(-9.595959595959595)))) == colormat[:, 1:3]
@test @inferred(ArrayInterface.getindex(absym_abstr, :, ["a"])) == absym_abstr[:,[1]]

@test parent(@inferred(ArrayInterface.index_labels(colors))[1]) == range(-10, 10, length=100)
@test parent(ArrayInterface.index_labels(colors, 1)) == range(-10, 10, length=100)
@test ArrayInterface.index_labels(colors, 2) == ArrayInterface.UnlabelledIndices(axes(colors, 2))
@test ArrayInterface.index_labels(parent(colors)) == map(ArrayInterface.UnlabelledIndices, axes(colors))

label = ArrayInterface.IndexLabel(ArrayInterface.getlabels(ArrayInterface.index_labels(colors)[1], 3))
@test @inferred(ArrayInterface.getindex(colors, label)) == colors[3]
@test @inferred(ArrayInterface.getindex(colors, <=(label))) == colors[1:3]
end

#=
ArrayInterface.to_indices(colors, (label,))
axis = ArrayInterface.lazy_axes(colors)[1]
labels = ArrayInterface.index_labels(axis)[1]
ArrayInterface.to_index(labels, label)
=#
2 changes: 1 addition & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
ArrayInterface.is_forwarding_wrapper(::Type{<:LabelledArray}) = true
Base.parent(x::LabelledArray) = getfield(x, :parent)
ArrayInterface.parent_type(::Type{T}) where {P,T<:LabelledArray{<:Any,<:Any,P}} = P
ArrayInterface.index_labels(x::LabelledArray) = getfield(x, :labels)
ArrayInterface.index_labels(x::LabelledArray) = map(ArrayInterface.LabelledIndices, getfield(x, :labels))

# Dummy array type with undetermined contiguity properties
struct DummyZeros{T,N} <: AbstractArray{T,N}
Expand Down

0 comments on commit d7c0a27

Please sign in to comment.