Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Add SortedSets and TagSets #1204

Merged
merged 12 commits into from
Oct 4, 2023
1 change: 1 addition & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Compat
using Dictionaries
using FLoops
using Folds
using InlineStrings
using Random
using LinearAlgebra
using StaticArrays
Expand All @@ -19,6 +20,10 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
using .TagSets

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down
214 changes: 214 additions & 0 deletions NDTensors/src/SortedSets/src/SortedSets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
module SortedSets
using Dictionaries

using Base: @propagate_inbounds
using Base.Order: Ordering, Forward
using Random

import Dictionaries:
istokenizable,
tokentype,
iteratetoken,
iteratetoken_reverse,
gettoken,
gettokenvalue,
isinsertable,
gettoken!,
empty_type,
deletetoken!,
randtoken

export SortedSet
# TODO:
# Make an `AbstractSortedIndices`? Is that needed?
# Define specialized implementations for:
#
# Base.union
# Base.intersect
# Base.setdiff
# Base.symdiff
# Base.sort
#
# which can be dispatched on `SmallVector` for faster operations,
# potentially using a trait (`HasFastCopy`?).

"""
SortedIndices(iter)

Construct an `SortedIndices <: AbstractIndices` from an arbitrary Julia iterable with unique
elements. Lookup uses that they are sorted.

SortedIndices can be faster than ArrayIndices which use naive search that may be optimal for
small collections. Larger collections are better handled by containers like `Indices`.
"""
struct SortedIndices{I,Inds<:AbstractArray{I},SortKwargs<:NamedTuple} <: AbstractIndices{I}
inds::Inds
sort_kwargs::SortKwargs
@inline function SortedIndices{I,Inds}(
a::Inds;
lt=isless,
by=identity,
rev::Bool=false,
order::Ordering=Forward,
checksorted::Bool=true,
checkunique::Bool=true,
) where {I,Inds<:AbstractArray{I}}
if checkunique
@assert allunique(Iterators.map(by, a))
end
if checksorted
@assert issorted(a; lt, by, rev, order)
end
sort_kwargs = (; lt, by, rev, order)
return new{I,Inds,typeof(sort_kwargs)}(a, sort_kwargs)
end
end

const SortSet{I,Inds,SortKwargs} = SortedIndices{I,Inds,SortKwargs}

@propagate_inbounds SortedIndices() = SortedIndices{Any}([])
@propagate_inbounds SortedIndices{I}() where {I} = SortedIndices{I,Vector{I}}(I[])
@propagate_inbounds SortedIndices{I,Inds}() where {I,Inds} = SortedIndices{I}(Inds())

@propagate_inbounds SortedIndices(iter) = SortedIndices(collect(iter))
@propagate_inbounds SortedIndices{I}(iter) where {I} = SortedIndices{I}(collect(I, iter))

@propagate_inbounds SortedIndices(a::AbstractArray{I}) where {I} = SortedIndices{I}(a)
@propagate_inbounds SortedIndices{I}(a::AbstractArray{I}) where {I} =
SortedIndices{I,typeof(a)}(a)

function Base.convert(::Type{AbstractIndices{I}}, inds::SortedIndices) where {I}
return convert(SortedIndices{I}, inds)
end
function Base.convert(::Type{SortedIndices}, inds::AbstractIndices{I}) where {I}
return convert(SortedIndices{I}, inds)
end
function Base.convert(::Type{SortedIndices{I}}, inds::AbstractIndices) where {I}
return convert(SortedIndices{I,Vector{I}}, inds)
end
function Base.convert(
::Type{SortedIndices{I,Inds}}, inds::AbstractIndices
) where {I,Inds<:AbstractArray{I}}
a = convert(Inds, collect(I, inds))
return @inbounds SortedIndices{I,typeof(a)}(a)
end

Base.convert(::Type{SortedIndices{I}}, inds::SortedIndices{I}) where {I} = inds
function Base.convert(
::Type{SortedIndices{I}}, inds::SortedIndices{<:Any,Inds}
) where {I,Inds<:AbstractArray{I}}
return convert(SortedIndices{I,Inds}, inds)
end
function Base.convert(
::Type{SortedIndices{I,Inds}}, inds::SortedIndices{I,Inds}
) where {I,Inds<:AbstractArray{I}}
return inds
end
function Base.convert(
::Type{SortedIndices{I,Inds}}, inds::SortedIndices
) where {I,Inds<:AbstractArray{I}}
a = convert(Inds, parent(inds))
return @inbounds SortedIndices{I,Inds}(a)
end

Base.parent(inds::SortedIndices) = getfield(inds, :inds)

# Basic interface
@propagate_inbounds function Base.iterate(i::SortedIndices{I}, state...) where {I}
return iterate(parent(i), state...)
end

Base.in(i::I, inds::SortedIndices{I}) where {I} = insorted(i, parent(inds); inds.sort_kwargs...)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
Base.IteratorSize(::SortedIndices) = Base.HasLength()
Base.length(inds::SortedIndices) = length(parent(inds))

istokenizable(i::SortedIndices) = true
tokentype(::SortedIndices) = Int
@inline iteratetoken(inds::SortedIndices, s...) = iterate(LinearIndices(parent(inds)), s...)
@inline function iteratetoken_reverse(inds::SortedIndices)
li = LinearIndices(parent(inds))
if isempty(li)
return nothing
else
t = last(li)
return (t, t)
end
end
@inline function iteratetoken_reverse(inds::SortedIndices, t)
li = LinearIndices(parent(inds))
t -= 1
if t < first(li)
return nothing
else
return (t, t)
end
end

@inline function gettoken(inds::SortedIndices, i)
a = parent(inds)
r = searchsorted(a, i; inds.sort_kwargs...)
@assert 0 ≤ length(r) ≤ 1 # If > 1, means the elements are not unique
length(r) == 0 && return (false, 0)
return (true, convert(Int, only(r)))
end
@propagate_inbounds gettokenvalue(inds::SortedIndices, x::Int) = parent(inds)[x]

isinsertable(i::SortedIndices) = true # Need an array trait here...

## # For `SmallVector`
## # TODO: Make this more general, based on a trait?
## isinsertable(i::SortedIndices{<:Any,<:SmallVector}) = false

@inline function gettoken!(inds::SortedIndices{I}, i::I, values=()) where {I}
a = parent(inds)
r = searchsorted(a, i; inds.sort_kwargs...)
@assert 0 ≤ length(r) ≤ 1 # If > 1, means the elements are not unique
if length(r) == 0
insert!(a, first(r), i)
foreach(v -> resize!(v, length(v) + 1), values)
return (false, last(LinearIndices(a)))
end
return (true, convert(Int, only(r)))
end

@inline function deletetoken!(inds::SortedIndices, x::Int, values=())
deleteat!(parent(inds), x)
foreach(v -> deleteat!(v, x), values)
return inds
end

function Base.empty!(inds::SortedIndices, values=())
empty!(parent(inds))
foreach(empty!, values)
return inds
end

# TODO: Make into `MSmallVector`?
empty_type(::Type{<:SortedIndices}, ::Type{I}) where {I} = SortedIndices{I,Vector{I}}

function Base.copy(inds::SortedIndices, ::Type{I}) where {I}
if I === eltype(inds)
# TODO: Disable checking unique and sorted.
SortedIndices{I}(copy(parent(inds)))
else
# TODO: Disable checking unique and sorted.
SortedIndices{I}(convert(AbstractArray{I}, parent(inds)))
end
end

# TODO: Can this take advantage of sorting?
function Base.filter!(pred, inds::SortedIndices)
filter!(pred, parent(inds))
return inds
end

function randtoken(rng::Random.AbstractRNG, inds::SortedIndices)
return rand(rng, keys(parent(inds)))
end

function Base.sort!(inds::SortedIndices; kwargs...)
# TODO: No-op, should be sorted already.
sort!(inds.inds; kwargs...)
return inds
end
end
114 changes: 114 additions & 0 deletions NDTensors/src/TagSets/src/TagSets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
module TagSets
using Dictionaries

## import Dictionaries:
## istokenizable,
## tokentype,
## iteratetoken,
## iteratetoken_reverse,
## gettoken,
## gettokenvalue,
## isinsertable,
## gettoken!,
## empty_type,
## deletetoken!,
## randtoken

using Base: @propagate_inbounds

## using NDTensors.SmallVectors
## using InlineStrings

## using NDTensors.SmallVectors: AbstractSmallVector, buffer

# A sorted collection of unique tags of type `T`.
# Add `skipchars` (see `skipmissing`) and `delim` for delimiter?
# https://docs.julialang.org/en/v1/base/strings/#Base.strip
# https://docs.julialang.org/en/v1/stdlib/DelimitedFiles/#Delimited-Files
# Add a `Bool` param for bounds checking/ignoring overflow/spillover?
# TODO: Make `S` a first argument, hardcode `SmallVector` storage?
# https://juliacollections.github.io/DataStructures.jl/v0.9/sorted_containers.html
# https://github.com/JeffreySarnoff/SortingNetworks.jl
# https://github.com/vvjn/MergeSorted.jl
# https://bkamins.github.io/julialang/2023/08/25/infiltrate.html
# https://github.com/Jutho/TensorKit.jl/blob/master/src/auxiliary/dicts.jl
# https://github.com/tpapp/SortedVectors.jl
# https://discourse.julialang.org/t/special-purpose-subtypes-of-arrays/20327
# https://discourse.julialang.org/t/all-the-ways-to-group-reduce-sorted-vectors-ideas/45239
# https://discourse.julialang.org/t/sorting-a-vector-of-fixed-size/71766
struct TagSet{T,D<:AbstractIndices{T}} <: AbstractIndices{T}
data::D
end

TagSet(vec::AbstractVector) = TagSet(Indices(vec))

# Field accessors
Base.parent(tags::TagSet) = getfield(tags, :data)

# AbstractIndices interface
@propagate_inbounds function Base.iterate(tags::TagSet, state...)
return iterate(parent(tags), state...)
end

# `I` is needed to avoid ambiguity error.
Base.in(tag::I, tags::TagSet{I}) where {I} = in(tag, parent(tags))
Base.IteratorSize(::TagSet) = Base.HasLength()
Base.length(tags::TagSet) = length(parent(tags))

Dictionaries.istokenizable(i::TagSet) = true
Dictionaries.tokentype(::TagSet) = Int
@inline Dictionaries.iteratetoken(inds::TagSet, s...) = iterate(parent(inds), s...)
@inline function Dictionaries.iteratetoken_reverse(inds::TagSet)
return iteratetoken_reverse(parent(inds))
end
@inline function Dictionaries.iteratetoken_reverse(inds::TagSet, t)
return iteratetoken_reverse(parent(inds), t)
end

@inline function Dictionaries.gettoken(inds::TagSet, i)
return gettoken(parent(inds), i)
end
@propagate_inbounds Dictionaries.gettokenvalue(inds::TagSet, x) = gettokenvalue(parent(inds), x)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved

Dictionaries.isinsertable(tags::TagSet) = true # Need an array trait here...

# Specify `I` to fix ambiguity error.
@inline function Dictionaries.gettoken!(tags::TagSet{I}, i::I, values=()) where {I}
return gettoken!(parent(tags), i, values)
end

@inline function Dictionaries.deletetoken!(tags::TagSet, x, values=())
deletetoken!(parent(tags), x, values)
return tags
end

function Base.empty!(inds::TagSet, values=())
empty!(parent(inds))
return inds
end

Dictionaries.empty_type(::Type{TagSet{I,D}}, ::Type{I}) where {I,D} = TagSet{I,D}

# Not defined to be part of the `AbstractIndices` interface,
# but seems to be needed.
function Base.filter!(pred, inds::TagSet)
filter!(pred, parent(inds))
return inds
end

function Base.copy(tags::TagSet, eltype::Type)
return TagSet(copy(parent(tags), eltype))
end

# TagSet interface
addtags(tags::TagSet, items) = union(tags, items)
removetags(tags::TagSet, items) = setdiff(tags, items)
function replacetags(tags::TagSet, rem, add)
remtags = setdiff(tags, rem)
if length(tags) ≠ length(remtags) + length(rem)
# Not all are removed, no replacement
return tags
end
return union(remtags, add)
end
end
Loading