diff --git a/train.py b/train.py index 05e98916..03c1b570 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,24 @@ from ptsemseg.loss import cross_entropy2d from ptsemseg.metrics import scores + +def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,): + """Polynomial decay of learning rate + :param init_lr is base learning rate + :param iter is a current iteration + :param lr_decay_iter how frequently decay occurs, default is 1 + :param max_iter is number of maximum iterations + :param power is a polymomial power + + """ + if iter % lr_decay_iter or iter > max_iter: + return optimizer + + for param_group in optimizer.param_groups: + param_group['lr'] = init_lr*(1 - iter/max_iter)**power + return optimizer + + def train(args): # Setup Dataloader