-
Notifications
You must be signed in to change notification settings - Fork 13
/
sato.py
38 lines (30 loc) · 1.18 KB
/
sato.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import json
import sys
from pathlib import Path
from allennlp.commands import main
sys.path += ['./scripts', './table_embedder/dataset_readers', './table_embedder/models']
from scripts.util import Util
def cmd_builder(setting, template_path, overrides):
setting.pop('cuda_devices')
for name, val in setting.items():
os.environ[name] = val
os.environ["ALLENNLP_DEBUG"] = "TRUE"
# Assemble the command into sys.argv
sys.argv = [
"allennlp", # command name, not used by main
"train",
template_path,
"-s", setting['out_dir'],
"--include-package", "table_embedder",
"-o", overrides,
]
if __name__ == "__main__":
params = Util.load_yaml('exp/sato/sato.yml')
if os.getenv('CUDA_VISIBLE_DEVICES') is not None:
params['cuda_devices'] = [int(elem) for elem in os.getenv('CUDA_VISIBLE_DEVICES').split(',')]
for k, v in params['common'].items():
params['train'][k] = v
overrides = json.dumps({"distributed": {"cuda_devices": params['train']['cuda_devices']}}) if len(params['train']['cuda_devices'])>1 else {}
cmd_builder(params['train'], "exp/sato/sato.jsonnet", overrides)
main()