-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcosine_annealing.py
29 lines (24 loc) · 1.07 KB
/
cosine_annealing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import math
from keras.callbacks import Callback
from keras import backend as K
class CosineAnnealingScheduler(Callback):
"""Cosine annealing scheduler.
"""
def __init__(self, init_epoch, T_max, eta_max, eta_min=0, verbose=0):
super(CosineAnnealingScheduler, self).__init__()
self.init_epoch = init_epoch
self.T_max = T_max
self.eta_max = eta_max
self.eta_min = eta_min
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = self.eta_min + (self.eta_max - self.eta_min) * (1 + math.cos(math.pi * (epoch - self.init_epoch) / self.T_max)) / 2
K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: CosineAnnealingScheduler setting learning '
'rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)