-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun.py
executable file
·41 lines (34 loc) · 1.18 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
Main program
May the Force be with you.
2019.3
"""
from torch.utils.data import DataLoader
from dataset import get_dataset
from logger import get_logger
from core.models import get_model
from core.trainer import Trainer
from config import get_cfg
# preparer configuration
cfg = get_cfg()
# prepare dataset
DatasetClass = get_dataset(cfg.DATASET)
dataloader_dict = dict()
for mode in cfg.MODES:
phase_dataset = DatasetClass(cfg, mode=mode)
dataloader_dict[mode] = DataLoader(phase_dataset, batch_size=cfg.BATCHSIZE,
shuffle=True if mode in ['train'] else False,
num_workers=cfg.DATALOADER_WORKERS, pin_memory=True,
drop_last=True)
# prepare models
ModelClass = get_model(cfg.MODEL)
model = ModelClass(cfg)
# prepare logger
LoggerClass = get_logger(cfg.LOGGER)
logger = LoggerClass(cfg)
# register dataset, models, logger to trainer
trainer = Trainer(cfg, model, dataloader_dict, logger)
# start training
epoch_total = cfg.EPOCH_TOTAL + (cfg.RESUME_EPOCH_ID if cfg.RESUME else 0)
while trainer.do_epoch() <= cfg.EPOCH_TOTAL:
pass