This repository has been archived by the owner on Apr 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
121 lines (110 loc) · 4.78 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
""" Download files from web. If Gdrive downloading stacked, try to remove the gdown cache `rm -rf ~/.cache/gdown` """
import tarfile
import zipfile
import gzip
import requests
import os
import gdown
from gensim.models import KeyedVectors
from gensim.models import fasttext
def get_embedding_interface(model_name):
if model_name in ['fasttext', 'fasttext_cc', 'w2v', 'glove']:
model = get_word_embedding_model(model_name)
def get_embedding(a, b):
try:
v_a = model[a]
except KeyError:
v_a = 0
try:
v_b = model[b]
except KeyError:
v_b = 0
if type(v_a) is int and type(v_b) is int:
return 0
return (v_a - v_b).tolist()
else:
raise ValueError(f'unknown model {model_name}')
return get_embedding, model.vector_size
def get_word_embedding_model(model_name: str = 'fasttext'):
""" get word embedding model """
os.makedirs('./cache', exist_ok=True)
if model_name == 'w2v':
path = './cache/GoogleNews-vectors-negative300.bin'
if not os.path.exists(path):
print('downloading {}'.format(model_name))
wget(
url="https://huggingface.co/relbert/word_embedding_models/resolve/main/GoogleNews-vectors-negative300.bin",
cache_dir='./cache')
model = KeyedVectors.load_word2vec_format(path, binary=True)
elif model_name == 'fasttext_cc':
path = './cache/crawl-300d-2M-subword.bin'
if not os.path.exists(path):
print('downloading {}'.format(model_name))
wget(
url='https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M-subword.zip',
cache_dir='./cache')
model = fasttext.load_facebook_model(path)
elif model_name == 'fasttext':
path = './cache/wiki-news-300d-1M.vec'
if not os.path.exists(path):
print('downloading {}'.format(model_name))
wget(
url='https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip',
cache_dir='./cache')
model = KeyedVectors.load_word2vec_format(path)
elif model_name == 'glove':
path = './cache/glove.840B.300d.gensim.bin'
if not os.path.exists(path):
print('downloading {}'.format(model_name))
wget(
url='https://huggingface.co/relbert/word_embedding_models/resolve/main/glove.840B.300d.gensim.bin',
cache_dir='./cache')
model = KeyedVectors.load_word2vec_format(path, binary=True)
elif model_name == 'pair2vec':
path = './cache/pair2vec.fasttext.bin'
if not os.path.exists(path):
print('downloading {}'.format(model_name))
wget(
url='https://github.com/asahi417/AnalogyTools/releases/download/0.0.0/pair2vec.fasttext.bin.tar.gz',
cache_dir='./cache')
model = KeyedVectors.load_word2vec_format(path, binary=True)
else:
path = './cache/{}.bin'.format(model_name)
if not os.path.exists(path):
print('downloading {}'.format(model_name))
wget(url='https://github.com/asahi417/AnalogyTools/releases/download/0.0.0/{}.bin.tar.gz'.format(model_name),
cache_dir='./cache')
model = KeyedVectors.load_word2vec_format(path, binary=True)
return model
def wget(url, cache_dir: str, gdrive_filename: str = None):
""" wget and uncompress data_iterator """
path = _wget(url, cache_dir, gdrive_filename=gdrive_filename)
if path.endswith('.tar.gz') or path.endswith('.tgz') or path.endswith('.tar'):
if path.endswith('.tar'):
tar = tarfile.open(path)
else:
tar = tarfile.open(path, "r:gz")
tar.extractall(cache_dir)
tar.close()
os.remove(path)
elif path.endswith('.gz'):
with gzip.open(path, 'rb') as f:
with open(path.replace('.gz', ''), 'wb') as f_write:
f_write.write(f.read())
os.remove(path)
elif path.endswith('.zip'):
with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(cache_dir)
os.remove(path)
# return path
def _wget(url: str, cache_dir, gdrive_filename: str = None):
""" get data from web """
os.makedirs(cache_dir, exist_ok=True)
if url.startswith('https://drive.google.com'):
assert gdrive_filename is not None, 'please provide fileaname for gdrive download'
return gdown.download(url, '{}/{}'.format(cache_dir, gdrive_filename), quiet=False)
filename = os.path.basename(url)
with open('{}/{}'.format(cache_dir, filename), "wb") as f:
r = requests.get(url)
f.write(r.content)
return '{}/{}'.format(cache_dir, filename)