-
Notifications
You must be signed in to change notification settings - Fork 11
/
run_training.py
55 lines (39 loc) · 1.76 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import sys
import argparse
import importlib
import multiprocessing
import cv2 as cv
import torch.backends.cudnn
env_path = os.path.join(os.path.dirname(__file__))
if env_path not in sys.path:
sys.path.append(env_path)
import admin.settings as ws_settings
def run_training(train_module, train_name, cudnn_benchmark=True):
"""Run a train scripts in train_settings.
args:
train_module: Name of module in the "train_settings/" folder.
train_name: Name of the train settings file.
cudnn_benchmark: Use cudnn benchmark or not (default is True).
"""
# This is needed to avoid strange crashes related to opencv
cv.setNumThreads(0)
torch.backends.cudnn.benchmark = cudnn_benchmark
print('Training: {} {}'.format(train_module, train_name))
settings = ws_settings.Settings()
settings.module_name = train_module
settings.script_name = train_name
settings.project_path = '{}/{}'.format(train_module, train_name)
expr_module = importlib.import_module('train_settings.{}.{}'.format(train_module, train_name))
expr_func = getattr(expr_module, 'run')
expr_func(settings)
def main():
parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
parser.add_argument('train_module', type=str, help='Name of module in the "train_settings/" folder.')
parser.add_argument('train_name', type=str, help='Name of the train settings file.')
parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).')
args = parser.parse_args()
run_training(args.train_module, args.train_name, args.cudnn_benchmark)
if __name__ == '__main__':
multiprocessing.set_start_method('spawn', force=True)
main()