-
Notifications
You must be signed in to change notification settings - Fork 2
/
earlystopping.py
executable file
·109 lines (96 loc) · 4.1 KB
/
earlystopping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from os import path, makedirs, walk ,remove, scandir, unlink
from numpy import inf
from torch import save as t_save
from lib.utils import sort_human, BOLD, CLR
class EarlyStopping:
def __init__(self, log_path, patience=7, model=None, verbose=False, exp_tag=""):
"""Early stops the training if validation loss doesn't improve after a given patience.
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = inf
self.global_min_loss = inf
save_dir = f"{log_path}/save_model/{exp_tag}"
self.save_path = save_dir
if not path.isdir(save_dir):
makedirs(save_dir)
save_dir = f"{self.save_path}/best/"
if not path.isdir(save_dir):
makedirs(save_dir)
if model is not None:
self.meta_info = {'meta':(model.encoder_params,\
model.decoder_params,\
model.n_frames_input,\
model.n_frames_output)}
else:
self.meta_info = {}
def __str__(self):
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
def __call__(self, val_loss, model, epoch, step=0):
"""Summary
Args:
val_loss (TYPE): Description
model (TYPE): Description
epoch (TYPE): Description
"""
score = -val_loss
model.update(self.meta_info)
if step != 0:
self.save_checkpoint(val_loss, model, epoch, step)
else:
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, epoch, step)
elif score < self.best_score:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
print(f"{BOLD}[*] early stopping at epoch {epoch} !{CLR}")
else:
print(f"[*] early stopping counter: {BOLD}{self.counter}/{self.patience}{CLR}")
# self.del_old_models()
else:
self.best_score = score
self.save_checkpoint(val_loss, model, epoch, step)
self.counter = 0
t_save(model, f"{self.save_path}/LAST_checkpoint_{epoch}_{step}_{val_loss:.6f}.pth.tar")
def del_old_models(self, keep=10):
_, _, files = next(walk(self.save_path))
file_count = len(files)
if file_count > keep:
for old_model in sort_human(files)[:keep//2]:
remove(path.join(self.save_path, old_model))
def save_checkpoint(self, val_loss, model, epoch, step=0):
"""Saves model when validation loss decrease
Args:
val_loss (TYPE): Description
model (TYPE): Description
epoch (TYPE): Description
"""
# save best model
if step != 0:
save_flag = "IE"
print(f"[$] saveing model at step: {step} in epoch {epoch}")
self.del_old_models()
t_save(model, f"{self.save_path}/{save_flag}checkpoint_{epoch}_{step}_{val_loss}.pth.tar")
else:
if val_loss < self.global_min_loss:
if self.verbose:
print(f"[*] validation loss record {BOLD}{val_loss}{CLR} in epoch: {BOLD}{epoch}{CLR}@{step}")
self.global_min_loss = val_loss
save_flag = "best/"
for file in scandir(f"{self.save_path}/{save_flag}"):
unlink(file.path)
else:
save_flag = ""
#self.del_old_models()
t_save(model, f"{self.save_path}/{save_flag}checkpoint_{epoch}_{step}_{val_loss}.pth.tar")
self.val_loss_min = val_loss