From 069bc70c076087e31b69b6b26aa73dc4dcccdd2f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 3 Oct 2023 17:48:42 -0400 Subject: [PATCH] Fix and generalize SortedSet tests --- .../src/abstractsmallvector/deque.jl | 14 ++++++++- NDTensors/src/SortedSets/src/sortedset.jl | 30 ++++++++++++++++--- NDTensors/src/SortedSets/test/runtests.jl | 23 ++++++++------ 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl index 96c1e80640..79fdd5cdda 100644 --- a/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl +++ b/NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl @@ -140,8 +140,20 @@ end return vec end +@inline function Base.deleteat!( + vec::AbstractSmallVector, indices::AbstractUnitRange{<:Integer} +) + f = first(indices) + n = length(indices) + circshift!(smallview(vec, f, lastindex(vec)), -n) + resize!(vec, length(vec) - n) + return vec +end + # Don't @inline, makes it slower. -function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer) +function StaticArrays.deleteat( + vec::AbstractSmallVector, index::Union{Integer,AbstractUnitRange{<:Integer}} +) mvec = Base.copymutable(vec) deleteat!(mvec, index) return convert(similar_type(vec), mvec) diff --git a/NDTensors/src/SortedSets/src/sortedset.jl b/NDTensors/src/SortedSets/src/sortedset.jl index e1107f41cd..43e2db1c8c 100644 --- a/NDTensors/src/SortedSets/src/sortedset.jl +++ b/NDTensors/src/SortedSets/src/sortedset.jl @@ -15,7 +15,8 @@ struct SortedIndices{I,Inds<:AbstractArray{I},Order<:Ordering} <: AbstractSet{I} ) where {I,Inds<:AbstractArray{I},Order<:Ordering} = new{I,Inds,Order}(inds, order) end -function SortedIndices{I,Inds}( +# Inner constructor +function SortedIndices{I,Inds,Order}( a::Inds, order::Order; issorted=issorted, allunique=allunique ) where {I,Inds<:AbstractArray{I},Order<:Ordering} if !issorted(a, order) @@ -27,7 +28,21 @@ function SortedIndices{I,Inds}( return _SortedIndices(a, order) end -function SortedIndices( +@inline function SortedIndices{I,Inds,Order}( + a::AbstractArray, order::Ordering; issorted=issorted, allunique=allunique +) where {I,Inds<:AbstractArray{I},Order<:Ordering} + return SortedIndices{I,Inds,Order}( + convert(Inds, a), convert(Order, order); issorted, allunique + ) +end + +@inline function SortedIndices{I,Inds}( + a::AbstractArray, order::Order; issorted=issorted, allunique=allunique +) where {I,Inds<:AbstractArray{I},Order<:Ordering} + return SortedIndices{I,Inds,Order}(a, order; issorted, allunique) +end + +@inline function SortedIndices( a::Inds, order::Ordering; issorted=issorted, allunique=allunique ) where {I,Inds<:AbstractArray{I}} return SortedIndices{I,Inds}(a, order; issorted, allunique) @@ -182,8 +197,15 @@ end # TODO: Make into `MSmallVector`? # More generally, make a `thaw(::AbstractArray)` function to return # a mutable version of an AbstractArray. -@inline Dictionaries.empty_type(::Type{SortedIndices{I,D}}, ::Type{I}) where {I,D} = - SortedIndices{I,empty_type(D, I)} +@inline Dictionaries.empty_type( + ::Type{SortedIndices{I,D,Order}}, ::Type{I} +) where {I,D,Order} = SortedIndices{I,Dictionaries.empty_type(D, I),Order} + +@inline Dictionaries.empty_type(::Type{<:AbstractVector}, ::Type{I}) where {I} = Vector{I} + +function Base.empty(inds::SortedIndices{I,D}, ::Type{I}) where {I,D} + return Dictionaries.empty_type(typeof(inds), I)(D(), inds.order) +end @inline function Base.copy(inds::SortedIndices, ::Type{I}) where {I} if I === eltype(inds) diff --git a/NDTensors/src/SortedSets/test/runtests.jl b/NDTensors/src/SortedSets/test/runtests.jl index 900b57bed4..24500f474d 100644 --- a/NDTensors/src/SortedSets/test/runtests.jl +++ b/NDTensors/src/SortedSets/test/runtests.jl @@ -1,15 +1,20 @@ using Test using NDTensors.SortedSets +using NDTensors.SmallVectors @testset "Test NDTensors.SortedSets" begin - s1 = SortedSet([1, 3, 5]) - s2 = SortedSet([2, 3, 6]) + for V in (Vector, MSmallVector{10}, SmallVector{10}) + s1 = SortedSet(V([1, 3, 5])) + s2 = SortedSet(V([2, 3, 6])) - # Set interface - @test union(s1, s2) == SortedSet([1, 2, 3, 5, 6]) - @test setdiff(s1, s2) == SortedSet([1, 5]) - @test symdiff(s1, s2) == SortedSet([1, 2, 5, 6]) - @test intersect(s1, s2) == SortedSet([3]) - @test insert!(copy(s1), 4) == SortedSet([1, 3, 4, 5]) - @test delete!(copy(s1), 3) == SortedSet([1, 5]) + # Set interface + @test union(s1, s2) == SortedSet([1, 2, 3, 5, 6]) + @test setdiff(s1, s2) == SortedSet([1, 5]) + @test symdiff(s1, s2) == SortedSet([1, 2, 5, 6]) + @test intersect(s1, s2) == SortedSet([3]) + if SmallVectors.InsertStyle(V) isa IsInsertable + @test insert!(copy(s1), 4) == SortedSet([1, 3, 4, 5]) + @test delete!(copy(s1), 3) == SortedSet([1, 5]) + end + end end