Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A3C distributed modes #340

Merged
merged 39 commits into from
May 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
bf66248
remove check_compatibility
kengz May 17, 2019
0a922a7
add and register GlobalAdam
kengz May 17, 2019
3f49674
move make_global_nets to net_util
kengz May 17, 2019
dae5989
Merge remote-tracking branch 'origin/v4-dev' into globalopt
kengz May 17, 2019
074bd0c
enforce net naming convention
kengz May 17, 2019
2b29495
move global_nets init to net_util
kengz May 17, 2019
5800dcf
simplify net and global_net init
kengz May 17, 2019
ec92f45
use global adam for a3cpong
kengz May 17, 2019
07311bc
lower lr
kengz May 17, 2019
657196b
override with global sync
kengz May 17, 2019
ab8db09
eval less
kengz May 17, 2019
3736c84
increase lr
kengz May 17, 2019
7a249be
sync net after training
kengz May 17, 2019
700780a
set and sync hogwild
kengz May 17, 2019
0e2883a
lower lr
kengz May 17, 2019
4636653
fix ppo misnaming critic
kengz May 17, 2019
1d44716
remove curren sync_global_nets
kengz May 18, 2019
2ca515d
do grad push and param pull inside training_step
kengz May 18, 2019
67c6b3f
guard set_global_nets, pass global_net into training_step
kengz May 18, 2019
099d435
fix typo
kengz May 18, 2019
aa44829
move local grad to cpu first
kengz May 18, 2019
f725bcb
global_net to CPU
kengz May 18, 2019
00f3384
revert
kengz May 18, 2019
82918fa
rename to train_step
kengz May 18, 2019
2bed5f4
allow for synced and shared distributed modes
kengz May 18, 2019
cb9b1ad
add basic compat check
kengz May 18, 2019
64782e8
name a3c
kengz May 18, 2019
b248e90
add a2c pong spec
kengz May 18, 2019
fdae0e1
cleanup is_venv setting
kengz May 18, 2019
6bf48f2
remove useless NUM_EVAL_EPI
kengz May 18, 2019
3453a91
divide max_tick by max session if distributed
kengz May 18, 2019
7e6957a
rename resources to search_resources for clarity
kengz May 18, 2019
2e617de
add a3c atari spec
kengz May 18, 2019
be5bc12
add GlobalRMSProp
kengz May 18, 2019
b2cbd1b
update deprecated warning and add_trace methods
kengz May 18, 2019
981f740
improve run log
kengz May 18, 2019
e739b6f
fix a3c shared hogwild cuda id assignment to offset 0
kengz May 18, 2019
bdaac5e
disable a3c gpu with synced in spec
kengz May 18, 2019
fe78439
add flaky to vizdoom test
kengz May 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run_spec(spec, lab_mode):

def read_spec_and_run(spec_file, spec_name, lab_mode):
'''Read a spec and run it in lab mode'''
logger.info(f'Running lab: spec_file {spec_file} spec_name {spec_name} in mode: {lab_mode}')
logger.info(f'Running lab spec_file:{spec_file} spec_name:{spec_name} in mode:{lab_mode}')
if lab_mode in TRAIN_MODES:
spec = spec_util.get(spec_file, spec_name)
else: # eval mode
Expand Down
41 changes: 19 additions & 22 deletions slm_lab/agent/algorithm/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,27 +146,24 @@ def init_nets(self, global_nets=None):
if critic_net_spec['use_same_optim']:
critic_net_spec = actor_net_spec

if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body, add_critic=self.shared)
# main actor network, may contain out_dim self.shared == True
NetClass = getattr(net, actor_net_spec['type'])
self.net = NetClass(actor_net_spec, in_dim, out_dim)
self.net_names = ['net']
if not self.shared: # add separate network for critic
critic_out_dim = 1
CriticNetClass = getattr(net, critic_net_spec['type'])
self.critic = CriticNetClass(critic_net_spec, in_dim, critic_out_dim)
self.net_names.append('critic')
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body, add_critic=self.shared)
# main actor network, may contain out_dim self.shared == True
NetClass = getattr(net, actor_net_spec['type'])
self.net = NetClass(actor_net_spec, in_dim, out_dim)
self.net_names = ['net']
if not self.shared: # add separate network for critic
critic_out_dim = 1
CriticNetClass = getattr(net, critic_net_spec['type'])
self.critic_net = CriticNetClass(critic_net_spec, in_dim, critic_out_dim)
self.net_names.append('critic_net')
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
if not self.shared:
self.critic_optim = net_util.get_optim(self.critic, self.critic.optim_spec)
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic.lr_scheduler_spec)
self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec)
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

@lab_api
Expand All @@ -188,7 +185,7 @@ def calc_pdparam(self, x, net=None):

def calc_v(self, x, net=None, use_cache=True):
'''
Forward-pass to calculate the predicted state-value from critic.
Forward-pass to calculate the predicted state-value from critic_net.
'''
if self.shared: # output: policy, value
if use_cache: # uses cache from calc_pdparam to prevent double-pass
Expand All @@ -197,7 +194,7 @@ def calc_v(self, x, net=None, use_cache=True):
net = self.net if net is None else net
v_pred = net(x)[-1].view(-1)
else:
net = self.critic if net is None else net
net = self.critic_net if net is None else net
v_pred = net(x).view(-1)
return v_pred

Expand Down Expand Up @@ -294,10 +291,10 @@ def train(self):
val_loss = self.calc_val_loss(v_preds, v_targets) # from critic
if self.shared: # shared network
loss = policy_loss + val_loss
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
else:
self.net.training_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.critic.training_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock)
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock, global_net=self.global_critic_net)
loss = policy_loss + val_loss
# reset
self.to_train = 0
Expand Down
2 changes: 2 additions & 0 deletions slm_lab/agent/algorithm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def post_init_nets(self):
Call at the end of init_nets() after setting self.net_names
'''
assert hasattr(self, 'net_names')
for net_name in self.net_names:
assert net_name.endswith('net'), f'Naming convention: net_name must end with "net"; got {net_name}'
if util.in_eval_lab_modes():
self.load()
logger.info(f'Loaded algorithm models for lab_mode: {util.get_lab_mode()}')
Expand Down
34 changes: 14 additions & 20 deletions slm_lab/agent/algorithm/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,15 @@ def init_nets(self, global_nets=None):
'''Initialize the neural network used to learn the Q function from the spec'''
if self.algorithm_spec['name'] == 'VanillaDQN':
assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for VanillaDQN; use DQN.'
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

def calc_q_loss(self, batch):
Expand Down Expand Up @@ -145,7 +142,7 @@ def train(self):
batch = self.sample()
for _ in range(self.training_batch_epoch):
loss = self.calc_q_loss(batch)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
total_loss += loss
loss = total_loss / (self.training_epoch * self.training_batch_epoch)
# reset
Expand Down Expand Up @@ -182,19 +179,16 @@ def init_nets(self, global_nets=None):
'''Initialize networks'''
if self.algorithm_spec['name'] == 'DQNBase':
assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for DQNBase; use DQN.'
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.target_net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net', 'target_net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.target_net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net', 'target_net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.online_net = self.target_net
self.eval_net = self.target_net
Expand Down
15 changes: 6 additions & 9 deletions slm_lab/agent/algorithm/hydra_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,14 @@ def init_nets(self, global_nets=None):
# NOTE: Separate init from MultitaskDQN despite similarities so that this implementation can support arbitrary sized state and action heads (e.g. multiple layers)
self.state_dims = in_dims = [body.state_dim for body in self.agent.nanflat_body_a]
self.action_dims = out_dims = [body.action_dim for body in self.agent.nanflat_body_a]
if global_nets is None:
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dims, out_dims)
self.target_net = NetClass(self.net_spec, in_dims, out_dims)
self.net_names = ['net', 'target_net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dims, out_dims)
self.target_net = NetClass(self.net_spec, in_dims, out_dims)
self.net_names = ['net', 'target_net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
self.online_net = self.target_net
self.eval_net = self.target_net
Expand Down Expand Up @@ -100,7 +97,7 @@ def space_train(self):
batch = self.space_sample()
for _ in range(self.training_batch_epoch):
loss = self.calc_q_loss(batch)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
total_loss += loss
loss = total_loss / (self.training_epoch * self.training_batch_epoch)
# reset
Expand Down
6 changes: 3 additions & 3 deletions slm_lab/agent/algorithm/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def train(self):
val_loss = self.calc_val_loss(v_preds, v_targets) # from critic
if self.shared: # shared network
loss = policy_loss + val_loss
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
else:
self.net.training_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.critic.training_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock)
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock, global_net=self.global_critic_net)
loss = policy_loss + val_loss
total_loss += loss
loss = total_loss / self.training_epoch / len(minibatches)
Expand Down
17 changes: 7 additions & 10 deletions slm_lab/agent/algorithm/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,15 @@ def init_nets(self, global_nets=None):
Networks for continuous action spaces have two heads and return two values, the first is a tensor containing the mean of the action policy, the second is a tensor containing the std deviation of the action policy. The distribution is assumed to be a Gaussian (Normal) distribution.
Networks for discrete action spaces have a single head and return the logits for a categorical probability distribution over the discrete actions
'''
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

@lab_api
Expand Down Expand Up @@ -161,7 +158,7 @@ def train(self):
pdparams = self.calc_pdparam_batch(batch)
advs = self.calc_ret_advs(batch)
loss = self.calc_policy_loss(batch, pdparams, advs)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
# reset
self.to_train = 0
logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}')
Expand Down
17 changes: 7 additions & 10 deletions slm_lab/agent/algorithm/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,15 @@ def init_nets(self, global_nets=None):
'''Initialize the neural network used to learn the Q function from the spec'''
if 'Recurrent' in self.net_spec['type']:
self.net_spec.update(seq_len=self.net_spec['seq_len'])
if global_nets is None:
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
else:
util.set_attr(self, global_nets)
self.net_names = list(global_nets.keys())
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

@lab_api
Expand Down Expand Up @@ -149,7 +146,7 @@ def train(self):
if self.to_train == 1:
batch = self.sample()
loss = self.calc_q_loss(batch)
self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
# reset
self.to_train = 0
logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}')
Expand Down
2 changes: 1 addition & 1 deletion slm_lab/agent/algorithm/sil.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def train(self):
pdparams, _v_preds = self.calc_pdparam_v(batch)
sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch, pdparams)
sil_loss = sil_policy_loss + sil_val_loss
self.net.training_step(sil_loss, self.optim, self.lr_scheduler, lr_clock=clock)
self.net.train_step(sil_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net)
total_sil_loss += sil_loss
sil_loss = total_sil_loss / self.training_epoch
loss = super_loss + sil_loss
Expand Down
9 changes: 6 additions & 3 deletions slm_lab/agent/net/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,18 @@ def forward(self, x):
else:
return self.model_tail(x)

@net_util.dev_check_training_step
def training_step(self, loss, optim, lr_scheduler, lr_clock=None):
'''Takes a single training step: one forward and one backwards pass'''
@net_util.dev_check_train_step
def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None):
lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t'))
optim.zero_grad()
loss.backward()
if self.clip_grad_val is not None:
nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
if global_net is not None:
net_util.push_global_grads(self, global_net)
optim.step()
if global_net is not None:
net_util.copy(global_net, self)
lr_clock.tick('grad_step')
return loss

Expand Down
20 changes: 13 additions & 7 deletions slm_lab/agent/net/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,19 @@ def forward(self, x):
else:
return self.model_tail(x)

@net_util.dev_check_training_step
def training_step(self, loss, optim, lr_scheduler, lr_clock=None):
'''
Train a network given a computed loss
'''
@net_util.dev_check_train_step
def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None):
'''Train a network given a computed loss'''
lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t'))
optim.zero_grad()
loss.backward()
if self.clip_grad_val is not None:
nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
if global_net is not None:
net_util.push_global_grads(self, global_net)
optim.step()
if global_net is not None:
net_util.copy(global_net, self)
lr_clock.tick('grad_step')
return loss

Expand Down Expand Up @@ -286,14 +288,18 @@ def forward(self, xs):
outs.append(model_tail(body_x))
return outs

@net_util.dev_check_training_step
def training_step(self, loss, optim, lr_scheduler, lr_clock=None):
@net_util.dev_check_train_step
def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None):
lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t'))
optim.zero_grad()
loss.backward()
if self.clip_grad_val is not None:
nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
if global_net is not None:
net_util.push_global_grads(self, global_net)
optim.step()
if global_net is not None:
net_util.copy(global_net, self)
lr_clock.tick('grad_step')
return loss

Expand Down
Loading