-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
42 lines (36 loc) · 1.18 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from utils import get_latest_ckpt, get_min_loss_ckpt
class CelebAModel(Trainer):
"""TODO"""
def get_trainer(gpus, path, config, resume_mode='latest', debug=False):
checkpoint_callback = ModelCheckpoint(
filepath="%s/{epoch}-{val_loss:.2f}" % path,
save_top_k=True,
save_last=True,
verbose=True,
monitor='val_acc',
mode='max',
prefix='')
early_stop = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min')
if resume_mode == 'latest':
resume_func = get_latest_ckpt
elif resume_mode == 'min_loss':
resume_func = get_min_loss_ckpt
else:
raise ValueError("%s not supported" % resume_mode)
return Trainer(checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop,
gpus=gpus,
resume_from_checkpoint=resume_func(path),
default_root_dir=path,
fast_dev_run=debug,
terminate_on_nan=True,
**config)