-
Notifications
You must be signed in to change notification settings - Fork 13
/
pretrain_pred.py
51 lines (40 loc) · 1.47 KB
/
pretrain_pred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import sys
import json
from pathlib import Path
from allennlp.commands import main
sys.path += ['./scripts', './table_embedder/dataset_readers', './table_embedder/models']
from scripts.util import Util
from scripts.pretrain import Pretrain
def cmd_builder(params, overrides):
sys.argv = [
"allennlp", # command name, not used by main
"predict",
params['model_path'],
params['pred_path'],
"--output-file", str(Path(params['out_dir'])/params['out_pred_name']),
"--include-package", "table_embedder",
"--predictor", "predictor",
"--cuda-device", 0,
"--batch-size", str(params['batch_size']),
"-o", overrides,
]
def setup_env_variables(params):
os.environ["ALLENNLP_DEBUG"] = "TRUE"
for name, val in params.items():
print(name, val)
if isinstance(name, list):
continue
os.environ[name] = val
if __name__ == "__main__":
# initialize
params = Util.load_yaml('./exp/pretrain/26600K_mix.yml') # './exp/pretrain/26600K_freq.yml'
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, params['cuda_devices']))
del params['cuda_devices']
setup_env_variables(params)
# predict
overrides = json.dumps({'dataset_reader': {'type': 'preprocess'}, 'trainer': {'opt_level': 'O0'}})
cmd_builder(params, overrides)
main()
# dump pred
Pretrain.dump_pred(Path(params['out_dir'])/params['out_pred_name'], Path(params['out_dir']))