diff --git a/NEWS.md b/NEWS.md index 06b4e218fde953..9396bea11240a2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -469,6 +469,20 @@ Library improvements linear-to-cartesian conversion ([#24715]) - It has a new constructor taking an array + * several missing set-like operations have been added ([#23528]): + `union`, `intersect`, `symdiff`, `setdiff` are now implemented for + all collections with arbitrary many arguments, as well as the + mutating counterparts (`union!` etc.). The performance is also + much better in many cases. Note that this change is slightly + breaking: all the non-mutating functions always return a new + object even if only one argument is passed. Moreover the semantics + of `intersect` and `symdiff` is changed for vectors: + + `intersect` doesn't preserve the multiplicity anymore (use `filter` for + the old behavior) + + `symdiff` has been made consistent with the corresponding methods for + other containers, by taking the multiplicity of the arguments into account. + Use `unique` to get the old behavior. + * The type `LinearIndices` has been added, providing conversion from cartesian incices to linear indices using the normal indexing operation. ([#24715]) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 0b0bc3ab3afe15..2882d62faed4a6 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -580,8 +580,10 @@ julia> empty([1.0, 2.0, 3.0], String) 0-element Array{String,1} ``` """ -empty(a::AbstractVector) = empty(a, eltype(a)) -empty(a::AbstractVector, ::Type{T}) where {T} = Vector{T}() +empty(a::AbstractVector{T}, ::Type{U}=T) where {T,U} = Vector{T}() + +# like empty, but should return a mutable collection, a Vector by default +emptymutable(a::AbstractVector{T}, ::Type{U}=T) where {T,U} = Vector{U}() ## from general iterable to any array diff --git a/base/array.jl b/base/array.jl index f4f21791768c11..bea192e129ea3b 100644 --- a/base/array.jl +++ b/base/array.jl @@ -2209,67 +2209,53 @@ function filter!(f, a::AbstractVector) return a end -function filter(f, a::Vector) - r = Vector{eltype(a)}() - for ai in a - if f(ai) - push!(r, ai) - end - end - return r -end +filter(f, a::Vector) = mapfilter(f, push!, a, similar(a, 0)) # set-like operators for vectors # These are moderately efficient, preserve order, and remove dupes. -function intersect(v1, vs...) - ret = Vector{promote_eltype(v1, vs...)}() - for v_elem in v1 - inall = true - for vsi in vs - if !in(v_elem, vsi) - inall=false; break - end - end - if inall - push!(ret, v_elem) - end +_unique_filter!(pred, update!, state) = function (x) + if pred(x, state) + update!(state, x) + true + else + false end - ret end -function union(vs...) - ret = Vector{promote_eltype(vs...)}() - seen = Set() - for v in vs - for v_elem in v - if !in(v_elem, seen) - push!(ret, v_elem) - push!(seen, v_elem) - end - end +_grow_filter!(seen) = _unique_filter!(∉, push!, seen) +_shrink_filter!(keep) = _unique_filter!(∈, pop!, keep) + +function _grow!(pred!, v::AbstractVector, itrs) + filter!(pred!, v) # uniquify v + foldl(v, itrs) do v, itr + mapfilter(pred!, push!, itr, v) end - ret end -# setdiff only accepts two args -function setdiff(a, b) - args_type = promote_type(eltype(a), eltype(b)) - bset = Set(b) - ret = Vector{args_type}() - seen = Set{eltype(a)}() - for a_elem in a - if !in(a_elem, seen) && !in(a_elem, bset) - push!(ret, a_elem) - push!(seen, a_elem) - end - end - ret +union!(v::AbstractVector{T}, itrs...) where {T} = + _grow!(_grow_filter!(sizehint!(Set{T}(), length(v))), v, itrs) + +symdiff!(v::AbstractVector{T}, itrs...) where {T} = + _grow!(_shrink_filter!(symdiff!(Set{T}(), v, itrs...)), v, itrs) + +function _shrink!(shrinker!, v::AbstractVector, itrs) + seen = Set{eltype(v)}() + filter!(_grow_filter!(seen), v) + shrinker!(seen, itrs...) + filter!(_in(seen), v) end -# symdiff is associative, so a relatively clean -# way to implement this is by using setdiff and union, and -# recursing. Has the advantage of keeping order, too, but -# not as fast as other methods that make a single pass and -# store counts with a Dict. -symdiff(a, b) = union(setdiff(a,b), setdiff(b,a)) -symdiff(a, b, rest...) = symdiff(a, symdiff(b, rest...)) + +intersect!(v::AbstractVector, itrs...) = _shrink!(intersect!, v, itrs) +setdiff!( v::AbstractVector, itrs...) = _shrink!(setdiff!, v, itrs) + +vectorfilter(f, v::AbstractVector) = filter(f, v) # TODO: do we want this special case? +vectorfilter(f, v) = [x for x in v if f(x)] + +function _shrink(shrinker!, itr, itrs) + keep = shrinker!(Set(itr), itrs...) + vectorfilter(_shrink_filter!(keep), itr) +end + +intersect(itr, itrs...) = _shrink(intersect!, itr, itrs) +setdiff( itr, itrs...) = _shrink(setdiff!, itr, itrs) diff --git a/base/bitset.jl b/base/bitset.jl index f2c07235e3ae18..234faa8a28ce21 100644 --- a/base/bitset.jl +++ b/base/bitset.jl @@ -32,7 +32,12 @@ BitSet(itr) = union!(BitSet(), itr) eltype(::Type{BitSet}) = Int similar(s::BitSet) = BitSet() + +empty(s::BitSet, ::Type{Int}=Int) = BitSet() +emptymutable(s::BitSet, ::Type{Int}=Int) = BitSet() + copy(s1::BitSet) = copy!(BitSet(), s1) +copymutable(s::BitSet) = copy(s) """ copy!(dst, src) @@ -49,8 +54,6 @@ function copy!(dest::BitSet, src::BitSet) dest end -copymutable(s::BitSet) = copy(s) - eltype(s::BitSet) = Int sizehint!(s::BitSet, n::Integer) = (sizehint!(s.bits, (n+63) >> 6); s) @@ -256,7 +259,6 @@ isempty(s::BitSet) = _check0(s.bits, 1, length(s.bits)) # Mathematical set functions: union!, intersect!, setdiff!, symdiff! union(s::BitSet, sets...) = union!(copy(s), sets...) -union!(s::BitSet, ns) = foldl(push!, s, ns) union!(s1::BitSet, s2::BitSet) = _matched_map!(|, s1, s2) intersect(s1::BitSet, s2::BitSet) = diff --git a/base/set.jl b/base/set.jl index b9caaf586d212a..f1de13afd08c7a 100644 --- a/base/set.jl +++ b/base/set.jl @@ -27,8 +27,13 @@ function Set(g::Generator) return Set{T}(g) end -similar(s::Set{T}) where {T} = Set{T}() -similar(s::Set, T::Type) = Set{T}() +similar(s::Set{T}, ::Type{U}=T) where {T,U} = Set{U}() + +empty(s::Set{T}, ::Type{U}=T) where {T,U} = Set{U}() + +# return an empty set with eltype T, which is mutable (can be grown) +# by default, a Set is returned +emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}() function show(io::IO, s::Set) print(io, "Set(") @@ -88,23 +93,30 @@ julia> union([1, 2], [2, 4]) 2 4 -julia> union([4, 2], [1, 2]) +julia> union([4, 2], 1:2) 3-element Array{Int64,1}: 4 2 1 + +julia> union(Set([1, 2]), 2:3) +Set([2, 3, 1]) ``` """ function union end -union(s::Set, sets...) = union!(Set{join_eltype(s, sets...)}(), s, sets...) +_in(itr) = x -> x in itr + +union(s, sets...) = union!(emptymutable(s, promote_eltype(s, sets...)), s, sets...) +union(s::AbstractSet) = copy(s) const ∪ = union """ - union!(s::AbstractSet, itrs...) + union!(s::Union{AbstractSet,AbstractVector}, itrs...) Construct the union of passed in sets and overwrite `s` with the result. +Maintain order with arrays. # Examples ```jldoctest @@ -116,6 +128,11 @@ julia> a Set([7, 4, 3, 5, 1]) ``` """ +union!(s::AbstractSet, sets...) = foldl(union!, s, sets) + +# default generic 2-args implementation with push! +union!(s::AbstractSet, itr) = foldl(push!, s, itr) + function union!(s::Set{T}, itr) where T haslength(itr) && sizehint!(s, length(itr)) for x=itr @@ -125,17 +142,13 @@ function union!(s::Set{T}, itr) where T s end -union!(s::AbstractSet, sets...) = foldl(union!, s, sets) - -join_eltype() = Bottom -join_eltype(v1, vs...) = typejoin(eltype(v1), join_eltype(vs...)) """ intersect(s, itrs...) ∩(s, itrs...) Construct the intersection of sets. -Maintain order and multiplicity of the first argument for arrays and ranges. +Maintain order with arrays. # Examples ```jldoctest @@ -144,28 +157,29 @@ julia> intersect([1, 2, 3], [3, 4, 5]) 3 julia> intersect([1, 4, 4, 5, 6], [4, 6, 6, 7, 8]) -3-element Array{Int64,1}: - 4 +2-element Array{Int64,1}: 4 6 + +julia> intersect(Set([1, 2]), BitSet([2, 3])) +Set([2]) ``` """ -function intersect end - -intersect(s) = copymutable(s) -intersect(s::AbstractSet, itr) = mapfilter(x->in(x, s), push!, itr, similar(s)) intersect(s::AbstractSet, itr, itrs...) = intersect!(intersect(s, itr), itrs...) +intersect(s) = union(s) +intersect(s::AbstractSet, itr) = mapfilter(_in(s), push!, itr, emptymutable(s)) const ∩ = intersect """ - intersect!(s::AbstractSet, itrs...) + intersect!(s::Union{AbstractSet,AbstractVector}, itrs...) Intersect all passed in sets and overwrite `s` with the result. +Maintain order with arrays. """ -intersect!(s::AbstractSet, s2::AbstractSet) = filter!(x -> x in s2, s) -intersect!(s::AbstractSet, itr) = intersect!(s, union!(similar(s), itr)) intersect!(s::AbstractSet, itrs...) = foldl(intersect!, s, itrs) +intersect!(s::AbstractSet, s2::AbstractSet) = filter!(_in(s2), s) +intersect!(s::AbstractSet, itr) = intersect!(s, union!(emptymutable(s), itr)) """ setdiff(s, itrs...) @@ -182,12 +196,13 @@ julia> setdiff([1,2,3], [3,4,5]) ``` """ setdiff(s::AbstractSet, itrs...) = setdiff!(copymutable(s), itrs...) -setdiff(s) = copymutable(s) +setdiff(s) = union(s) """ setdiff!(s, itrs...) Remove from set `s` (in-place) each element of each iterable from `itrs`. +Maintain order with arrays. # Examples ```jldoctest @@ -207,7 +222,8 @@ setdiff!(s::AbstractSet, itr) = foldl(delete!, s, itr) symdiff(s, itrs...) Construct the symmetric difference of elements in the passed in sets. -Maintains order with arrays. +When `s` is not an `AbstractSet`, the order is maintained. +Note that in this case the multiplicity of elements matters. # Examples ```jldoctest @@ -216,15 +232,25 @@ julia> symdiff([1,2,3], [3,4,5], [4,5,6]) 1 2 6 + +julia> symdiff([1,2,1], [2, 1, 2]) +2-element Array{Int64,1}: + 1 + 2 + +julia> symdiff(unique([1,2,1]), unique([2, 1, 2])) +0-element Array{Int64,1} ``` """ -symdiff(s::AbstractSet, sets...) = symdiff!(copymutable(s), sets...) -symdiff(s) = copymutable(s) # remove when method above becomes as efficient +symdiff(s, sets...) = symdiff!(emptymutable(s, promote_eltype(s, sets...)), s, sets...) +symdiff(s) = symdiff!(copy(s)) """ - symdiff!(s::AbstractSet, itrs...) + symdiff!(s::Union{AbstractSet,AbstractVector}, itrs...) Construct the symmetric difference of the passed in sets, and overwrite `s` with the result. +When `s` is an array, the order is maintained. +Note that in this case the multiplicity of elements matters. """ symdiff!(s::AbstractSet, itrs...) = foldl(symdiff!, s, itrs) @@ -256,7 +282,18 @@ julia> issubset([1, 2, 3], [1, 2]) false ``` """ -issubset(l, r) = all(x -> x in r, l) +function issubset(l, r) + for elt in l + if !in(elt, r) + return false + end + end + return true +end + +# use the implementation below when it becoms as efficient +# issubset(l, r) = all(_in(r), l) + const ⊆ = issubset ⊊(l::Set, r::Set) = <(l, r) ⊈(l::Set, r::Set) = !⊆(l, r) diff --git a/test/arrayops.jl b/test/arrayops.jl index 6db4148f136948..bdd96ff1831a49 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -932,7 +932,7 @@ end @test isequal(setdiff([1,2,3,4], [7,8,9]), [1,2,3,4]) @test isequal(setdiff([1,2,3,4], Int64[]), Int64[1,2,3,4]) @test isequal(setdiff([1,2,3,4], [1,2,3,4,5]), Int64[]) - @test isequal(symdiff([1,2,3], [4,3,4]), [1,2,4]) + @test isequal(symdiff([1,2,3], [4,3,4]), [1,2]) @test isequal(symdiff(['e','c','a'], ['b','a','d']), ['e','c','b','d']) @test isequal(symdiff([1,2,3], [4,3], [5]), [1,2,4,5]) @test isequal(symdiff([1,2,3,4,5], [1,2,3], [3,4]), [3,5]) diff --git a/test/sets.jl b/test/sets.jl index 179aaecc5fb9f2..66dcd34664c8a6 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -157,14 +157,14 @@ end end @testset "union" begin - for S in (Set, BitSet) + for S in (Set, BitSet, Vector) s = ∪(S([1,2]), S([3,4])) - @test isequal(s, S([1,2,3,4])) + @test s == S([1,2,3,4]) s = union(S([5,6,7,8]), S([7,8,9])) - @test isequal(s, S([5,6,7,8,9])) + @test s == S([5,6,7,8,9]) s = S([1,3,5,7]) - union!(s,(2,3,4,5)) - @test isequal(s,S([1,2,3,4,5,7])) + union!(s, (2,3,4,5)) + @test s == S([1,3,5,7,2,4]) # order matters for Vector let s1 = S([1, 2, 3]) @test s1 !== union(s1) == s1 @test s1 !== union(s1, 2:4) == S([1,2,3,4]) @@ -173,17 +173,21 @@ end @test s1 === union!(s1, [2,3,4], S([5])) == S([1,2,3,4,5]) end end - @test typeof(union(Set([1]), BitSet())) === Set{Int} - @test typeof(union(BitSet([1]), Set())) === BitSet + @test union(Set([1]), BitSet()) isa Set{Int} + @test union(BitSet([1]), Set()) isa BitSet + @test union([1], BitSet()) isa Vector{Int} + # union must uniquify + @test union([1, 2, 1]) == union!([1, 2, 1]) == [1, 2] + @test union([1, 2, 1], [2, 2]) == union!([1, 2, 1], [2, 2]) == [1, 2] end @testset "intersect" begin - for S in (Set, BitSet) - s = ∩(S([1,2]), S([3,4])) - @test isequal(s, S()) + for S in (Set, BitSet, Vector) + s = S([1,2]) ∩ S([3,4]) + @test s == S() s = intersect(S([5,6,7,8]), S([7,8,9])) - @test isequal(s, S([7,8])) - @test isequal(intersect(S([2,3,1]), S([4,2,3]), S([5,4,3,2])), S([2,3])) + @test s == S([7,8]) + @test intersect(S([2,3,1]), S([4,2,3]), S([5,4,3,2])) == S([2,3]) let s1 = S([1,2,3]) @test s1 !== intersect(s1) == s1 @test s1 !== intersect(s1, 2:10) == S([2,3]) @@ -192,18 +196,22 @@ end @test s1 === intersect!(s1, [2,3,4], 3:4) == S([3]) end end - @test typeof(intersect(Set([1]), BitSet())) === Set{Int} - @test typeof(intersect(BitSet([1]), Set())) === BitSet + @test intersect(Set([1]), BitSet()) isa Set{Int} + @test intersect(BitSet([1]), Set()) isa BitSet + @test intersect([1], BitSet()) isa Vector{Int} + # intersect must uniquify + @test intersect([1, 2, 1]) == intersect!([1, 2, 1]) == [1, 2] + @test intersect([1, 2, 1], [2, 2]) == intersect!([1, 2, 1], [2, 2]) == [2] end @testset "setdiff" begin - for S in (Set, BitSet) - @test isequal(setdiff(S([1,2,3]), S()), S([1,2,3])) - @test isequal(setdiff(S([1,2,3]), S([1])), S([2,3])) - @test isequal(setdiff(S([1,2,3]), S([1,2])), S([3])) - @test isequal(setdiff(S([1,2,3]), S([1,2,3])), S()) - @test isequal(setdiff(S([1,2,3]), S([4])), S([1,2,3])) - @test isequal(setdiff(S([1,2,3]), S([4,1])), S([2,3])) + for S in (Set, BitSet, Vector) + @test setdiff(S([1,2,3]), S()) == S([1,2,3]) + @test setdiff(S([1,2,3]), S([1])) == S([2,3]) + @test setdiff(S([1,2,3]), S([1,2])) == S([3]) + @test setdiff(S([1,2,3]), S([1,2,3])) == S() + @test setdiff(S([1,2,3]), S([4])) == S([1,2,3]) + @test setdiff(S([1,2,3]), S([4,1])) == S([2,3]) let s1 = S([1, 2, 3]) @test s1 !== setdiff(s1) == s1 @test s1 !== setdiff(s1, 2:10) == S([1]) @@ -212,8 +220,13 @@ end @test s1 === setdiff!(s1, S([2,3,4]), S([1])) == S() end end - @test typeof(setdiff(Set([1]), BitSet())) === Set{Int} - @test typeof(setdiff(BitSet([1]), Set())) === BitSet + + @test setdiff(Set([1]), BitSet()) isa Set{Int} + @test setdiff(BitSet([1]), Set()) isa BitSet + @test setdiff([1], BitSet()) isa Vector{Int} + # setdiff must uniquify + @test setdiff([1, 2, 1]) == setdiff!([1, 2, 1]) == [1, 2] + @test setdiff([1, 2, 1], [2, 2]) == setdiff!([1, 2, 1], [2, 2]) == [1] s = Set([1,3,5,7]) setdiff!(s,(3,5)) @@ -240,7 +253,7 @@ end end @testset "issubset, symdiff" begin - for S in (Set, BitSet) + for S in (Set, BitSet, Vector) for (l,r) in ((S([1,2]), S([3,4])), (S([5,6,7,8]), S([7,8,9])), (S([1,2]), S([3,4])), @@ -255,16 +268,22 @@ end @test issubset(intersect(l,r), r) @test issubset(l, union(l,r)) @test issubset(r, union(l,r)) - @test isequal(union(intersect(l,r),symdiff(l,r)), union(l,r)) + if S === Vector + @test sort(union(intersect(l,r),symdiff(l,r))) == sort(union(l,r)) + else + @test union(intersect(l,r),symdiff(l,r)) == union(l,r) + end + end + if S !== Vector + @test ⊆(S([1]), S([1,2])) + @test ⊊(S([1]), S([1,2])) + @test !⊊(S([1]), S([1])) + @test ⊈(S([1]), S([2])) + @test ⊇(S([1,2]), S([1])) + @test ⊋(S([1,2]), S([1])) + @test !⊋(S([1]), S([1])) + @test ⊉(S([1]), S([2])) end - @test ⊆(S([1]), S([1,2])) - @test ⊊(S([1]), S([1,2])) - @test !⊊(S([1]), S([1])) - @test ⊈(S([1]), S([2])) - @test ⊇(S([1,2]), S([1])) - @test ⊋(S([1,2]), S([1])) - @test !⊋(S([1]), S([1])) - @test ⊉(S([1]), S([2])) let s1 = S([1,2,3,4]) @test s1 !== symdiff(s1) == s1 @test s1 !== symdiff(s1, S([2,4,5,6])) == S([1,3,5,6]) @@ -272,6 +291,13 @@ end @test s1 === symdiff!(s1, S([2,4,5,6]), [1,6,7]) == S([3,5,7]) end end + @test symdiff(Set([1,2,3,4]), Set([2,4,5,6])) == Set([1,3,5,6]) + @test symdiff(Set([1]), BitSet()) isa Set{Int} + @test symdiff(BitSet([1]), Set{Int}()) isa BitSet + @test symdiff([1], BitSet()) isa Vector{Int} + # symdiff must NOT uniquify + @test symdiff([1, 2, 1]) == symdiff!([1, 2, 1]) == [2] + @test symdiff([1, 2, 1], [2, 2]) == symdiff!([1, 2, 1], [2, 2]) == [2] end @testset "unique" begin