Skip to content

Commit

Permalink
Merge pull request #340 from kengz/globalopt
Browse files Browse the repository at this point in the history
A3C distributed modes
  • Loading branch information
kengz authored May 18, 2019
2 parents fb254c8 + fe78439 commit 7feab40
Show file tree
Hide file tree
Showing 48 changed files with 592 additions and 279 deletions.
2 changes: 1 addition & 1 deletion run_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run_spec(spec, lab_mode):

def read_spec_and_run(spec_file, spec_name, lab_mode):
'''Read a spec and run it in lab mode'''
logger.info(f'Running lab: spec_file {spec_file} spec_name {spec_name} in mode: {lab_mode}')
logger.info(f'Running lab spec_file:{spec_file} spec_name:{spec_name} in mode:{lab_mode}')
if lab_mode in TRAIN_MODES:
spec = spec_util.get(spec_file, spec_name)
else: # eval mode
Expand Down
41 changes: 19 additions & 22 deletions slm_lab/agent/algorithm/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,27 +146,24 @@ def init_nets(self, global_nets=None):
if critic_net_spec['use_same_optim']:
critic_net_spec = actor_net_spec

if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body, add_critic=self.shared)
# main actor network, may contain out_dim self.shared == True
NetClass = getattr(net, actor_net_spec['type'])
self.net = NetClass(actor_net_spec, in_dim, out_dim)
self.net_names = ['net']
if not self.shared: # add separate network for critic
critic_out_dim = 1
CriticNetClass = getattr(net, critic_net_spec['type'])
self.critic = CriticNetClass(critic_net_spec, in_dim, critic_out_dim)
self.net_names.append('critic')
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body, add_critic=self.shared)
# main actor network, may contain out_dim self.shared == True
NetClass = getattr(net, actor_net_spec['type'])
self.net = NetClass(actor_net_spec, in_dim, out_dim)
self.net_names = ['net']
if not self.shared: # add separate network for critic
critic_out_dim = 1
CriticNetClass = getattr(net, critic_net_spec['type'])
self.critic_net = CriticNetClass(critic_net_spec, in_dim, critic_out_dim)
self.net_names.append('critic_net')
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
if not self.shared:
self.critic_optim = net_util.get_optim(self.critic, self.critic.optim_spec)
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic.lr_scheduler_spec)
self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec)
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

@lab_api
Expand All @@ -188,7 +185,7 @@ def calc_pdparam(self, x, net=None):

def calc_v(self, x, net=None, use_cache=True):
'''
Forward-pass to calculate the predicted state-value from critic.
Forward-pass to calculate the predicted state-value from critic_net.
'''
if self.shared: # output: policy, value
if use_cache: # uses cache from calc_pdparam to prevent double-pass
Expand All @@ -197,7 +194,7 @@ def calc_v(self, x, net=None, use_cache=True):
net = self.net if net is None else net
v_pred = net(x)[-1].view(-1)
else:
net = self.critic if net is None else net
net = self.critic_net if net is None else net
v_pred = net(x).view(-1)
return v_pred

Expand Down Expand Up @@ -294,10 +291,10 @@ def train(self):
val_loss = self.calc_val_loss(v_preds, v_targets) # from critic
if self.shared: # shared network
loss = policy_loss + val_loss
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
else:
self.net.training_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.critic.training_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock)
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock, global_net=self.global_critic_net)
loss = policy_loss + val_loss
# reset
self.to_train = 0
Expand Down
2 changes: 2 additions & 0 deletions slm_lab/agent/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def post_init_nets(self):
Call at the end of init_nets() after setting self.net_names
'''
assert hasattr(self, 'net_names')
for net_name in self.net_names:
assert net_name.endswith('net'), f'Naming convention: net_name must end with "net"; got {net_name}'
if util.in_eval_lab_modes():
self.load()
logger.info(f'Loaded algorithm models for lab_mode: {util.get_lab_mode()}')
Expand Down
34 changes: 14 additions & 20 deletions slm_lab/agent/algorithm/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,15 @@ def init_nets(self, global_nets=None):
'''Initialize the neural network used to learn the Q function from the spec'''
if self.algorithm_spec['name'] == 'VanillaDQN':
assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for VanillaDQN; use DQN.'
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

def calc_q_loss(self, batch):
Expand Down Expand Up @@ -145,7 +142,7 @@ def train(self):
batch = self.sample()
for _ in range(self.training_batch_epoch):
loss = self.calc_q_loss(batch)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
total_loss += loss
loss = total_loss / (self.training_epoch * self.training_batch_epoch)
# reset
Expand Down Expand Up @@ -182,19 +179,16 @@ def init_nets(self, global_nets=None):
'''Initialize networks'''
if self.algorithm_spec['name'] == 'DQNBase':
assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for DQNBase; use DQN.'
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.target_net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net', 'target_net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.target_net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net', 'target_net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.online_net = self.target_net
self.eval_net = self.target_net
Expand Down
15 changes: 6 additions & 9 deletions slm_lab/agent/algorithm/hydra_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,14 @@ def init_nets(self, global_nets=None):
# NOTE: Separate init from MultitaskDQN despite similarities so that this implementation can support arbitrary sized state and action heads (e.g. multiple layers)
self.state_dims = in_dims = [body.state_dim for body in self.agent.nanflat_body_a]
self.action_dims = out_dims = [body.action_dim for body in self.agent.nanflat_body_a]
if global_nets is None:
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dims, out_dims)
self.target_net = NetClass(self.net_spec, in_dims, out_dims)
self.net_names = ['net', 'target_net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dims, out_dims)
self.target_net = NetClass(self.net_spec, in_dims, out_dims)
self.net_names = ['net', 'target_net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.online_net = self.target_net
self.eval_net = self.target_net
Expand Down Expand Up @@ -100,7 +97,7 @@ def space_train(self):
batch = self.space_sample()
for _ in range(self.training_batch_epoch):
loss = self.calc_q_loss(batch)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
total_loss += loss
loss = total_loss / (self.training_epoch * self.training_batch_epoch)
# reset
Expand Down
6 changes: 3 additions & 3 deletions slm_lab/agent/algorithm/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def train(self):
val_loss = self.calc_val_loss(v_preds, v_targets) # from critic
if self.shared: # shared network
loss = policy_loss + val_loss
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
else:
self.net.training_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.critic.training_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock)
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock, global_net=self.global_critic_net)
loss = policy_loss + val_loss
total_loss += loss
loss = total_loss / self.training_epoch / len(minibatches)
Expand Down
17 changes: 7 additions & 10 deletions slm_lab/agent/algorithm/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,15 @@ def init_nets(self, global_nets=None):
Networks for continuous action spaces have two heads and return two values, the first is a tensor containing the mean of the action policy, the second is a tensor containing the std deviation of the action policy. The distribution is assumed to be a Gaussian (Normal) distribution.
Networks for discrete action spaces have a single head and return the logits for a categorical probability distribution over the discrete actions
'''
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

@lab_api
Expand Down Expand Up @@ -161,7 +158,7 @@ def train(self):
pdparams = self.calc_pdparam_batch(batch)
advs = self.calc_ret_advs(batch)
loss = self.calc_policy_loss(batch, pdparams, advs)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
# reset
self.to_train = 0
logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}')
Expand Down
17 changes: 7 additions & 10 deletions slm_lab/agent/algorithm/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,15 @@ def init_nets(self, global_nets=None):
'''Initialize the neural network used to learn the Q function from the spec'''
if 'Recurrent' in self.net_spec['type']:
self.net_spec.update(seq_len=self.net_spec['seq_len'])
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

@lab_api
Expand Down Expand Up @@ -149,7 +146,7 @@ def train(self):
if self.to_train == 1:
batch = self.sample()
loss = self.calc_q_loss(batch)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
# reset
self.to_train = 0
logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}')
Expand Down
2 changes: 1 addition & 1 deletion slm_lab/agent/algorithm/sil.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def train(self):
pdparams, _v_preds = self.calc_pdparam_v(batch)
sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch, pdparams)
sil_loss = sil_policy_loss + sil_val_loss
self.net.training_step(sil_loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(sil_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
total_sil_loss += sil_loss
sil_loss = total_sil_loss / self.training_epoch
loss = super_loss + sil_loss
Expand Down
9 changes: 6 additions & 3 deletions slm_lab/agent/net/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,18 @@ def forward(self, x):
else:
return self.model_tail(x)

@net_util.dev_check_training_step
def training_step(self, loss, optim, lr_scheduler, lr_clock=None):
'''Takes a single training step: one forward and one backwards pass'''
@net_util.dev_check_train_step
def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None):
lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t'))
optim.zero_grad()
loss.backward()
if self.clip_grad_val is not None:
nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
if global_net is not None:
net_util.push_global_grads(self, global_net)
optim.step()
if global_net is not None:
net_util.copy(global_net, self)
lr_clock.tick('grad_step')
return loss

Expand Down
20 changes: 13 additions & 7 deletions slm_lab/agent/net/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,19 @@ def forward(self, x):
else:
return self.model_tail(x)

@net_util.dev_check_training_step
def training_step(self, loss, optim, lr_scheduler, lr_clock=None):
'''
Train a network given a computed loss
'''
@net_util.dev_check_train_step
def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None):
'''Train a network given a computed loss'''
lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t'))
optim.zero_grad()
loss.backward()
if self.clip_grad_val is not None:
nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
if global_net is not None:
net_util.push_global_grads(self, global_net)
optim.step()
if global_net is not None:
net_util.copy(global_net, self)
lr_clock.tick('grad_step')
return loss

Expand Down Expand Up @@ -286,14 +288,18 @@ def forward(self, xs):
outs.append(model_tail(body_x))
return outs

@net_util.dev_check_training_step
def training_step(self, loss, optim, lr_scheduler, lr_clock=None):
@net_util.dev_check_train_step
def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None):
lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t'))
optim.zero_grad()
loss.backward()
if self.clip_grad_val is not None:
nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
if global_net is not None:
net_util.push_global_grads(self, global_net)
optim.step()
if global_net is not None:
net_util.copy(global_net, self)
lr_clock.tick('grad_step')
return loss

Expand Down
Loading

0 comments on commit 7feab40

Please sign in to comment.