-
Notifications
You must be signed in to change notification settings - Fork 2
/
vocab.py
97 lines (75 loc) · 2.89 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# coding=utf-8
from __future__ import print_function
import argparse
from collections import Counter
from itertools import chain
class VocabEntry(object):
def __init__(self):
self.word2id = dict()
self.unk_id = 3
'''self.word2id['<pad>'] = 0
self.word2id['<s>'] = 1
self.word2id['</s>'] = 2
self.word2id['<unk>'] = 3'''
self.word2id["<pad>"] = 0
#self.word2id["NothingHere"] = 0
self.word2id["Unknown"] = 1
'''self.word2id["Unknown"] = 1
self.word2id["NothingHere"] = 0
self.word2id["NoneCopy"] = 2
self.word2id["CopyNode"] = 3
self.word2id["<StartNode>"] = 4'''
self.id2word = {v: k for k, v in self.word2id.items()}
def __getitem__(self, word):
return self.word2id.get(word, self.unk_id)
def __contains__(self, word):
return word in self.word2id
def __setitem__(self, key, value):
raise ValueError('vocabulary is readonly')
def __len__(self):
return len(self.word2id)
def __repr__(self):
return 'Vocabulary[size=%d]' % len(self)
def id2word(self, wid):
return self.id2word[wid]
def add(self, word):
if word not in self:
wid = self.word2id[word] = len(self)
self.id2word[wid] = word
return wid
else:
return self[word]
def is_unk(self, word):
return word not in self
@staticmethod
def from_corpus(corpus, size, freq_cutoff=0):
vocab_entry = VocabEntry()
#print(list(chain(*corpus)))
word_freq = Counter(chain(*corpus))
#print(word_freq)
non_singletons = [w for w in word_freq if word_freq[w] > 1]
singletons = [w for w in word_freq if word_freq[w] == 1]
print('number of word types: %d, number of word types w/ frequency > 1: %d' % (len(word_freq),
len(non_singletons)))
print('singletons: %s' % singletons)
top_k_words = sorted(word_freq.keys(), reverse=True, key=word_freq.get)[:size]
words_not_included = []
for word in top_k_words:
if len(vocab_entry) < size:
if word_freq[word] >= freq_cutoff:
vocab_entry.add(word)
else:
words_not_included.append(word)
print('word types not included: %s' % words_not_included)
return vocab_entry
class Vocab(object):
def __init__(self, **kwargs):
self.entries = []
for key, item in kwargs.items():
assert isinstance(item, VocabEntry)
self.__setattr__(key, item)
self.entries.append(key)
def __repr__(self):
return 'Vocab(%s)' % (', '.join('%s %swords' % (entry, getattr(self, entry)) for entry in self.entries))
if __name__ == '__main__':
raise NotImplementedError