-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathunion_find_tree.hpp
82 lines (76 loc) · 1.7 KB
/
union_find_tree.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#pragma once
#include <vector>
#include <algorithm>
#include <stdexcept>
// Union Find Tree
// (aka Disjoint Set Union)
//
// Space: O(N)
//
// Verified:
// - https://onlinejudge.u-aizu.ac.jp/problems/DSL_1_A
//
class UnionFindTree {
public:
struct node {
size_t leader;
size_t rank;
};
private:
size_t N;
std::vector<node> nodes;
public:
// Time: O(N)
UnionFindTree(size_t const N = 0)
: N(N)
, nodes(N)
{
for (size_t i = 0; i < N; ++i) {
nodes[i] = { i, 0 };
}
}
// Time: O(1)
size_t size() {
return N;
}
// v = [0,N)
// Time: O(a(N))
node find(size_t const v) {
throw_if_invalid_index(v);
if (v != nodes[v].leader) {
// Path Compression
nodes[v] = find(nodes[v].leader);
}
return nodes[v];
}
// u = [0,N), v = [0,N)
// Time: O(a(N))
bool same(size_t const u, size_t const v) {
throw_if_invalid_index(u);
throw_if_invalid_index(v);
return find(u).leader == find(v).leader;
}
// u = [0,N), v = [0,N)
// Time: O(a(N))
bool unite(size_t u, size_t v) {
throw_if_invalid_index(u);
throw_if_invalid_index(v);
u = find(u).leader;
v = find(v).leader;
if (u == v) {
return false;
}
if (nodes[u].rank < nodes[v].rank) {
std::swap(u, v);
}
nodes[v].leader = u;
if (nodes[u].rank == nodes[v].rank) {
nodes[u].rank++;
}
return true;
}
private:
void throw_if_invalid_index(size_t const v) {
if (v >= N) throw std::out_of_range("index out of range");
}
};