-
Notifications
You must be signed in to change notification settings - Fork 312
/
SortedList.py
122 lines (95 loc) · 3.35 KB
/
SortedList.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
The "sorted list" data-structure, with amortized O(n^(1/3)) cost per insert and pop.
Example:
A = SortedList()
A.insert(30)
A.insert(50)
A.insert(20)
A.insert(30)
A.insert(30)
print(A) # prints [20, 30, 30, 30, 50]
print(A.lower_bound(30), A.upper_bound(30)) # prints 1 4
print(A[-1]) # prints 50
print(A.pop(1)) # prints 30
print(A) # prints [20, 30, 30, 50]
print(A.count(30)) # prints 2
"""
from bisect import bisect_left as lower_bound
from bisect import bisect_right as upper_bound
class FenwickTree:
def __init__(self, x):
bit = self.bit = list(x)
size = self.size = len(bit)
for i in range(size):
j = i | (i + 1)
if j < size:
bit[j] += bit[i]
def update(self, idx, x):
"""updates bit[idx] += x"""
while idx < self.size:
self.bit[idx] += x
idx |= idx + 1
def __call__(self, end):
"""calc sum(bit[:end])"""
x = 0
while end:
x += self.bit[end - 1]
end &= end - 1
return x
def find_kth(self, k):
"""Find largest idx such that sum(bit[:idx]) <= k"""
idx = -1
for d in reversed(range(self.size.bit_length())):
right_idx = idx + (1 << d)
if right_idx < self.size and self.bit[right_idx] <= k:
idx = right_idx
k -= self.bit[idx]
return idx + 1, k
class SortedList:
block_size = 700
def __init__(self, iterable=()):
iterable = sorted(iterable)
self.micros = [iterable[i:i + self.block_size - 1] for i in range(0, len(iterable), self.block_size - 1)] or [[]]
self.macro = [i[0] for i in self.micros[1:]]
self.micro_size = [len(i) for i in self.micros]
self.fenwick = FenwickTree(self.micro_size)
self.size = len(iterable)
def insert(self, x):
i = lower_bound(self.macro, x)
j = upper_bound(self.micros[i], x)
self.micros[i].insert(j, x)
self.size += 1
self.micro_size[i] += 1
self.fenwick.update(i, 1)
if len(self.micros[i]) >= self.block_size:
self.micros[i:i + 1] = self.micros[i][:self.block_size >> 1], self.micros[i][self.block_size >> 1:]
self.micro_size[i:i + 1] = self.block_size >> 1, self.block_size >> 1
self.fenwick = FenwickTree(self.micro_size)
self.macro.insert(i, self.micros[i + 1][0])
def pop(self, k=-1):
i, j = self._find_kth(k)
self.size -= 1
self.micro_size[i] -= 1
self.fenwick.update(i, -1)
return self.micros[i].pop(j)
def __getitem__(self, k):
i, j = self._find_kth(k)
return self.micros[i][j]
def count(self, x):
return self.upper_bound(x) - self.lower_bound(x)
def __contains__(self, x):
return self.count(x) > 0
def lower_bound(self, x):
i = lower_bound(self.macro, x)
return self.fenwick(i) + lower_bound(self.micros[i], x)
def upper_bound(self, x):
i = upper_bound(self.macro, x)
return self.fenwick(i) + upper_bound(self.micros[i], x)
def _find_kth(self, k):
return self.fenwick.find_kth(k + self.size if k < 0 else k)
def __len__(self):
return self.size
def __iter__(self):
return (x for micro in self.micros for x in micro)
def __repr__(self):
return str(list(self))