-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
74 lines (57 loc) · 2.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from data_loader.covid_data_loader import *
from models.covid import *
from trainers.covid_trainer import *
from testers.covid_tester import *
from utils.config import process_config
from utils.utils import get_args
from utils.gpus import set_gpus
import tensorflow as tf
import os
def train(config):
print('Create the data generator.')
data_loader = COVIDDataLoader(config)
train_data, val_data = data_loader.get_train_data()
test_data = data_loader.get_test_data()
print('Create the model.')
model = COVID_Model(config)
print('Create the trainer.')
trainer = COVIDModelTrainer(model.model, (train_data, val_data), config)
print('Start training the model.')
trainer.train()
print('Create the tester.')
tester = COVIDModelTester(model.model, test_data, config)
print('Test the model.')
tester.test()
def evaluate(config):
print('Create the data generator.')
data_loader = COVIDDataLoader(config)
test_data = data_loader.get_test_data()
print('Create the model.')
model = COVID_Model(config)
print('Loading checkpoint\'s weights')
model.load(config.tester.checkpoint_path)
print('Create the tester.')
tester = COVIDModelTester(model.model, test_data, config)
print('Test the model.')
tester.test()
def main():
# capture the config path from the run arguments
# then process the json configuration file
try:
args = get_args()
config = process_config(args.config, dirs=True, config_copy=True)
except Exception as e:
print(e)
print("missing or invalid arguments")
exit(0)
# Set number of gpu instances to be used
# set_gpus(config)
os.environ["CUDA_VISIBLE_DEVICES"] = config.devices.gpu.id
print('Physical GPU devices: {}'.format(len(tf.config.experimental.list_physical_devices('GPU'))))
print('Logical GPU devices: {}'.format(len(tf.config.experimental.list_logical_devices('GPU'))))
if(config.mode == "train"):
train(config)
elif(config.mode == "eval"):
evaluate(config)
if __name__ == '__main__':
main()