Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Add SmallVectors submodule #1202

Merged
merged 13 commits into from
Sep 29, 2023
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("SmallVectors/src/SmallVectors.jl")
using .SmallVectors

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down
18 changes: 18 additions & 0 deletions NDTensors/src/SmallVectors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SmallVectors

A module that defines small (mutable and immutable) vectors with a maximum length. Externally the have a dynamic (or in the case of immuatable vectors, runtime) length, but internally they are backed by a statically sized vector. This makes it so that operations can be performed faster because they can remain on the stack, but it provides some more convenience compared to StaticArrays.jl where the length is encoded in the type.

For example:
```julia
using NDTensors.SmallVectors
v = SmallVector{10}([1, 2, 3]) # Immutable vector with length 3, maximum length 10
v = push(v, 4)
v = setindex(v, 4, 4)
v = sort(v; rev=true)

mv = MSmallVector{10}([1, 2, 3]) # Mutable vector with length 3, maximum length 10
push!(mv, 4)
mv[2] = 12
sort!(mv; rev=true)
```
This also has the advantage that you can efficiently store collections of `SmallVector`/`MSmallVector` that have different runtime lengths, as long as they have the same maximum length.
16 changes: 16 additions & 0 deletions NDTensors/src/SmallVectors/src/SmallVectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module SmallVectors
using StaticArrays
mtfishman marked this conversation as resolved.
Show resolved Hide resolved

export SmallVector, MSmallVector, SubSmallVector
mtfishman marked this conversation as resolved.
Show resolved Hide resolved

struct NotImplemented <: Exception
msg::String
end
NotImplemented() = NotImplemented("Not implemented.")
mtfishman marked this conversation as resolved.
Show resolved Hide resolved

include("abstractsmallvector/abstractsmallvector.jl")
include("abstractsmallvector/deque.jl")
include("msmallvector/msmallvector.jl")
include("smallvector/smallvector.jl")
include("subsmallvector/subsmallvector.jl")
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
A vector with a fixed maximum length, backed by a fixed size buffer.
"""
abstract type AbstractSmallVector{T} <: AbstractVector{T} end

# Required buffer interface
buffer(vec::AbstractSmallVector) = throw(NotImplemented())

similar_type(vec::AbstractSmallVector) = typeof(vec)

# Required buffer interface
maxlength(vec::AbstractSmallVector) = length(buffer(vec))

# Required AbstractArray interface
Base.size(vec::AbstractSmallVector) = throw(NotImplemented())

# Derived AbstractArray interface
function Base.getindex(vec::AbstractSmallVector, index::Integer)
return throw(NotImplemented())
end
function Base.setindex!(vec::AbstractSmallVector, item, index::Integer)
return throw(NotImplemented())
end
Base.IndexStyle(::Type{<:AbstractSmallVector}) = IndexLinear()
280 changes: 280 additions & 0 deletions NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# TODO: Set operations
# union, ∪, union!
# intersect, ∩, intersect!
# setdiff, setdiff!
# symdiff, symdiff!
# unique, unique!

Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented())

@inline function resize(vec::AbstractSmallVector, len)
mvec = Base.copymutable(vec)
resize!(mvec, len)
return convert(similar_type(vec), mvec)
end

@inline function Base.empty!(vec::AbstractSmallVector)
resize!(vec, 0)
return vec
end

@inline function empty(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
empty!(mvec)
return convert(similar_type(vec), mvec)
end

@inline function StaticArrays.setindex(vec::AbstractSmallVector, item, index::Integer)
@boundscheck checkbounds(vec, index)
mvec = Base.copymutable(vec)
@inbounds mvec[index] = item
return convert(similar_type(vec), mvec)
end

@inline function Base.push!(vec::AbstractSmallVector, item)
resize!(vec, length(vec) + 1)
@inbounds vec[length(vec)] = item
return vec
end

@inline function StaticArrays.push(vec::AbstractSmallVector, item)
mvec = Base.copymutable(vec)
push!(mvec, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.pop!(vec::AbstractSmallVector)
resize!(vec, length(vec) - 1)
return vec
end

@inline function StaticArrays.pop(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
pop!(mvec)
return convert(similar_type(vec), mvec)
end

@inline function Base.pushfirst!(vec::AbstractSmallVector, item)
insert!(vec, firstindex(vec), item)
return vec
end

# Don't `@inline`, makes it slower.
function StaticArrays.pushfirst(vec::AbstractSmallVector, item)
mvec = Base.copymutable(vec)
pushfirst!(mvec, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.popfirst!(vec::AbstractSmallVector)
circshift!(vec, -1)
resize!(vec, length(vec) - 1)
return vec
end

# Don't `@inline`, makes it slower.
function StaticArrays.popfirst(vec::AbstractSmallVector)
mvec = Base.copymutable(vec)
popfirst!(mvec)
return convert(similar_type(vec), mvec)
end

@inline function Base.reverse!(vec::AbstractSmallVector)
start, stop = firstindex(vec), lastindex(vec)
r = stop
@inbounds for i in start:Base.midpoint(start, stop-1)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
vec[i], vec[r] = vec[r], vec[i]
r -= 1
end
return vec
end

@inline function Base.reverse!(vec::AbstractSmallVector, start::Integer, stop::Integer=lastindex(v))
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
reverse!(smallview(vec, start, stop))
return vec
end

@inline function Base.circshift!(vec::AbstractSmallVector, shift::Integer)
start, stop = firstindex(vec), lastindex(vec)
n = length(vec)
n == 0 && return vec
shift = mod(shift, n)
shift == 0 && return vec
reverse!(smallview(vec, start, stop - shift))
reverse!(smallview(vec, stop - shift + 1, stop))
reverse!(smallview(vec, start, stop))
return vec
end

@inline function Base.insert!(vec::AbstractSmallVector, index::Integer, item)
resize!(vec, length(vec) + 1)
circshift!(smallview(vec, index, lastindex(vec)), 1)
@inbounds vec[index] = item
return vec
end

# Don't @inline, makes it slower.
function StaticArrays.insert(vec::AbstractSmallVector, index::Integer, item)
mvec = Base.copymutable(vec)
insert!(mvec, index, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.deleteat!(vec::AbstractSmallVector, index::Integer)
circshift!(smallview(vec, index, lastindex(vec)), -1)
resize!(vec, length(vec) - 1)
return vec
end

# Don't @inline, makes it slower.
function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer)
mvec = Base.copymutable(vec)
deleteat!(mvec, index)
return convert(similar_type(vec), mvec)
end

# InsertionSortAlg
# https://github.com/JuliaLang/julia/blob/bed2cd540a11544ed4be381d471bbf590f0b745e/base/sort.jl#L722-L736
# https://en.wikipedia.org/wiki/Insertion_sort#:~:text=Insertion%20sort%20is%20a%20simple,%2C%20heapsort%2C%20or%20merge%20sort.
# Alternatively could use `TupleTools.jl` or `StaticArrays.jl` for out-of-place sorting.
@inline function Base.sort!(vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
lo, hi = firstindex(vec), lastindex(vec)
lo_plus_1 = (lo + 1)
@inbounds for i in lo_plus_1:hi
j = i
x = vec[i]
jmax = j
for _ in jmax:-1:lo_plus_1
y = vec[j - 1]
if !(lt(by(x), by(y)) != rev)
break
end
vec[j] = y
j -= 1
end
vec[j] = x
end
return vec
end

# Don't @inline, makes it slower.
function Base.sort(vec::AbstractSmallVector; kwargs...)
mvec = Base.copymutable(vec)
sort!(mvec; kwargs...)
return convert(similar_type(vec), mvec)
end

@inline function insertsorted!(vec::AbstractSmallVector, item; kwargs...)
insert!(vec, searchsortedfirst(vec, item; kwargs...), item)
return vec
end

function insertsorted(vec::AbstractSmallVector, item; kwargs...)
mvec = Base.copymutable(vec)
insertsorted!(mvec, item; kwargs...)
return convert(similar_type(vec), mvec)
end

@inline function insertsortedunique!(vec::AbstractSmallVector, item; kwargs...)
r = searchsorted(vec, item; kwargs...)
if length(r) == 0
insert!(vec, first(r), item)
end
return vec
end

# Code repeated since inlining doesn't work.
function insertsortedunique(vec::AbstractSmallVector, item; kwargs...)
r = searchsorted(vec, item; kwargs...)
if length(r) == 0
vec = insert(vec, first(r), item)
end
return vec
end

@inline function mergesorted!(vec::AbstractSmallVector, item::AbstractVector; kwargs...)
for x in item
insertsorted!(vec, x; kwargs...)
end
return vec
end

function mergesorted(vec::AbstractSmallVector, item; kwargs...)
mvec = Base.copymutable(vec)
mergesorted!(mvec, item; kwargs...)
return convert(similar_type(vec), mvec)
end

@inline function mergesortedunique!(vec::AbstractSmallVector, item::AbstractVector; kwargs...)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
for x in item
insertsortedunique!(vec, x; kwargs...)
end
return vec
end

# Code repeated since inlining doesn't work.
function mergesortedunique(vec::AbstractSmallVector, item; kwargs...)
for x in item
vec = insertsortedunique(vec, x; kwargs...)
end
return vec
end

Base.@propagate_inbounds function Base.copyto!(vec::AbstractSmallVector, item::AbstractVector)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
for i in eachindex(item)
vec[i] = item[i]
end
return vec
end

# Don't @inline, makes it slower.
function Base.circshift(vec::AbstractSmallVector, shift::Integer)
mvec = Base.copymutable(vec)
circshift!(mvec, shift)
return convert(similar_type(vec), mvec)
end

@inline function Base.append!(vec::AbstractSmallVector, item::AbstractVector)
l = length(vec)
r = length(item)
resize!(vec, l + r)
@inbounds copyto!(smallview(vec, l + 1, l + r + 1), item)
return vec
end

# Missing from `StaticArrays.jl`.
# Don't @inline, makes it slower.
function append(vec::AbstractSmallVector, item::AbstractVector)
mvec = Base.copymutable(vec)
append!(mvec, item)
return convert(similar_type(vec), mvec)
end

@inline function Base.prepend!(vec::AbstractSmallVector, item::AbstractVector)
l = length(vec)
r = length(item)
resize!(vec, l + r)
circshift!(vec, length(item))
@inbounds copyto!(vec, item)
return vec
end

# Missing from `StaticArrays.jl`.
# Don't @inline, makes it slower.
function prepend(vec::AbstractSmallVector, item::AbstractVector)
mvec = Base.copymutable(vec)
prepend!(mvec, item)
return convert(similar_type(vec), mvec)
end

# Don't @inline, makes it slower.
function Base.vcat(vec1::AbstractSmallVector, vec2::AbstractVector)
mvec1 = Base.copymutable(vec1)
append!(mvec1, vec2)
return convert(similar_type(vec1), mvec1)
end

# TODO: inline when defined.
function Base.splice!(a::AbstractSmallVector, args...)
return throw(NotImplemented())
end
Loading
Loading