Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Add SmallVectors submodule #1202

Merged
merged 13 commits into from
Sep 29, 2023
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("SmallVectors/src/SmallVectors.jl")
using .SmallVectors

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down
73 changes: 73 additions & 0 deletions NDTensors/src/SmallVectors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SmallVectors

## Introduction

A module that defines small (mutable and immutable) vectors with a maximum length. Externally they have a dynamic/runtime length, but internally they are backed by a statically sized vector. This makes it so that operations can be performed faster because they can remain on the stack, but it provides some more convenience compared to StaticArrays.jl where the length is encoded in the type.

## Examples

For example:
```julia
using NDTensors.SmallVectors

mv = MSmallVector{10}([1, 2, 3]) # Mutable vector with length 3, maximum length 10
push!(mv, 4)
mv[2] = 12
sort!(mv; rev=true)

v = SmallVector{10}([1, 2, 3]) # Immutable vector with length 3, maximum length 10
v = SmallVectors.push(v, 4)
v = SmallVectors.setindex(v, 12, 2)
v = SmallVectors.sort(v; rev=true)
```
This also has the advantage that you can efficiently store collections of `SmallVector`/`MSmallVector` that have different runtime lengths, as long as they have the same maximum length.

## List of functionality

`SmallVector` and `MSmallVector` are subtypes of `AbstractVector` and therefore can be used in `Base` `AbstractVector` functions, though `SmallVector` will fail for mutating functions like `setindex!` because it is immutable.

`MSmallVector` has specialized implementations of `Base` functions that involve resizing such as:
- `resize!`
- `push!`
- `pushfirst!`
- `pop!`
- `popfirst!`
- `append!`
- `prepend!`
- `insert!`
- `deleteat!`
which are guaranteed to not realocate memory, and instead just use the memory buffer that already exists, unlike Base's `Vector` which may have to reallocate memory depending on the operation. However, they will error if they involve operations that resize beyond the maximum length of the `MSmallVector`, which you can access with `SmallVectors.maxlength(v)`.

In addition, `SmallVector` and `MSmallVector` implement basic non-mutating operations such as:
- `SmallVectors.setindex`
, non-mutating resizing operations:
- `SmallVector.resize`
- `SmallVector.push`
- `SmallVector.pushfirst`
- `SmallVector.pop`
- `SmallVector.popfirst`
- `SmallVector.append`
- `SmallVector.prepend`
- `SmallVector.insert`
- `SmallVector.deleteat`
which output a new vector. In addition, it implements:
- `SmallVectors.circshift`
- `sort` (overloaded from `Base`).

Finally, it provides some new helpful functions that are not in `Base`:
- `SmallVectors.insertsorted[!]`
- `SmallVectors.insertsortedunique[!]`
- `SmallVectors.mergesorted[!]`
- `SmallVectors.mergesortedunique[!]`

## TODO

Add specialized overloads for:
- `splice[!]`
- `union[!]` (`∪`)
- `intersect[!]` (`∩`)
- `setdiff[!]`
- `symdiff[!]`
- `unique[!]`

Please let us know if there are other operations that would warrant specialized implmentations for `AbstractSmallVector`.
16 changes: 16 additions & 0 deletions NDTensors/src/SmallVectors/src/SmallVectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module SmallVectors
using StaticArrays

export SmallVector, MSmallVector, SubSmallVector

struct NotImplemented <: Exception
msg::String
end
NotImplemented() = NotImplemented("Not implemented.")

include("abstractsmallvector/abstractsmallvector.jl")
include("abstractsmallvector/deque.jl")
include("msmallvector/msmallvector.jl")
include("smallvector/smallvector.jl")
include("subsmallvector/subsmallvector.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
A vector with a fixed maximum length, backed by a fixed size buffer.
"""
abstract type AbstractSmallVector{T} <: AbstractVector{T} end

# Required buffer interface
buffer(vec::AbstractSmallVector) = throw(NotImplemented())

similar_type(vec::AbstractSmallVector) = typeof(vec)

# Required buffer interface
maxlength(vec::AbstractSmallVector) = length(buffer(vec))

# Required AbstractArray interface
Base.size(vec::AbstractSmallVector) = throw(NotImplemented())

# Derived AbstractArray interface
function Base.getindex(vec::AbstractSmallVector, index::Integer)
return throw(NotImplemented())
end
function Base.setindex!(vec::AbstractSmallVector, item, index::Integer)
return throw(NotImplemented())
end
Base.IndexStyle(::Type{<:AbstractSmallVector}) = IndexLinear()

Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractSmallVector} = a isa T ? a : T(a)::T
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading