-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinteger_set.cpp
131 lines (119 loc) · 3.67 KB
/
integer_set.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "integer_set.h"
#include <iostream>
void IntegerSet::sort(int bit1, int bit2) {
assert(bit2 > bit1);
int dbit = bit2 - bit1;
uint64_t mask1 = (1ULL << bit1);
uint64_t mask2 = (1ULL << bit2);
uint64_t mask12 = ~(mask1 | mask2);
for_each([=](uint64_t i) {
reset(i);
uint64_t b1 = i & mask1;
uint64_t b2 = i & mask2;
// bit1 and bit2 in position 1
uint64_t and12 = (b1 & (b2 >> dbit));
// bit1 or bit2 in position 2
uint64_t or12 = ((b1 << dbit) | b2);
uint64_t j = (i & mask12) | and12 | or12;
set(j);
});
}
void IntegerSet::unsort(int bit1, int bit2) {
assert(bit2 > bit1);
int dbit = bit2 - bit1;
uint64_t mask1 = (1ULL << bit1);
uint64_t mask2 = (1ULL << bit2);
uint64_t mask12 = ~(mask1 | mask2);
for_each([=](uint64_t i) {
uint64_t b1 = i & mask1;
uint64_t b2 = i & mask2;
// bit1 or bit2 in position 1
uint64_t or12 = (b1 | (b2 >> dbit));
// bit1 and bit2 in position 2
uint64_t and12 = ((b1 << dbit) & b2);
uint64_t j = (i & mask12) | or12 | and12;
set(j);
});
}
bool IntegerSet::is_sorted() {
bool check = true;
for_each([&](uint64_t i) {
// Bit pattern must be a bunch of zeros followed by a bunch of ones
// First fill all unused high bits with ones.
i |= ~((1ULL << n) - 1);
// Then check the trailing zero bits plus the leading one bits adds up to 64
check = check && (__builtin_ctzll(i) + __builtin_clzll(~i) == 64);
});
return check;
}
// Remove all the integers where the following two bits aren't in order
// TODO: Could remove entire buckets at a time instead of single bits if bit1, bit2 > 6
void IntegerSet::remove_all_where_not_sorted(int bit1, int bit2) {
assert(bit2 > bit1);
const uint64_t mask1 = (1ULL << bit1);
const uint64_t mask2 = (1ULL << bit2);
for_each([=](uint64_t i) {
// Check if bit1 is set and bit2 is not set
if ((i & mask1) && ((i & mask2) == 0)) {
reset(i);
}
});
}
bool IntegerSet::is_sorted(int bit1, int bit2) {
assert(bit2 > bit1);
const int dbit = bit2 - bit1;
uint64_t bad = 0;
for_each([&](uint64_t i) {
bad |= (i & ~(i >> dbit));
});
const uint64_t mask1 = (1ULL << bit1);
return (bad & mask1) == 0;
}
size_t IntegerSet::size() const {
size_t count = 0;
for (uint64_t i : storage) {
count += __builtin_popcountll(i);
}
return count;
}
void IntegerSet::dump_binary() {
for_each([=](uint64_t id) {
for (int j = 0; j < n; j++) {
std::cout << "01"[((id >> j) & 1)];
}
std::cout << "\n";
});
}
void IntegerSet::dump() {
for_each([=](uint64_t id) {
std::cout << id << "\n";
});
}
bool IntegerSet::is_subset_of(const IntegerSet &other) const {
for (size_t i = 0; i < storage.size(); i++) {
uint64_t this_bucket = storage[i];
uint64_t other_bucket = other.storage[i];
if (this_bucket & ~other_bucket) {
// There is at least one integer in this set that isn't in the other set
return false;
}
}
return true;
}
bool IntegerSet::operator==(const IntegerSet &other) const {
for (size_t i = 0; i < storage.size(); i++) {
uint64_t this_bucket = storage[i];
uint64_t other_bucket = other.storage[i];
if (this_bucket != other_bucket) {
return false;
}
}
return true;
}
uint64_t IntegerSet::hash() const {
uint64_t h = 0;
for (size_t i = 0; i < storage.size(); i++) {
h ^= (storage[i] + 0x9e3779b9 + (h << 6) + (h >> 2));
}
return h;
}