-
Notifications
You must be signed in to change notification settings - Fork 0
/
n_gram.py
58 lines (53 loc) · 1.89 KB
/
n_gram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#coding: utf-8;
from __future__ import division, print_function, unicode_literals
from future_builtins import *
def find_top_ngram(sequences, N, k=1):
tree = {}
for seq in sequences:
terms = seq.split()
if len(terms) < N: continue
for i in range( len(terms) - N + 1 ):
v = tree
for j in range(N-1):
if terms[i+j] not in v:
v[terms[i+j]] = {}
v = v[terms[i+j]]
v[terms[i+N-1]] = 1 if terms[i+N-1] not in v else v[terms[i+N-1]] + 1
return dfs(tree, N, "", k)
def dfs(tree, height, prefix, k):
if height > 1:
ret = []
for e in tree.keys():
buf = []
pref = " ".join([prefix, e]) if len(prefix) > 0 else e
tmp = dfs(tree[e], height - 1, pref, k)
i, j = 0, 0
while i < len(ret) and j < len(tmp):
if ret[i][1] > tmp[j][1] \
or ( ret[i][1] == tmp[j][1] and ret[i][0] < tmp[j][0] ):
buf.append(ret[i])
i += 1
else:
buf.append(tmp[j])
j += 1
buf = buf + ret[i:] if i < len(ret) else buf + tmp[j:]
ret = buf[:k]
return ret
else:
ret = sorted(tree.items(), key=lambda x:x[1], reverse=True)[:k]
return [ (" ".join([prefix, e]), v) for e,v in ret ]
if __name__ == "__main__":
sequences = [
"this is an apple",
"that is an apple",
"this is a pen",
"that is a pen",
"this is a banana",
"that is a banana",
"there is an apple",
"there are two bananas",
"please give me a pen",
"my favorite fruit is banana",
]
print("3-gram", find_top_ngram( sequences, 3, k=5 ))
print("2-gram", find_top_ngram( sequences, 2, k=5 ))