Skip to content

Commit

Permalink
Fix and generalize SortedSet tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 3, 2023
1 parent 9d47073 commit 069bc70
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
14 changes: 13 additions & 1 deletion NDTensors/src/SmallVectors/src/abstractsmallvector/deque.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,20 @@ end
return vec
end

@inline function Base.deleteat!(
vec::AbstractSmallVector, indices::AbstractUnitRange{<:Integer}
)
f = first(indices)
n = length(indices)
circshift!(smallview(vec, f, lastindex(vec)), -n)
resize!(vec, length(vec) - n)
return vec
end

# Don't @inline, makes it slower.
function StaticArrays.deleteat(vec::AbstractSmallVector, index::Integer)
function StaticArrays.deleteat(
vec::AbstractSmallVector, index::Union{Integer,AbstractUnitRange{<:Integer}}
)
mvec = Base.copymutable(vec)
deleteat!(mvec, index)
return convert(similar_type(vec), mvec)
Expand Down
30 changes: 26 additions & 4 deletions NDTensors/src/SortedSets/src/sortedset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ struct SortedIndices{I,Inds<:AbstractArray{I},Order<:Ordering} <: AbstractSet{I}
) where {I,Inds<:AbstractArray{I},Order<:Ordering} = new{I,Inds,Order}(inds, order)
end

function SortedIndices{I,Inds}(
# Inner constructor
function SortedIndices{I,Inds,Order}(
a::Inds, order::Order; issorted=issorted, allunique=allunique
) where {I,Inds<:AbstractArray{I},Order<:Ordering}
if !issorted(a, order)
Expand All @@ -27,7 +28,21 @@ function SortedIndices{I,Inds}(
return _SortedIndices(a, order)
end

function SortedIndices(
@inline function SortedIndices{I,Inds,Order}(
a::AbstractArray, order::Ordering; issorted=issorted, allunique=allunique
) where {I,Inds<:AbstractArray{I},Order<:Ordering}
return SortedIndices{I,Inds,Order}(
convert(Inds, a), convert(Order, order); issorted, allunique
)
end

@inline function SortedIndices{I,Inds}(
a::AbstractArray, order::Order; issorted=issorted, allunique=allunique
) where {I,Inds<:AbstractArray{I},Order<:Ordering}
return SortedIndices{I,Inds,Order}(a, order; issorted, allunique)
end

@inline function SortedIndices(
a::Inds, order::Ordering; issorted=issorted, allunique=allunique
) where {I,Inds<:AbstractArray{I}}
return SortedIndices{I,Inds}(a, order; issorted, allunique)
Expand Down Expand Up @@ -182,8 +197,15 @@ end
# TODO: Make into `MSmallVector`?
# More generally, make a `thaw(::AbstractArray)` function to return
# a mutable version of an AbstractArray.
@inline Dictionaries.empty_type(::Type{SortedIndices{I,D}}, ::Type{I}) where {I,D} =
SortedIndices{I,empty_type(D, I)}
@inline Dictionaries.empty_type(
::Type{SortedIndices{I,D,Order}}, ::Type{I}
) where {I,D,Order} = SortedIndices{I,Dictionaries.empty_type(D, I),Order}

@inline Dictionaries.empty_type(::Type{<:AbstractVector}, ::Type{I}) where {I} = Vector{I}

function Base.empty(inds::SortedIndices{I,D}, ::Type{I}) where {I,D}
return Dictionaries.empty_type(typeof(inds), I)(D(), inds.order)
end

@inline function Base.copy(inds::SortedIndices, ::Type{I}) where {I}
if I === eltype(inds)
Expand Down
23 changes: 14 additions & 9 deletions NDTensors/src/SortedSets/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
using Test
using NDTensors.SortedSets
using NDTensors.SmallVectors

@testset "Test NDTensors.SortedSets" begin
s1 = SortedSet([1, 3, 5])
s2 = SortedSet([2, 3, 6])
for V in (Vector, MSmallVector{10}, SmallVector{10})
s1 = SortedSet(V([1, 3, 5]))
s2 = SortedSet(V([2, 3, 6]))

# Set interface
@test union(s1, s2) == SortedSet([1, 2, 3, 5, 6])
@test setdiff(s1, s2) == SortedSet([1, 5])
@test symdiff(s1, s2) == SortedSet([1, 2, 5, 6])
@test intersect(s1, s2) == SortedSet([3])
@test insert!(copy(s1), 4) == SortedSet([1, 3, 4, 5])
@test delete!(copy(s1), 3) == SortedSet([1, 5])
# Set interface
@test union(s1, s2) == SortedSet([1, 2, 3, 5, 6])
@test setdiff(s1, s2) == SortedSet([1, 5])
@test symdiff(s1, s2) == SortedSet([1, 2, 5, 6])
@test intersect(s1, s2) == SortedSet([3])
if SmallVectors.InsertStyle(V) isa IsInsertable
@test insert!(copy(s1), 4) == SortedSet([1, 3, 4, 5])
@test delete!(copy(s1), 3) == SortedSet([1, 5])
end
end
end

0 comments on commit 069bc70

Please sign in to comment.