Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lookahead+RAdam optimizer #416

Merged
merged 19 commits into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion slm_lab/agent/net/net_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# register custom torch.optim
setattr(torch.optim, 'GlobalAdam', optimizer.GlobalAdam)
setattr(torch.optim, 'GlobalRMSprop', optimizer.GlobalRMSprop)
setattr(torch.optim, 'Lookahead', optimizer.Lookahead)
setattr(torch.optim, 'RAdam', optimizer.RAdam)


Expand Down Expand Up @@ -329,9 +330,11 @@ def init_global_nets(algorithm):
global_nets[f'global_{net_name}'] = g_net
# if optim is Global, set to override the local optim and its scheduler
optim = getattr(algorithm, optim_name)
if 'Global' in util.get_class_name(optim):
if hasattr(optim, 'share_memory'):
optim.share_memory() # make optim global
global_nets[optim_name] = optim
if hasattr(optim, 'optimizer'): # for Lookahead with an inner optimizer
global_nets[f'{optim_name}_optimizer'] = optim.optimizer
lr_scheduler_name = net_name.replace('net', 'lr_scheduler')
lr_scheduler = getattr(algorithm, lr_scheduler_name)
global_nets[lr_scheduler_name] = lr_scheduler
Expand All @@ -346,6 +349,12 @@ def set_global_nets(algorithm, global_nets):
setattr(algorithm, f'global_{net_name}', None)
# set attr created in init_global_nets
if global_nets is not None:
# handle inner-optimizer recovery
inner_opt_keys = [k for k in global_nets if k.endswith('_optimizer')]
for inner_opt_key in inner_opt_keys:
opt = global_nets[inner_opt_key.replace('_optimizer', '')] # optimizer which has a inner optimizer
setattr(opt, 'optimizer', global_nets.pop(inner_opt_key))
# set global nets and optims
util.set_attr(algorithm, global_nets)
logger.info(f'Set global_nets attr {list(global_nets.keys())} for Hogwild')

Expand Down
93 changes: 75 additions & 18 deletions slm_lab/lib/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Custom PyTorch optimizer classes, to be registered in net_util.py
from torch.optim.optimizer import Optimizer
import itertools as it
import math
import torch

Expand Down Expand Up @@ -103,6 +104,56 @@ def step(self, closure=None):
return loss


class Lookahead(Optimizer):
'''
Lookahead Optimizer: k steps forward, 1 step back
https://arxiv.org/abs/1907.08610
Implementation modified from https://github.com/lonePatient/lookahead_pytorch; reference from https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d
'''

def __init__(self, params, alpha=0.5, k=5, optimizer='RAdam', **optimizer_kwargs):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
# construct base optimizer
OptimClass = getattr(torch.optim, optimizer)
self.optimizer = OptimClass(params, **optimizer_kwargs)
self.param_groups = self.optimizer.param_groups
self.state = self.optimizer.state
# create and use defaults to track params to retain them in multiprocessing spawn
self.defaults = self.optimizer.defaults
self.defaults['alpha'] = alpha
self.defaults['k'] = k
for group in self.param_groups:
group['step_counter'] = 0
self.defaults['slow_weights'] = [[
p.clone().detach() for p in group['params']]
for group in self.param_groups]

for w in it.chain(*self.defaults['slow_weights']):
w.requires_grad = False

def share_memory(self):
self.optimizer.share_memory()

def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
loss = self.optimizer.step()
for group, slow_weights in zip(self.param_groups, self.defaults['slow_weights']):
group['step_counter'] += 1
if group['step_counter'] % self.defaults['k'] != 0:
continue
for p, q in zip(group['params'], slow_weights):
if p.grad is None:
continue
q.data.add_(self.defaults['alpha'], p.data - q.data)
p.data.copy_(q.data)
return loss


class RAdam(Optimizer):
'''
RAdam optimizer which stabilizes training vs. different learning rates.
Expand All @@ -111,21 +162,33 @@ class RAdam(Optimizer):
'''

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=torch.zeros((10, 3)))
super(RAdam, self).__init__(params, defaults)

for group in self.param_groups:
kengz marked this conversation as resolved.
Show resolved Hide resolved
for p in group['params']:
state = self.state[p]
state['step'] = torch.zeros(1)
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_()
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_()

def __setstate__(self, state):
super(RAdam, self).__setstate__(state)

def step(self, closure=None):
def share_memory(self):
kengz marked this conversation as resolved.
Show resolved Hide resolved
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'].share_memory_()
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()

def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:

for p in group['params']:
if p.grad is None:
continue
Expand All @@ -134,16 +197,9 @@ def step(self, closure=None):
raise RuntimeError('RAdam does not support sparse gradients')

p_data_fp32 = p.data.float()

state = self.state[p]

if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
Expand All @@ -152,7 +208,7 @@ def step(self, closure=None):
exp_avg.mul_(beta1).add_(1 - beta1, grad)

state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
buffered = self.defaults['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
Expand All @@ -164,20 +220,21 @@ def step(self, closure=None):

# more conservative since it's an approximated value
if N_sma >= 5:
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
step_size = group['lr'] / (1 - beta1 ** state['step'])
step_size = 1.0 / (1 - beta1 ** state['step'])
buffered[2] = step_size

if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

# more conservative since it's an approximated value
adap_lr = (-step_size * group['lr']).squeeze(dim=0).item()
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p_data_fp32.addcdiv_(adap_lr, exp_avg, denom)
else:
p_data_fp32.add_(-step_size, exp_avg)
p_data_fp32.add_(adap_lr, exp_avg)

p.data.copy_(p_data_fp32)

Expand Down
26 changes: 17 additions & 9 deletions slm_lab/spec/benchmark/a2c/a2c_gae_roboschool.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "Adam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "Adam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand All @@ -66,7 +68,7 @@
},
"spec_params": {
"env": [
"RoboschoolAnt-v1", "RoboschoolAtlasForwardWalk-v1", "RoboschoolHalfCheetah-v1", "RoboschoolHopper-v1", "RoboschoolInvertedDoublePendulum-v1", "RoboschoolInvertedPendulum-v1", "RoboschoolInvertedPendulumSwingup-v1", "RoboschoolReacher-v1", "RoboschoolWalker2d-v1"
"RoboschoolAnt-v1", "RoboschoolAtlasForwardWalk-v1", "RoboschoolHalfCheetah-v1", "RoboschoolHopper-v1", "RoboschoolInvertedDoublePendulum-v1", "RoboschoolInvertedPendulum-v1", "RoboschoolReacher-v1", "RoboschoolWalker2d-v1"
]
}
},
Expand Down Expand Up @@ -106,11 +108,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand Down Expand Up @@ -172,11 +176,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand Down Expand Up @@ -238,11 +244,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand Down
26 changes: 17 additions & 9 deletions slm_lab/spec/benchmark/a2c/a2c_nstep_roboschool.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "Adam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "Adam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand All @@ -66,7 +68,7 @@
},
"spec_params": {
"env": [
"RoboschoolAnt-v1", "RoboschoolAtlasForwardWalk-v1", "RoboschoolHalfCheetah-v1", "RoboschoolHopper-v1", "RoboschoolInvertedDoublePendulum-v1", "RoboschoolInvertedPendulum-v1", "RoboschoolInvertedPendulumSwingup-v1", "RoboschoolReacher-v1", "RoboschoolWalker2d-v1"
"RoboschoolAnt-v1", "RoboschoolAtlasForwardWalk-v1", "RoboschoolHalfCheetah-v1", "RoboschoolHopper-v1", "RoboschoolInvertedDoublePendulum-v1", "RoboschoolInvertedPendulum-v1", "RoboschoolReacher-v1", "RoboschoolWalker2d-v1"
]
}
},
Expand Down Expand Up @@ -106,11 +108,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand Down Expand Up @@ -172,11 +176,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand Down Expand Up @@ -238,11 +244,13 @@
"name": "MSELoss"
},
"actor_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"critic_optim_spec": {
"name": "RAdam",
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
Expand Down
Loading