-
Notifications
You must be signed in to change notification settings - Fork 186
/
Copy pathutils.py
56 lines (48 loc) · 1.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
""" utility functions"""
import re
import os
from os.path import basename
import gensim
import torch
from torch import nn
def count_data(path):
""" count number of data in the given path"""
matcher = re.compile(r'[0-9]+\.json')
match = lambda name: bool(matcher.match(name))
names = os.listdir(path)
n_data = len(list(filter(match, names)))
return n_data
PAD = 0
UNK = 1
START = 2
END = 3
def make_vocab(wc, vocab_size):
word2id, id2word = {}, {}
word2id['<pad>'] = PAD
word2id['<unk>'] = UNK
word2id['<start>'] = START
word2id['<end>'] = END
for i, (w, _) in enumerate(wc.most_common(vocab_size), 4):
word2id[w] = i
return word2id
def make_embedding(id2word, w2v_file, initializer=None):
attrs = basename(w2v_file).split('.') #word2vec.{dim}d.{vsize}k.bin
w2v = gensim.models.Word2Vec.load(w2v_file).wv
vocab_size = len(id2word)
emb_dim = int(attrs[-3][:-1])
embedding = nn.Embedding(vocab_size, emb_dim).weight
if initializer is not None:
initializer(embedding)
oovs = []
with torch.no_grad():
for i in range(len(id2word)):
# NOTE: id2word can be list or dict
if i == START:
embedding[i, :] = torch.Tensor(w2v['<s>'])
elif i == END:
embedding[i, :] = torch.Tensor(w2v[r'<\s>'])
elif id2word[i] in w2v:
embedding[i, :] = torch.Tensor(w2v[id2word[i]])
else:
oovs.append(i)
return embedding, oovs