diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index ae138d11fc..80f4a7fd8b 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -19,6 +19,8 @@ using TupleTools include("SetParameters/src/SetParameters.jl") using .SetParameters +include("SmallVectors/src/SmallVectors.jl") +using .SmallVectors using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo diff --git a/NDTensors/src/SmallVectors/README.md b/NDTensors/src/SmallVectors/README.md new file mode 100644 index 0000000000..71ee904af7 --- /dev/null +++ b/NDTensors/src/SmallVectors/README.md @@ -0,0 +1,73 @@ +# SmallVectors + +## Introduction + +A module that defines small (mutable and immutable) vectors with a maximum length. Externally they have a dynamic/runtime length, but internally they are backed by a statically sized vector. This makes it so that operations can be performed faster because they can remain on the stack, but it provides some more convenience compared to StaticArrays.jl where the length is encoded in the type. + +## Examples + +For example: +```julia +using NDTensors.SmallVectors + +mv = MSmallVector{10}([1, 2, 3]) # Mutable vector with length 3, maximum length 10 +push!(mv, 4) +mv[2] = 12 +sort!(mv; rev=true) + +v = SmallVector{10}([1, 2, 3]) # Immutable vector with length 3, maximum length 10 +v = SmallVectors.push(v, 4) +v = SmallVectors.setindex(v, 12, 2) +v = SmallVectors.sort(v; rev=true) +``` +This also has the advantage that you can efficiently store collections of `SmallVector`/`MSmallVector` that have different runtime lengths, as long as they have the same maximum length. + +## List of functionality + +`SmallVector` and `MSmallVector` are subtypes of `AbstractVector` and therefore can be used in `Base` `AbstractVector` functions, though `SmallVector` will fail for mutating functions like `setindex!` because it is immutable. + +`MSmallVector` has specialized implementations of `Base` functions that involve resizing such as: +- `resize!` +- `push!` +- `pushfirst!` +- `pop!` +- `popfirst!` +- `append!` +- `prepend!` +- `insert!` +- `deleteat!` +which are guaranteed to not realocate memory, and instead just use the memory buffer that already exists, unlike Base's `Vector` which may have to reallocate memory depending on the operation. However, they will error if they involve operations that resize beyond the maximum length of the `MSmallVector`, which you can access with `SmallVectors.maxlength(v)`. + +In addition, `SmallVector` and `MSmallVector` implement basic non-mutating operations such as: +- `SmallVectors.setindex` +, non-mutating resizing operations: +- `SmallVector.resize` +- `SmallVector.push` +- `SmallVector.pushfirst` +- `SmallVector.pop` +- `SmallVector.popfirst` +- `SmallVector.append` +- `SmallVector.prepend` +- `SmallVector.insert` +- `SmallVector.deleteat` +which output a new vector. In addition, it implements: +- `SmallVectors.circshift` +- `sort` (overloaded from `Base`). + +Finally, it provides some new helpful functions that are not in `Base`: +- `SmallVectors.insertsorted[!]` +- `SmallVectors.insertsortedunique[!]` +- `SmallVectors.mergesorted[!]` +- `SmallVectors.mergesortedunique[!]` + +## TODO + +Add specialized overloads for: +- `splice[!]` +- `union[!]` (`∪`) +- `intersect[!]` (`∩`) +- `setdiff[!]` +- `symdiff[!]` +- `unique[!]` + +Please let us know if there are other operations that would warrant specialized implmentations for `AbstractSmallVector`. diff --git a/NDTensors/src/SmallVectors/src/SmallVectors.jl b/NDTensors/src/SmallVectors/src/SmallVectors.jl new file mode 100644 index 0000000000..e3f7083795 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/SmallVectors.jl @@ -0,0 +1,16 @@ +module SmallVectors +using StaticArrays + +export SmallVector, MSmallVector, SubSmallVector + +struct NotImplemented <: Exception + msg::String +end +NotImplemented() = NotImplemented("Not implemented.") + +include("abstractsmallvector/abstractsmallvector.jl") +include("abstractsmallvector/deque.jl") +include("msmallvector/msmallvector.jl") +include("smallvector/smallvector.jl") +include("subsmallvector/subsmallvector.jl") +end diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl new file mode 100644 index 0000000000..014f0e0d9b --- /dev/null +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/abstractsmallvector.jl @@ -0,0 +1,28 @@ +""" +A vector with a fixed maximum length, backed by a fixed size buffer. +""" +abstract type AbstractSmallVector{T} <: AbstractVector{T} end + +# Required buffer interface +buffer(vec::AbstractSmallVector) = throw(NotImplemented()) + +similar_type(vec::AbstractSmallVector) = typeof(vec) + +# Required buffer interface +maxlength(vec::AbstractSmallVector) = length(buffer(vec)) + +# Required AbstractArray interface +Base.size(vec::AbstractSmallVector) = throw(NotImplemented()) + +# Derived AbstractArray interface +function Base.getindex(vec::AbstractSmallVector, index::Integer) + return throw(NotImplemented()) +end +function Base.setindex!(vec::AbstractSmallVector, item, index::Integer) + return throw(NotImplemented()) +end +Base.IndexStyle(::Type{<:AbstractSmallVector}) = IndexLinear() + +function Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractSmallVector} + return a isa T ? a : T(a)::T +end diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl new file mode 100644 index 0000000000..96c1e80640 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl @@ -0,0 +1,295 @@ +# TODO: add +# splice[!] +# union[!] (∪) +# intersect[!] (∩) +# setdiff[!] +# symdiff[!] +# unique[!] + +# unionsorted[!] +# setdiffsorted[!] +# deletesorted[!] (delete all or one?) +# deletesortedfirst[!] (delete all or one?) + +Base.resize!(vec::AbstractSmallVector, len) = throw(NotImplemented()) + +@inline function resize(vec::AbstractSmallVector, len) + mvec = Base.copymutable(vec) + resize!(mvec, len) + return convert(similar_type(vec), mvec) +end + +@inline function Base.empty!(vec::AbstractSmallVector) + resize!(vec, 0) + return vec +end + +@inline function empty(vec::AbstractSmallVector) + mvec = Base.copymutable(vec) + empty!(mvec) + return convert(similar_type(vec), mvec) +end + +@inline function StaticArrays.setindex(vec::AbstractSmallVector, item, index::Integer) + @boundscheck checkbounds(vec, index) + mvec = Base.copymutable(vec) + @inbounds mvec[index] = item + return convert(similar_type(vec), mvec) +end + +@inline function Base.push!(vec::AbstractSmallVector, item) + resize!(vec, length(vec) + 1) + @inbounds vec[length(vec)] = item + return vec +end + +@inline function StaticArrays.push(vec::AbstractSmallVector, item) + mvec = Base.copymutable(vec) + push!(mvec, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.pop!(vec::AbstractSmallVector) + resize!(vec, length(vec) - 1) + return vec +end + +@inline function StaticArrays.pop(vec::AbstractSmallVector) + mvec = Base.copymutable(vec) + pop!(mvec) + return convert(similar_type(vec), mvec) +end + +@inline function Base.pushfirst!(vec::AbstractSmallVector, item) + insert!(vec, firstindex(vec), item) + return vec +end + +# Don't `@inline`, makes it slower. +function StaticArrays.pushfirst(vec::AbstractSmallVector, item) + mvec = Base.copymutable(vec) + pushfirst!(mvec, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.popfirst!(vec::AbstractSmallVector) + circshift!(vec, -1) + resize!(vec, length(vec) - 1) + return vec +end + +# Don't `@inline`, makes it slower. +function StaticArrays.popfirst(vec::AbstractSmallVector) + mvec = Base.copymutable(vec) + popfirst!(mvec) + return convert(similar_type(vec), mvec) +end + +# This implementation of `midpoint` is performance-optimized but safe +# only if `lo <= hi`. +# TODO: Replace with `Base.midpoint`. +midpoint(lo::T, hi::T) where {T<:Integer} = lo + ((hi - lo) >>> 0x01) +midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...) + +@inline function Base.reverse!(vec::AbstractSmallVector) + start, stop = firstindex(vec), lastindex(vec) + r = stop + @inbounds for i in start:midpoint(start, stop - 1) + vec[i], vec[r] = vec[r], vec[i] + r -= 1 + end + return vec +end + +@inline function Base.reverse!( + vec::AbstractSmallVector, start::Integer, stop::Integer=lastindex(v) +) + reverse!(smallview(vec, start, stop)) + return vec +end + +@inline function Base.circshift!(vec::AbstractSmallVector, shift::Integer) + start, stop = firstindex(vec), lastindex(vec) + n = length(vec) + n == 0 && return vec + shift = mod(shift, n) + shift == 0 && return vec + reverse!(smallview(vec, start, stop - shift)) + reverse!(smallview(vec, stop - shift + 1, stop)) + reverse!(smallview(vec, start, stop)) + return vec +end + +@inline function Base.insert!(vec::AbstractSmallVector, index::Integer, item) + resize!(vec, length(vec) + 1) + circshift!(smallview(vec, index, lastindex(vec)), 1) + @inbounds vec[index] = item + return vec +end + +# Don't @inline, makes it slower. +function StaticArrays.insert(vec::AbstractSmallVector, index::Integer, item) + mvec = Base.copymutable(vec) + insert!(mvec, index, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.deleteat!(vec::AbstractSmallVector, index::Integer) + circshift!(smallview(vec, index, lastindex(vec)), -1) + resize!(vec, length(vec) - 1) + return vec +end + +# Don't @inline, makes it slower. +function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer) + mvec = Base.copymutable(vec) + deleteat!(mvec, index) + return convert(similar_type(vec), mvec) +end + +# InsertionSortAlg +# https://github.com/JuliaLang/julia/blob/bed2cd540a11544ed4be381d471bbf590f0b745e/base/sort.jl#L722-L736 +# https://en.wikipedia.org/wiki/Insertion_sort#:~:text=Insertion%20sort%20is%20a%20simple,%2C%20heapsort%2C%20or%20merge%20sort. +# Alternatively could use `TupleTools.jl` or `StaticArrays.jl` for out-of-place sorting. +@inline function Base.sort!( + vec::AbstractSmallVector; lt=isless, by=identity, rev::Bool=false +) + lo, hi = firstindex(vec), lastindex(vec) + lo_plus_1 = (lo + 1) + @inbounds for i in lo_plus_1:hi + j = i + x = vec[i] + jmax = j + for _ in jmax:-1:lo_plus_1 + y = vec[j - 1] + if !(lt(by(x), by(y)) != rev) + break + end + vec[j] = y + j -= 1 + end + vec[j] = x + end + return vec +end + +# Don't @inline, makes it slower. +function Base.sort(vec::AbstractSmallVector; kwargs...) + mvec = Base.copymutable(vec) + sort!(mvec; kwargs...) + return convert(similar_type(vec), mvec) +end + +@inline function insertsorted!(vec::AbstractSmallVector, item; kwargs...) + insert!(vec, searchsortedfirst(vec, item; kwargs...), item) + return vec +end + +function insertsorted(vec::AbstractSmallVector, item; kwargs...) + mvec = Base.copymutable(vec) + insertsorted!(mvec, item; kwargs...) + return convert(similar_type(vec), mvec) +end + +@inline function insertsortedunique!(vec::AbstractSmallVector, item; kwargs...) + r = searchsorted(vec, item; kwargs...) + if length(r) == 0 + insert!(vec, first(r), item) + end + return vec +end + +# Code repeated since inlining doesn't work. +function insertsortedunique(vec::AbstractSmallVector, item; kwargs...) + r = searchsorted(vec, item; kwargs...) + if length(r) == 0 + vec = insert(vec, first(r), item) + end + return vec +end + +@inline function mergesorted!(vec::AbstractSmallVector, item::AbstractVector; kwargs...) + for x in item + insertsorted!(vec, x; kwargs...) + end + return vec +end + +function mergesorted(vec::AbstractSmallVector, item; kwargs...) + mvec = Base.copymutable(vec) + mergesorted!(mvec, item; kwargs...) + return convert(similar_type(vec), mvec) +end + +@inline function mergesortedunique!( + vec::AbstractSmallVector, item::AbstractVector; kwargs... +) + for x in item + insertsortedunique!(vec, x; kwargs...) + end + return vec +end + +# Code repeated since inlining doesn't work. +function mergesortedunique(vec::AbstractSmallVector, item; kwargs...) + for x in item + vec = insertsortedunique(vec, x; kwargs...) + end + return vec +end + +Base.@propagate_inbounds function Base.copyto!( + vec::AbstractSmallVector, item::AbstractVector +) + for i in eachindex(item) + vec[i] = item[i] + end + return vec +end + +# Don't @inline, makes it slower. +function Base.circshift(vec::AbstractSmallVector, shift::Integer) + mvec = Base.copymutable(vec) + circshift!(mvec, shift) + return convert(similar_type(vec), mvec) +end + +@inline function Base.append!(vec::AbstractSmallVector, item::AbstractVector) + l = length(vec) + r = length(item) + resize!(vec, l + r) + @inbounds copyto!(smallview(vec, l + 1, l + r + 1), item) + return vec +end + +# Missing from `StaticArrays.jl`. +# Don't @inline, makes it slower. +function append(vec::AbstractSmallVector, item::AbstractVector) + mvec = Base.copymutable(vec) + append!(mvec, item) + return convert(similar_type(vec), mvec) +end + +@inline function Base.prepend!(vec::AbstractSmallVector, item::AbstractVector) + l = length(vec) + r = length(item) + resize!(vec, l + r) + circshift!(vec, length(item)) + @inbounds copyto!(vec, item) + return vec +end + +# Missing from `StaticArrays.jl`. +# Don't @inline, makes it slower. +function prepend(vec::AbstractSmallVector, item::AbstractVector) + mvec = Base.copymutable(vec) + prepend!(mvec, item) + return convert(similar_type(vec), mvec) +end + +# Don't @inline, makes it slower. +function Base.vcat(vec1::AbstractSmallVector, vec2::AbstractVector) + mvec1 = Base.copymutable(vec1) + append!(mvec1, vec2) + return convert(similar_type(vec1), mvec1) +end diff --git a/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl b/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl new file mode 100644 index 0000000000..434d1e4d48 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/msmallvector/msmallvector.jl @@ -0,0 +1,75 @@ +""" +MSmallVector + +TODO: Make `buffer` field `const` (new in Julia 1.8) +""" +mutable struct MSmallVector{S,T} <: AbstractSmallVector{T} + buffer::MVector{S,T} + length::Int +end + +# Constructors +function MSmallVector{S}(buffer::AbstractVector, len::Int) where {S} + return MSmallVector{S,eltype(buffer)}(buffer, len) +end +function MSmallVector(buffer::AbstractVector, len::Int) + return MSmallVector{length(buffer),eltype(buffer)}(buffer, len) +end + +""" +`MSmallVector` constructor, uses `MVector` as a buffer. +```julia +MSmallVector{10}([1, 2, 3]) +MSmallVector{10}(SA[1, 2, 3]) +``` +""" +function MSmallVector{S,T}(vec::AbstractVector) where {S,T} + buffer = MVector{S,T}(undef) + copyto!(buffer, vec) + return MSmallVector(buffer, length(vec)) +end + +# Derive the buffer length. +MSmallVector(vec::AbstractSmallVector) = MSmallVector{length(buffer(vec))}(vec) + +function MSmallVector{S}(vec::AbstractVector) where {S} + return MSmallVector{S,eltype(vec)}(vec) +end + +function MSmallVector{S,T}(::UndefInitializer, dims::Tuple{Integer}) where {S,T} + return MSmallVector{S,T}(undef, prod(dims)) +end +function MSmallVector{S,T}(::UndefInitializer, length::Integer) where {S,T} + return MSmallVector{S,T}(MVector{S,T}(undef), length) +end + +# Buffer interface +buffer(vec::MSmallVector) = vec.buffer + +# Accessors +Base.size(vec::MSmallVector) = (vec.length,) + +# Required Base overloads +@inline function Base.getindex(vec::MSmallVector, index::Integer) + @boundscheck checkbounds(vec, index) + return @inbounds buffer(vec)[index] +end + +@inline function Base.setindex!(vec::MSmallVector, item, index::Integer) + @boundscheck checkbounds(vec, index) + @inbounds buffer(vec)[index] = item + return vec +end + +@inline function Base.resize!(vec::MSmallVector, len::Integer) + len < 0 && throw(ArgumentError("New length must be ≥ 0.")) + len > maxlength(vec) && + throw(ArgumentError("New length $len must be ≤ the maximum length $(maxlength(vec)).")) + vec.length = len + return vec +end + +# `similar` creates a `MSmallVector` by default. +function Base.similar(vec::AbstractSmallVector, elt::Type, dims::Dims) + return MSmallVector{length(buffer(vec)),elt}(undef, dims) +end diff --git a/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl b/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl new file mode 100644 index 0000000000..3976480f47 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/smallvector/smallvector.jl @@ -0,0 +1,71 @@ +""" +SmallVector +""" +struct SmallVector{S,T} <: AbstractSmallVector{T} + buffer::SVector{S,T} + length::Int +end + +# Accessors +# TODO: Use `Accessors.jl`. +@inline setbuffer(vec::SmallVector, buffer) = SmallVector(buffer, vec.length) +@inline setlength(vec::SmallVector, length) = SmallVector(vec.buffer, length) + +# Constructors +function SmallVector{S}(buffer::AbstractVector, len::Int) where {S} + return SmallVector{S,eltype(buffer)}(buffer, len) +end +function SmallVector(buffer::AbstractVector, len::Int) + return SmallVector{length(buffer),eltype(buffer)}(buffer, len) +end + +""" +`SmallVector` constructor, uses `SVector` as buffer storage. +```julia +SmallVector{10}([1, 2, 3]) +SmallVector{10}(SA[1, 2, 3]) +``` +""" +function SmallVector{S,T}(vec::AbstractVector) where {S,T} + return SmallVector{S,T}(MSmallVector{S,T}(vec)) +end +# Special optimization codepath for `MSmallVector` +# to avoid a copy. +function SmallVector{S,T}(vec::MSmallVector) where {S,T} + return SmallVector{S,T}(buffer(vec), length(vec)) +end + +function SmallVector{S}(vec::AbstractVector) where {S} + return SmallVector{S,eltype(vec)}(vec) +end + +# Specialized constructor +function MSmallVector{S,T}(vec::SmallVector) where {S,T} + return MSmallVector{S,T}(buffer(vec), length(vec)) +end + +# Derive the buffer length. +SmallVector(vec::AbstractSmallVector) = SmallVector{length(buffer(vec))}(vec) + +# Empty constructor +(smallvector_type::Type{SmallVector{S,T}} where {S,T})() = smallvector_type(undef, 0) +function SmallVector{S,T}(::UndefInitializer, length::Integer) where {S,T} + return SmallVector{S,T}(SVector{S,T}(MVector{S,T}(undef)), length) +end + +# Buffer interface +buffer(vec::SmallVector) = vec.buffer + +# AbstractArray interface +Base.size(vec::SmallVector) = (vec.length,) + +# Base overloads +@inline function Base.getindex(vec::SmallVector, index::Integer) + @boundscheck checkbounds(vec, index) + return @inbounds buffer(vec)[index] +end + +Base.copy(vec::SmallVector) = vec + +# Optimization, default uses `similar`. +Base.copymutable(vec::SmallVector) = MSmallVector(vec) diff --git a/NDTensors/src/SmallVectors/src/subsmallvector/subsmallvector.jl b/NDTensors/src/SmallVectors/src/subsmallvector/subsmallvector.jl new file mode 100644 index 0000000000..922eaa3521 --- /dev/null +++ b/NDTensors/src/SmallVectors/src/subsmallvector/subsmallvector.jl @@ -0,0 +1,80 @@ +abstract type AbstractSubSmallVector{T} <: AbstractSmallVector{T} end + +""" +SubSmallVector +""" +struct SubSmallVector{T,P} <: AbstractSubSmallVector{T} + parent::P + start::Int + stop::Int +end + +mutable struct SubMSmallVector{T,P<:AbstractVector{T}} <: AbstractSubSmallVector{T} + parent::P + start::Int + stop::Int +end + +# TODO: Use Accessors.jl +Base.parent(vec::SubSmallVector) = vec.parent +Base.parent(vec::SubMSmallVector) = vec.parent + +# buffer interface +buffer(vec::AbstractSubSmallVector) = buffer(parent(vec)) + +function smallview(vec::SmallVector, start::Integer, stop::Integer) + return SubSmallVector(vec, start, stop) +end +function smallview(vec::MSmallVector, start::Integer, stop::Integer) + return SubMSmallVector(vec, start, stop) +end + +function smallview(vec::SubSmallVector, start::Integer, stop::Integer) + return SubSmallVector(parent(vec), vec.start + start - 1, vec.start + stop - 1) +end +function smallview(vec::SubMSmallVector, start::Integer, stop::Integer) + return SubMSmallVector(parent(vec), vec.start + start - 1, vec.start + stop - 1) +end + +# Constructors +function SubSmallVector(vec::AbstractVector, start::Integer, stop::Integer) + return SubSmallVector{eltype(vec),typeof(vec)}(vec, start, stop) +end +function SubMSmallVector(vec::AbstractVector, start::Integer, stop::Integer) + return SubMSmallVector{eltype(vec),typeof(vec)}(vec, start, stop) +end + +# Accessors +Base.size(vec::AbstractSubSmallVector) = (vec.stop - vec.start + 1,) + +Base.@propagate_inbounds function Base.getindex(vec::AbstractSubSmallVector, index::Integer) + return parent(vec)[index + vec.start - 1] +end + +Base.@propagate_inbounds function Base.setindex!( + vec::AbstractSubSmallVector, item, index::Integer +) + buffer(vec)[index + vec.start - 1] = item + return vec +end + +function SubSmallVector{T,P}(vec::SubMSmallVector) where {T,P} + return SubSmallVector{T,P}(P(parent(vec)), vec.start, vec.stop) +end + +function Base.convert(smalltype::Type{<:SubSmallVector}, vec::SubMSmallVector) + return smalltype(vec) +end + +@inline function Base.resize!(vec::SubMSmallVector, len::Integer) + len < 0 && throw(ArgumentError("New length must be ≥ 0.")) + len > maxlength(vec) - vec.start + 1 && + throw(ArgumentError("New length $len must be ≤ the maximum length $(maxlength(vec)).")) + vec.stop = vec.start + len - 1 + return vec +end + +# Optimization, default uses `similar`. +function Base.copymutable(vec::SubSmallVector) + return SubMSmallVector(Base.copymutable(parent(vec)), vec.start, vec.stop) +end diff --git a/NDTensors/src/SmallVectors/test/runtests.jl b/NDTensors/src/SmallVectors/test/runtests.jl new file mode 100644 index 0000000000..4b1daac154 --- /dev/null +++ b/NDTensors/src/SmallVectors/test/runtests.jl @@ -0,0 +1,155 @@ +using NDTensors.SmallVectors +using Test + +using NDTensors.SmallVectors: + setindex, + resize, + push, + pushfirst, + pop, + popfirst, + append, + prepend, + insert, + deleteat, + circshift, + insertsorted, + insertsorted!, + insertsortedunique, + insertsortedunique!, + mergesorted, + mergesorted!, + mergesortedunique, + mergesortedunique! + +function test_smallvectors() + return @testset "SmallVectors" begin + x = SmallVector{10}([1, 3, 5]) + mx = MSmallVector(x) + + @test x isa SmallVector{10,Int} + @test mx isa MSmallVector{10,Int} + @test eltype(x) === Int + @test eltype(mx) === Int + + # TODO: Test construction has zero allocations. + # TODO: Extend construction to arbitrary collections, like tuple. + + # conversion + @test @inferred(SmallVector(x)) == x + @test @allocated(SmallVector(x)) == 0 + @test @inferred(SmallVector(mx)) == x + @test @allocated(SmallVector(mx)) == 0 + + # length + @test @inferred(length(x)) == 3 + @test @allocated(length(x)) == 0 + @test @inferred(length(SmallVectors.buffer(x))) == 10 + @test @allocated(length(SmallVectors.buffer(x))) == 0 + + nalloc_limit = 128 + + item = 115 + no_broken = (false, false, false, false) + for ( + f!, + f, + ans, + args, + nalloc, + f!_impl_broken, + f!_noalloc_broken, + f_impl_broken, + f_noalloc_broken, + ) in [ + (:push!, :push, [1, 3, 5, item], (item,), nalloc_limit, no_broken...), + (:append!, :append, [1, 3, 5, item], ([item],), nalloc_limit, no_broken...), + (:prepend!, :prepend, [item, 1, 3, 5], ([item],), nalloc_limit, no_broken...), + (:pushfirst!, :pushfirst, [item, 1, 3, 5], (item,), nalloc_limit, no_broken...), + (:setindex!, :setindex, [1, item, 5], (item, 2), nalloc_limit, no_broken...), + (:pop!, :pop, [1, 3], (), nalloc_limit, no_broken...), + (:popfirst!, :popfirst, [3, 5], (), nalloc_limit, no_broken...), + (:insert!, :insert, [1, item, 3, 5], (2, item), nalloc_limit, no_broken...), + (:deleteat!, :deleteat, [1, 5], (2,), nalloc_limit, no_broken...), + (:circshift!, :circshift, [5, 1, 3], (1,), nalloc_limit, no_broken...), + (:sort!, :sort, [1, 3, 5], (), nalloc_limit, no_broken...), + (:insertsorted!, :insertsorted, [1, 2, 3, 5], (2,), nalloc_limit, no_broken...), + (:insertsorted!, :insertsorted, [1, 3, 3, 5], (3,), nalloc_limit, no_broken...), + ( + :insertsortedunique!, + :insertsortedunique, + [1, 2, 3, 5], + (2,), + nalloc_limit, + no_broken..., + ), + ( + :insertsortedunique!, + :insertsortedunique, + [1, 3, 5], + (3,), + nalloc_limit, + no_broken..., + ), + (:mergesorted!, :mergesorted, [1, 2, 3, 3, 5], ([2, 3],), nalloc_limit, no_broken...), + ( + :mergesortedunique!, + :mergesortedunique, + [1, 2, 3, 5], + ([2, 3],), + nalloc_limit, + no_broken..., + ), + ] + mx_tmp = copy(mx) + @eval begin + if VERSION < v"1.7" + # broken kwarg wasn't added to @test yet + if $f!_impl_broken + @test_broken @inferred($f!(copy($mx), $args...)) == $ans + else + @test @inferred($f!(copy($mx), $args...)) == $ans + end + if $f!_noalloc_broken + @test_broken @allocated($f!($mx_tmp, $args...)) ≤ $nalloc + else + @test @allocated($f!($mx_tmp, $args...)) ≤ $nalloc + end + if $f_impl_broken + @test_broken @inferred($f($x, $args...)) == $ans + else + @test @inferred($f($x, $args...)) == $ans + end + if $f_noalloc_broken + @test_broken @allocated($f($x, $args...)) ≤ $nalloc + else + @test @allocated($f($x, $args...)) ≤ $nalloc + end + else + @test @inferred($f!(copy($mx), $args...)) == $ans broken = $f!_impl_broken + @test @allocated($f!($mx_tmp, $args...)) ≤ $nalloc broken = $f!_noalloc_broken + @test @inferred($f($x, $args...)) == $ans broken = $f_impl_broken + @test @allocated($f($x, $args...)) ≤ $nalloc broken = $f_noalloc_broken + end + end + end + + # Separated out since for some reason it breaks the `@inferred` + # check when `kwargs` are interpolated into `@eval`. + ans, kwargs = [5, 3, 1], (; rev=true) + mx_tmp = copy(mx) + @test @inferred(sort!(copy(mx); kwargs...)) == ans + @test @allocated(sort!(mx_tmp; kwargs...)) == 0 + @test @inferred(sort(x; kwargs...)) == ans + @test @allocated(sort(x; kwargs...)) ≤ nalloc_limit + + ans, args = [1, 3, 5, item], ([item],) + @test @inferred(vcat(x, args...)) == ans + @test @allocated(vcat(x, args...)) ≤ nalloc_limit + end +end + +# TODO: switch to: +# @testset "SmallVectors" test_smallvectors() +# (new in Julia 1.9) +test_smallvectors() diff --git a/NDTensors/test/SmallVectors.jl b/NDTensors/test/SmallVectors.jl new file mode 100644 index 0000000000..62b552dc72 --- /dev/null +++ b/NDTensors/test/SmallVectors.jl @@ -0,0 +1,4 @@ +using Test +using NDTensors + +include(joinpath(pkgdir(NDTensors), "src", "SmallVectors", "test", "runtests.jl")) diff --git a/NDTensors/test/runtests.jl b/NDTensors/test/runtests.jl index 90aeeea118..274e2303ed 100644 --- a/NDTensors/test/runtests.jl +++ b/NDTensors/test/runtests.jl @@ -20,6 +20,7 @@ end @safetestset "NDTensors" begin @testset "$filename" for filename in [ "SetParameters.jl", + "SmallVectors.jl", "linearalgebra.jl", "dense.jl", "blocksparse.jl",