-
Notifications
You must be signed in to change notification settings - Fork 2
/
unionfind.py
304 lines (267 loc) · 9 KB
/
unionfind.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""
A union-find disjoint set data structure.
https://github.com/deehzee/unionfind
"""
# 2to3 sanity
from __future__ import (
absolute_import, division, print_function, unicode_literals,
)
# Third-party libraries
import numpy as np
class UnionFind(object):
"""Union-find disjoint sets datastructure.
Union-find is a data structure that maintains disjoint set
(called connected components or components in short) membership,
and makes it easier to merge (union) two components, and to find
if two elements are connected (i.e., belong to the same
component).
This implements the "weighted-quick-union-with-path-compression"
union-find algorithm. Only works if elements are immutable
objects.
Worst case for union and find: :math:`(N + M \log^* N)`, with
:math:`N` elements and :math:`M` unions. The function
:math:`\log^*` is the number of times needed to take :math:`\log`
of a number until reaching 1. In practice, the amortized cost of
each operation is nearly linear [1]_.
Terms
-----
Component
Elements belonging to the same disjoint set
Connected
Two elements are connected if they belong to the same component.
Union
The operation where two components are merged into one.
Root
An internal representative of a disjoint set.
Find
The operation to find the root of a disjoint set.
Parameters
----------
elements : NoneType or container, optional, default: None
The initial list of elements.
Attributes
----------
n_elts : int
Number of elements.
n_comps : int
Number of distjoint sets or components.
Implements
----------
__len__
Calling ``len(uf)`` (where ``uf`` is an instance of ``UnionFind``)
returns the number of elements.
__contains__
For ``uf`` an instance of ``UnionFind`` and ``x`` an immutable object,
``x in uf`` returns ``True`` if ``x`` is an element in ``uf``.
__getitem__
For ``uf`` an instance of ``UnionFind`` and ``i`` an integer,
``res = uf[i]`` returns the element stored in the ``i``-th index.
If ``i`` is not a valid index an ``IndexError`` is raised.
__setitem__
For ``uf`` and instance of ``UnionFind``, ``i`` an integer and ``x``
an immutable object, ``uf[i] = x`` changes the element stored at the
``i``-th index. If ``i`` is not a valid index an ``IndexError`` is
raised.
.. [1] http://algs4.cs.princeton.edu/lectures/
"""
def __init__(self, elements=None):
self.n_elts = 0 # current num of elements
self.n_comps = 0 # the number of disjoint sets or components
self._next = 0 # next available id
self._elts = [] # the elements
self._indx = {} # dict mapping elt -> index in _elts
self._par = [] # parent: for the internal tree structure
self._siz = [] # size of the component - correct only for roots
if elements is None:
elements = []
for elt in elements:
self.add(elt)
def __repr__(self):
return (
'<UnionFind:\n\telts={},\n\tsiz={},\n\tpar={},\nn_elts={},n_comps={}>'
.format(
self._elts,
self._siz,
self._par,
self.n_elts,
self.n_comps,
))
def __len__(self):
return self.n_elts
def __contains__(self, x):
return x in self._indx
def __getitem__(self, index):
if index < 0 or index >= self._next:
raise IndexError('index {} is out of bound'.format(index))
return self._elts[index]
def __setitem__(self, index, x):
if index < 0 or index >= self._next:
raise IndexError('index {} is out of bound'.format(index))
self._elts[index] = x
def add(self, x):
"""Add a single disjoint element.
Parameters
----------
x : immutable object
Returns
-------
None
"""
if x in self:
return
self._elts.append(x)
self._indx[x] = self._next
self._par.append(self._next)
self._siz.append(1)
self._next += 1
self.n_elts += 1
self.n_comps += 1
def find(self, x):
"""Find the root of the disjoint set containing the given element.
Parameters
----------
x : immutable object
Returns
-------
int
The (index of the) root.
Raises
------
ValueError
If the given element is not found.
"""
if x not in self._indx:
raise ValueError('{} is not an element'.format(x))
p = self._indx[x]
while p != self._par[p]:
# path compression
q = self._par[p]
self._par[p] = self._par[q]
p = q
return p
def connected(self, x, y):
"""Return whether the two given elements belong to the same component.
Parameters
----------
x : immutable object
y : immutable object
Returns
-------
bool
True if x and y are connected, false otherwise.
"""
return self.find(x) == self.find(y)
def union(self, x, y):
"""Merge the components of the two given elements into one.
Parameters
----------
x : immutable object
y : immutable object
Returns
-------
None
"""
# Initialize if they are not already in the collection
for elt in [x, y]:
if elt not in self:
self.add(elt)
xroot = self.find(x)
yroot = self.find(y)
if xroot == yroot:
return
if self._siz[xroot] < self._siz[yroot]:
self._par[xroot] = yroot
self._siz[yroot] += self._siz[xroot]
else:
self._par[yroot] = xroot
self._siz[xroot] += self._siz[yroot]
self.n_comps -= 1
def component(self, x):
"""Find the connected component containing the given element.
Parameters
----------
x : immutable object
Returns
-------
set
Raises
------
ValueError
If the given element is not found.
"""
if x not in self:
raise ValueError('{} is not an element'.format(x))
elts = np.array(self._elts)
vfind = np.vectorize(self.find)
roots = vfind(elts)
return set(elts[roots == self.find(x)])
def components(self):
"""Return the list of connected components.
Returns
-------
list
A list of sets.
"""
elts = np.array(self._elts)
vfind = np.vectorize(self.find)
roots = vfind(elts)
distinct_roots = set(roots)
return [set(elts[roots == root]) for root in distinct_roots]
# comps = []
# for root in distinct_roots:
# mask = (roots == root)
# comp = set(elts[mask])
# comps.append(comp)
# return comps
def component_mapping(self):
"""Return a dict mapping elements to their components.
The returned dict has the following semantics:
`elt -> component containing elt`
If x, y belong to the same component, the comp(x) and comp(y)
are the same objects (i.e., share the same reference). Changing
comp(x) will reflect in comp(y). This is done to reduce
memory.
But this behaviour should not be relied on. There may be
inconsitency arising from such assumptions or lack thereof.
If you want to do any operation on these sets, use caution.
For example, instead of
::
s = uf.component_mapping()[item]
s.add(stuff)
# This will have side effect in other sets
do
::
s = set(uf.component_mapping()[item]) # or
s = uf.component_mapping()[item].copy()
s.add(stuff)
or
::
s = uf.component_mapping()[item]
s = s | {stuff} # Now s is different
Returns
-------
dict
A dict with the semantics: `elt -> component contianing elt`.
"""
elts = np.array(self._elts)
vfind = np.vectorize(self.find)
roots = vfind(elts)
distinct_roots = set(roots)
comps = {}
for root in distinct_roots:
mask = (roots == root)
comp = set(elts[mask])
comps.update({x: comp for x in comp})
# Change ^this^, if you want a different behaviour:
# If you don't want to share the same set to different keys:
# comps.update({x: set(comp) for x in comp})
return comps
if __name__ == "__main__":
uf = UnionFind(list('abcdefghij'))
print(uf)
uf.union('a', 'b')
uf.union('b', 'i')
uf.union('c', 'd')
uf.union('c', 'b')
print(uf.components())
print(uf.connected('c', 'd'))