diff --git a/base/set.jl b/base/set.jl index a3e9954d214770..ab64837aa9ce1a 100644 --- a/base/set.jl +++ b/base/set.jl @@ -67,6 +67,26 @@ rehash!(s::Set) = (rehash!(s.dict); s) iterate(s::Set, i...) = iterate(KeySet(s.dict), i...) +# In case the size(s) is smaller than size(t) its more efficient to iterate through +# elements of s instead and only delete the ones also contained in t. +# The threshold for this decision boils down to a tradeoff between +# size(s) * cost(in() + delete!()) ≶ size(t) * cost(delete!()) +# Empirical observations on Ints point towards a threshold of 0.8. +# To be on the safe side (e.g. cost(in) >>> cost(delete!) ) a +# conservative threshold of 0.5 was chosen. +function setdiff!(s::Set, t::Set) + if 2 * length(s) < length(t) + for x in s + x in t && delete!(s, x) + end + else + for x in t + delete!(s, x) + end + end + return s +end + """ unique(itr) diff --git a/test/sets.jl b/test/sets.jl index 35155633abb9d5..daddf0bc7bb8bd 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -264,6 +264,14 @@ end s = Set([1,2,3,4]) setdiff!(s, Set([2,4,5,6])) @test isequal(s,Set([1,3])) + + # setdiff iterates the shorter set - make sure this algorithm works + sa, sb = Set([1,2,3,4,5,6,7]), Set([2,3,9]) + @test Set([1,4,5,6,7]) == setdiff(sa, sb) !== sa + @test Set([1,4,5,6,7]) == setdiff!(sa, sb) === sa + sa, sb = Set([1,2,3,4,5,6,7]), Set([2,3,9]) + @test Set([9]) == setdiff(sb, sa) !== sb + @test Set([9]) == setdiff!(sb, sa) === sb end @testset "ordering" begin