-
Notifications
You must be signed in to change notification settings - Fork 0
/
rhash.cc
102 lines (90 loc) · 2.04 KB
/
rhash.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <string>
#include <random>
#include <cmath>
class rolling_hash {
public:
rolling_hash() {
q = find_prime(rand(1,1000000000));
x = rand(0,q-1);
}
long char_hash(const std::string& s, long i, long sz, long start) {
return (static_cast<long>(s[i+start])*static_cast<long>(pow(x,sz-1-i)+0.5)) % q;
}
long first_hash(const std::string& s, long m, long start=0) {
long hval = 0;
for(long i = 0; i < m; ++i) {
hval += char_hash(s,i,m,start);
hval %=q;
}
return hval;
}
long begin(const std::string& txt, long sz, long start_pos=0) {
index = start_pos;
pat_sz = sz;
current_hash = first_hash(txt, sz, index);
return current_hash;
}
long next(const std::string& txt) {
++index;
if (index < txt.size()) {
current_hash -= char_hash(txt,index-1,pat_sz,start_pos);
current_hash = (current_hash * x) % q;
current_hash = (current_hash+static_cast<long>(txt[start_pos+index+pat_sz-1])) % q;
}
return current_hash;
}
long rand(long from, long to) {
std::uniform_int_distribution<> dis(from,to);
return dis(gen);
}
long find_prime(long n) {
std::vector<long> num(n,0);
for(long i=2; (i-2)<n; ++i)
num[i-2]=i;
for(long i=0;;) {
//find first prime
while(i<n && num[i]==0)
++i;
if (i == n)
break;
long prime = num[i++];
for(long j = i ; j < n; ++j)
if ((num[j] % prime) == 0)
num[j]=0;
}
long p = rand(1,n/2);
long i = 0;
while(p) {
while(num[i]==0) {
i++;
if (i==n)
i=0;
}
--p;
}
return num[i];
}
long index;
std::random_device rd;
std::mt19937 gen;
long q;
long x;
};
std::vector<long>
rabin_karp(const std::string& txt, const std::string& pat)
{
rolling_hash rh;
long target_hash = rh.first_hash(pat,pat.size());
std::vector<long> res;
for(long hash = rh.begin(txt,pat.size()), long i=0; i < txt.size()-pat.size();
++i, hash = rh.next(txt)) {
if (hash == target_hash) {
int j = 0;
while(j < pat.size() && txt[j+i] == pat[j])
++j;
if (j == pat.size())
res.push_back(i);
}
}
return res;
}