Skip to content

Commit

Permalink
adderd learning rate scheduler in training
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishrat Badami authored and meetps committed Jul 21, 2017
1 parent 4f9fa49 commit fb3136a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
21 changes: 21 additions & 0 deletions lr_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,):
"""Polynomial decay of learning rate
:param init_lr is base learning rate
:param iter is a current iteration
:param lr_decay_iter how frequently decay occurs, default is 1
:param max_iter is number of maximum iterations
:param power is a polymomial power
"""
if iter % lr_decay_iter or iter > max_iter:
return optimizer

for param_group in optimizer.param_groups:
param_group['lr'] = init_lr*(1 - iter/max_iter)**power


def adjust_learning_rate(optimizer, init_lr, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = init_lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
22 changes: 4 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,7 @@
from ptsemseg.loader import get_loader, get_data_path
from ptsemseg.loss import cross_entropy2d
from ptsemseg.metrics import scores


def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,):
"""Polynomial decay of learning rate
:param init_lr is base learning rate
:param iter is a current iteration
:param lr_decay_iter how frequently decay occurs, default is 1
:param max_iter is number of maximum iterations
:param power is a polymomial power
"""
if iter % lr_decay_iter or iter > max_iter:
return optimizer

for param_group in optimizer.param_groups:
param_group['lr'] = init_lr*(1 - iter/max_iter)**power
return optimizer

from lr_scheduling import *

def train(args):

Expand Down Expand Up @@ -74,6 +57,9 @@ def train(args):
images = Variable(images)
labels = Variable(labels)

iter = len(trainloader)*epoch + i
poly_lr_scheduler(optimizer, args.l_rate, iter)

optimizer.zero_grad()
outputs = model(images)

Expand Down

0 comments on commit fb3136a

Please sign in to comment.