-
Notifications
You must be signed in to change notification settings - Fork 2
/
prepare_glove.py
79 lines (65 loc) · 2.27 KB
/
prepare_glove.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
'''
Author: Li Wei
Email: wei008@e.ntu.edu.sg
'''
import numpy as np
import time
import json
import logging
def print_time():
print('\n----------{}----------'.format(time.strftime("%Y-%m-%d %X", time.localtime())))
def load_w2v(embedding_dim, embedding_path, cpt_vocab):
print('\nload embedding...')
print(embedding_path)
words = []
for item in cpt_vocab:
words.extend(item.split(' '))
words = set(words)
word_idx = dict((c, k + 1) for k, c in enumerate(words))
word_idx_rev = dict((k + 1, c) for k, c in enumerate(words))
word_idx['unk'] = 0
word_idx_rev[0] = 'unk'
w2v = {}
inputFile = open(embedding_path+'glove.6B.100d.txt', 'r', encoding='latin')
inputFile.readline()
for line in inputFile.readlines():
line = line.strip().split(' ')
w, ebd = line[0], line[1:]
w2v[w] = ebd
embedding = [list(np.zeros(embedding_dim))]
hit = 0
for item in words:
if item in w2v:
vec = list(map(float, w2v[item]))
hit += 1
else:
vec = list(np.random.rand(embedding_dim) / 5. - 0.1)
embedding.append(vec)
print('w2v_file: {}\nall_words: {} hit_words: {}'.format(embedding_path, len(words), hit))
# embedding = np.array(embedding)
print("embedding.shape: {}:".format(len(embedding)))
print("load embedding done!\n")
# saved
# print("glove save function used!\n")
# with open(embedding_path+'glove_4_10.json', 'w') as f:
# json.dump(embedding, f)
return word_idx_rev, word_idx
def tokenize_glove(word_idx, str_input):
tokens = str_input.split(' ')
token_ids = []
for token in tokens:
if token in word_idx:
token_ids.append(word_idx[token])
else:
token_ids.append(word_idx['unk'])
return token_ids
def config_logger(log_path):
logger = logging.getLogger()
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
level=logging.INFO)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'))
logger.addHandler(file_handler)
return logger