-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathonecyclelr.py
117 lines (95 loc) · 5.12 KB
/
onecyclelr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from torch.optim import Optimizer
class OneCycleLR:
""" Sets the learing rate of each parameter group by the one cycle learning rate policy
proposed in https://arxiv.org/pdf/1708.07120.pdf.
It is recommended that you set the max_lr to be the learning rate that achieves
the lowest loss in the learning rate range test, and set min_lr to be 1/10 th of max_lr.
So, the learning rate changes like min_lr -> max_lr -> min_lr -> final_lr,
where final_lr = min_lr * reduce_factor.
Note: Currently only supports one parameter group.
Args:
optimizer: (Optimizer) against which we apply this scheduler
num_steps: (int) of total number of steps/iterations
lr_range: (tuple) of min and max values of learning rate
momentum_range: (tuple) of min and max values of momentum
annihilation_frac: (float), fracion of steps to annihilate the learning rate
reduce_factor: (float), denotes the factor by which we annihilate the learning rate at the end
last_step: (int), denotes the last step. Set to -1 to start training from the beginning
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = OneCycleLR(optimizer, num_steps=num_steps, lr_range=(0.1, 1.))
>>> for epoch in range(epochs):
>>> for step in train_dataloader:
>>> train(...)
>>> scheduler.step()
Useful resources:
https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6
https://medium.com/vitalify-asia/whats-up-with-deep-learning-optimizers-since-adam-5c1d862b9db0
"""
def __init__(self,
optimizer: Optimizer,
num_steps: int,
lr_range: tuple = (0.1, 1.),
momentum_range: tuple = (0.85, 0.95),
annihilation_frac: float = 0.1,
reduce_factor: float = 0.01,
last_step: int = -1):
# Sanity check
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
self.num_steps = num_steps
self.min_lr, self.max_lr = lr_range[0], lr_range[1]
assert self.min_lr < self.max_lr, \
"Argument lr_range must be (min_lr, max_lr), where min_lr < max_lr"
self.min_momentum, self.max_momentum = momentum_range[0], momentum_range[1]
assert self.min_momentum < self.max_momentum, \
"Argument momentum_range must be (min_momentum, max_momentum), where min_momentum < max_momentum"
self.num_cycle_steps = int(num_steps * (1. - annihilation_frac)) # Total number of steps in the cycle
self.final_lr = self.min_lr * reduce_factor
self.last_step = last_step
if self.last_step == -1:
self.step()
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer. (Borrowed from _LRScheduler class in torch.optim.lr_scheduler.py)
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state. (Borrowed from _LRScheduler class in torch.optim.lr_scheduler.py)
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_lr(self):
return self.optimizer.param_groups[0]['lr']
def get_momentum(self):
return self.optimizer.param_groups[0]['momentum']
def step(self):
"""Conducts one step of learning rate and momentum update
"""
current_step = self.last_step + 1
self.last_step = current_step
if current_step <= self.num_cycle_steps // 2:
# Scale up phase
scale = current_step / (self.num_cycle_steps // 2)
lr = self.min_lr + (self.max_lr - self.min_lr) * scale
momentum = self.max_momentum - (self.max_momentum - self.min_momentum) * scale
elif current_step <= self.num_cycle_steps:
# Scale down phase
scale = (current_step - self.num_cycle_steps // 2) / (self.num_cycle_steps - self.num_cycle_steps // 2)
lr = self.max_lr - (self.max_lr - self.min_lr) * scale
momentum = self.min_momentum + (self.max_momentum - self.min_momentum) * scale
elif current_step <= self.num_steps:
# Annihilation phase: only change lr
scale = (current_step - self.num_cycle_steps) / (self.num_steps - self.num_cycle_steps)
lr = self.min_lr - (self.min_lr - self.final_lr) * scale
momentum = None
else:
# Exceeded given num_steps: do nothing
return
self.optimizer.param_groups[0]['lr'] = lr
if momentum:
self.optimizer.param_groups[0]['momentum'] = momentum