-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
40 lines (31 loc) · 1.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import argparse
from train_pipeline import TrainTestPipeline
import logging
import time
from datetime import datetime
import os
def main(parser):
directory = __create_dir()
logging.basicConfig(filename=f'{directory}/info.log', level=logging.INFO)
logging.info('Started')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == 'cuda':
logging.info(f'Device unit: {device}')
else:
logging.warning(f'Device unit is not cuda, using {device}')
if parser.mode == 'train':
pipeline = TrainTestPipeline(parser.mode, parser.model_name, directory, device)
pipeline.train()
def __create_dir():
pwd = os.getcwd()
print(pwd)
dt_string = "training/raretina-" + datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
os.mkdir(dt_string)
return dt_string
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, required=True, choices=['train','inference'])
parser.add_argument('--model_name', type=str, required=True, choices=['transunet', 'resnetunet', 'unet'])
parser = parser.parse_args()
main(parser)