diff --git a/run_lab.py b/run_lab.py index ba5a4208e..4d2cb904c 100644 --- a/run_lab.py +++ b/run_lab.py @@ -50,7 +50,7 @@ def run_spec(spec, lab_mode): def read_spec_and_run(spec_file, spec_name, lab_mode): '''Read a spec and run it in lab mode''' - logger.info(f'Running lab: spec_file {spec_file} spec_name {spec_name} in mode: {lab_mode}') + logger.info(f'Running lab spec_file:{spec_file} spec_name:{spec_name} in mode:{lab_mode}') if lab_mode in TRAIN_MODES: spec = spec_util.get(spec_file, spec_name) else: # eval mode diff --git a/slm_lab/agent/algorithm/actor_critic.py b/slm_lab/agent/algorithm/actor_critic.py index 2d7e7a326..54b8b3862 100644 --- a/slm_lab/agent/algorithm/actor_critic.py +++ b/slm_lab/agent/algorithm/actor_critic.py @@ -146,27 +146,24 @@ def init_nets(self, global_nets=None): if critic_net_spec['use_same_optim']: critic_net_spec = actor_net_spec - if global_nets is None: - in_dim = self.body.state_dim - out_dim = net_util.get_out_dim(self.body, add_critic=self.shared) - # main actor network, may contain out_dim self.shared == True - NetClass = getattr(net, actor_net_spec['type']) - self.net = NetClass(actor_net_spec, in_dim, out_dim) - self.net_names = ['net'] - if not self.shared: # add separate network for critic - critic_out_dim = 1 - CriticNetClass = getattr(net, critic_net_spec['type']) - self.critic = CriticNetClass(critic_net_spec, in_dim, critic_out_dim) - self.net_names.append('critic') - else: - util.set_attr(self, global_nets) - self.net_names = list(global_nets.keys()) + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body, add_critic=self.shared) + # main actor network, may contain out_dim self.shared == True + NetClass = getattr(net, actor_net_spec['type']) + self.net = NetClass(actor_net_spec, in_dim, out_dim) + self.net_names = ['net'] + if not self.shared: # add separate network for critic + critic_out_dim = 1 + CriticNetClass = getattr(net, critic_net_spec['type']) + self.critic_net = CriticNetClass(critic_net_spec, in_dim, critic_out_dim) + self.net_names.append('critic_net') # init net optimizer and its lr scheduler self.optim = net_util.get_optim(self.net, self.net.optim_spec) self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) if not self.shared: - self.critic_optim = net_util.get_optim(self.critic, self.critic.optim_spec) - self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic.lr_scheduler_spec) + self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec) + self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) self.post_init_nets() @lab_api @@ -188,7 +185,7 @@ def calc_pdparam(self, x, net=None): def calc_v(self, x, net=None, use_cache=True): ''' - Forward-pass to calculate the predicted state-value from critic. + Forward-pass to calculate the predicted state-value from critic_net. ''' if self.shared: # output: policy, value if use_cache: # uses cache from calc_pdparam to prevent double-pass @@ -197,7 +194,7 @@ def calc_v(self, x, net=None, use_cache=True): net = self.net if net is None else net v_pred = net(x)[-1].view(-1) else: - net = self.critic if net is None else net + net = self.critic_net if net is None else net v_pred = net(x).view(-1) return v_pred @@ -294,10 +291,10 @@ def train(self): val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss - self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock) + self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) else: - self.net.training_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock) - self.critic.training_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock) + self.net.train_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) + self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss # reset self.to_train = 0 diff --git a/slm_lab/agent/algorithm/base.py b/slm_lab/agent/algorithm/base.py index 7e42bae6a..10c15293c 100644 --- a/slm_lab/agent/algorithm/base.py +++ b/slm_lab/agent/algorithm/base.py @@ -48,6 +48,8 @@ def post_init_nets(self): Call at the end of init_nets() after setting self.net_names ''' assert hasattr(self, 'net_names') + for net_name in self.net_names: + assert net_name.endswith('net'), f'Naming convention: net_name must end with "net"; got {net_name}' if util.in_eval_lab_modes(): self.load() logger.info(f'Loaded algorithm models for lab_mode: {util.get_lab_mode()}') diff --git a/slm_lab/agent/algorithm/dqn.py b/slm_lab/agent/algorithm/dqn.py index 694e9c31e..79fa2bc72 100644 --- a/slm_lab/agent/algorithm/dqn.py +++ b/slm_lab/agent/algorithm/dqn.py @@ -79,18 +79,15 @@ def init_nets(self, global_nets=None): '''Initialize the neural network used to learn the Q function from the spec''' if self.algorithm_spec['name'] == 'VanillaDQN': assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for VanillaDQN; use DQN.' - if global_nets is None: - in_dim = self.body.state_dim - out_dim = net_util.get_out_dim(self.body) - NetClass = getattr(net, self.net_spec['type']) - self.net = NetClass(self.net_spec, in_dim, out_dim) - self.net_names = ['net'] - else: - util.set_attr(self, global_nets) - self.net_names = list(global_nets.keys()) + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net'] # init net optimizer and its lr scheduler self.optim = net_util.get_optim(self.net, self.net.optim_spec) self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) self.post_init_nets() def calc_q_loss(self, batch): @@ -145,7 +142,7 @@ def train(self): batch = self.sample() for _ in range(self.training_batch_epoch): loss = self.calc_q_loss(batch) - self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock) + self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) total_loss += loss loss = total_loss / (self.training_epoch * self.training_batch_epoch) # reset @@ -182,19 +179,16 @@ def init_nets(self, global_nets=None): '''Initialize networks''' if self.algorithm_spec['name'] == 'DQNBase': assert all(k not in self.net_spec for k in ['update_type', 'update_frequency', 'polyak_coef']), 'Network update not available for DQNBase; use DQN.' - if global_nets is None: - in_dim = self.body.state_dim - out_dim = net_util.get_out_dim(self.body) - NetClass = getattr(net, self.net_spec['type']) - self.net = NetClass(self.net_spec, in_dim, out_dim) - self.target_net = NetClass(self.net_spec, in_dim, out_dim) - self.net_names = ['net', 'target_net'] - else: - util.set_attr(self, global_nets) - self.net_names = list(global_nets.keys()) + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.target_net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net', 'target_net'] # init net optimizer and its lr scheduler self.optim = net_util.get_optim(self.net, self.net.optim_spec) self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) self.post_init_nets() self.online_net = self.target_net self.eval_net = self.target_net diff --git a/slm_lab/agent/algorithm/hydra_dqn.py b/slm_lab/agent/algorithm/hydra_dqn.py index 71b967b35..c34e1a90f 100644 --- a/slm_lab/agent/algorithm/hydra_dqn.py +++ b/slm_lab/agent/algorithm/hydra_dqn.py @@ -19,17 +19,14 @@ def init_nets(self, global_nets=None): # NOTE: Separate init from MultitaskDQN despite similarities so that this implementation can support arbitrary sized state and action heads (e.g. multiple layers) self.state_dims = in_dims = [body.state_dim for body in self.agent.nanflat_body_a] self.action_dims = out_dims = [body.action_dim for body in self.agent.nanflat_body_a] - if global_nets is None: - NetClass = getattr(net, self.net_spec['type']) - self.net = NetClass(self.net_spec, in_dims, out_dims) - self.target_net = NetClass(self.net_spec, in_dims, out_dims) - self.net_names = ['net', 'target_net'] - else: - util.set_attr(self, global_nets) - self.net_names = list(global_nets.keys()) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dims, out_dims) + self.target_net = NetClass(self.net_spec, in_dims, out_dims) + self.net_names = ['net', 'target_net'] # init net optimizer and its lr scheduler self.optim = net_util.get_optim(self.net, self.net.optim_spec) self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) self.post_init_nets() self.online_net = self.target_net self.eval_net = self.target_net @@ -100,7 +97,7 @@ def space_train(self): batch = self.space_sample() for _ in range(self.training_batch_epoch): loss = self.calc_q_loss(batch) - self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock) + self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) total_loss += loss loss = total_loss / (self.training_epoch * self.training_batch_epoch) # reset diff --git a/slm_lab/agent/algorithm/ppo.py b/slm_lab/agent/algorithm/ppo.py index fc9050542..38555bd6f 100644 --- a/slm_lab/agent/algorithm/ppo.py +++ b/slm_lab/agent/algorithm/ppo.py @@ -189,10 +189,10 @@ def train(self): val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss - self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock) + self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) else: - self.net.training_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock) - self.critic.training_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock) + self.net.train_step(policy_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) + self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, lr_clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) diff --git a/slm_lab/agent/algorithm/reinforce.py b/slm_lab/agent/algorithm/reinforce.py index 356e10ae8..779a9a1b3 100644 --- a/slm_lab/agent/algorithm/reinforce.py +++ b/slm_lab/agent/algorithm/reinforce.py @@ -79,18 +79,15 @@ def init_nets(self, global_nets=None): Networks for continuous action spaces have two heads and return two values, the first is a tensor containing the mean of the action policy, the second is a tensor containing the std deviation of the action policy. The distribution is assumed to be a Gaussian (Normal) distribution. Networks for discrete action spaces have a single head and return the logits for a categorical probability distribution over the discrete actions ''' - if global_nets is None: - in_dim = self.body.state_dim - out_dim = net_util.get_out_dim(self.body) - NetClass = getattr(net, self.net_spec['type']) - self.net = NetClass(self.net_spec, in_dim, out_dim) - self.net_names = ['net'] - else: - util.set_attr(self, global_nets) - self.net_names = list(global_nets.keys()) + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net'] # init net optimizer and its lr scheduler self.optim = net_util.get_optim(self.net, self.net.optim_spec) self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) self.post_init_nets() @lab_api @@ -161,7 +158,7 @@ def train(self): pdparams = self.calc_pdparam_batch(batch) advs = self.calc_ret_advs(batch) loss = self.calc_policy_loss(batch, pdparams, advs) - self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock) + self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) # reset self.to_train = 0 logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') diff --git a/slm_lab/agent/algorithm/sarsa.py b/slm_lab/agent/algorithm/sarsa.py index 98e9505f6..5ff9c8ba3 100644 --- a/slm_lab/agent/algorithm/sarsa.py +++ b/slm_lab/agent/algorithm/sarsa.py @@ -72,18 +72,15 @@ def init_nets(self, global_nets=None): '''Initialize the neural network used to learn the Q function from the spec''' if 'Recurrent' in self.net_spec['type']: self.net_spec.update(seq_len=self.net_spec['seq_len']) - if global_nets is None: - in_dim = self.body.state_dim - out_dim = net_util.get_out_dim(self.body) - NetClass = getattr(net, self.net_spec['type']) - self.net = NetClass(self.net_spec, in_dim, out_dim) - self.net_names = ['net'] - else: - util.set_attr(self, global_nets) - self.net_names = list(global_nets.keys()) + in_dim = self.body.state_dim + out_dim = net_util.get_out_dim(self.body) + NetClass = getattr(net, self.net_spec['type']) + self.net = NetClass(self.net_spec, in_dim, out_dim) + self.net_names = ['net'] # init net optimizer and its lr scheduler self.optim = net_util.get_optim(self.net, self.net.optim_spec) self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec) + net_util.set_global_nets(self, global_nets) self.post_init_nets() @lab_api @@ -149,7 +146,7 @@ def train(self): if self.to_train == 1: batch = self.sample() loss = self.calc_q_loss(batch) - self.net.training_step(loss, self.optim, self.lr_scheduler, lr_clock=clock) + self.net.train_step(loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) # reset self.to_train = 0 logger.debug(f'Trained {self.name} at epi: {clock.epi}, total_t: {clock.total_t}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}') diff --git a/slm_lab/agent/algorithm/sil.py b/slm_lab/agent/algorithm/sil.py index fd7903996..72cf80168 100644 --- a/slm_lab/agent/algorithm/sil.py +++ b/slm_lab/agent/algorithm/sil.py @@ -147,7 +147,7 @@ def train(self): pdparams, _v_preds = self.calc_pdparam_v(batch) sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch, pdparams) sil_loss = sil_policy_loss + sil_val_loss - self.net.training_step(sil_loss, self.optim, self.lr_scheduler, lr_clock=clock) + self.net.train_step(sil_loss, self.optim, self.lr_scheduler, lr_clock=clock, global_net=self.global_net) total_sil_loss += sil_loss sil_loss = total_sil_loss / self.training_epoch loss = super_loss + sil_loss diff --git a/slm_lab/agent/net/conv.py b/slm_lab/agent/net/conv.py index 7bf6a49ec..a5c239eb5 100644 --- a/slm_lab/agent/net/conv.py +++ b/slm_lab/agent/net/conv.py @@ -189,15 +189,18 @@ def forward(self, x): else: return self.model_tail(x) - @net_util.dev_check_training_step - def training_step(self, loss, optim, lr_scheduler, lr_clock=None): - '''Takes a single training step: one forward and one backwards pass''' + @net_util.dev_check_train_step + def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None): lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t')) optim.zero_grad() loss.backward() if self.clip_grad_val is not None: nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val) + if global_net is not None: + net_util.push_global_grads(self, global_net) optim.step() + if global_net is not None: + net_util.copy(global_net, self) lr_clock.tick('grad_step') return loss diff --git a/slm_lab/agent/net/mlp.py b/slm_lab/agent/net/mlp.py index 75f065463..8cf249048 100644 --- a/slm_lab/agent/net/mlp.py +++ b/slm_lab/agent/net/mlp.py @@ -120,17 +120,19 @@ def forward(self, x): else: return self.model_tail(x) - @net_util.dev_check_training_step - def training_step(self, loss, optim, lr_scheduler, lr_clock=None): - ''' - Train a network given a computed loss - ''' + @net_util.dev_check_train_step + def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None): + '''Train a network given a computed loss''' lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t')) optim.zero_grad() loss.backward() if self.clip_grad_val is not None: nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val) + if global_net is not None: + net_util.push_global_grads(self, global_net) optim.step() + if global_net is not None: + net_util.copy(global_net, self) lr_clock.tick('grad_step') return loss @@ -286,14 +288,18 @@ def forward(self, xs): outs.append(model_tail(body_x)) return outs - @net_util.dev_check_training_step - def training_step(self, loss, optim, lr_scheduler, lr_clock=None): + @net_util.dev_check_train_step + def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None): lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t')) optim.zero_grad() loss.backward() if self.clip_grad_val is not None: nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val) + if global_net is not None: + net_util.push_global_grads(self, global_net) optim.step() + if global_net is not None: + net_util.copy(global_net, self) lr_clock.tick('grad_step') return loss diff --git a/slm_lab/agent/net/net_util.py b/slm_lab/agent/net/net_util.py index c9b18a60a..1764cf67c 100644 --- a/slm_lab/agent/net/net_util.py +++ b/slm_lab/agent/net/net_util.py @@ -1,6 +1,5 @@ from functools import partial, wraps -from slm_lab import ROOT_DIR -from slm_lab.lib import logger, util +from slm_lab.lib import logger, optimizer, util import os import pydash as ps import torch @@ -8,6 +7,9 @@ logger = logger.get_logger(__name__) +# register custom torch.optim +setattr(torch.optim, 'GlobalAdam', optimizer.GlobalAdam) + class NoOpLRScheduler: '''Symbolic LRScheduler class for API consistency''' @@ -179,7 +181,8 @@ def save_algorithm(algorithm, ckpt=None): net = getattr(algorithm, net_name) model_path = f'{prepath}_{net_name}_model.pth' save(net, model_path) - optim = getattr(algorithm, net_name.replace('net', 'optim'), None) + optim_name = net_name.replace('net', 'optim') + optim = getattr(algorithm, optim_name, None) if optim is not None: # only trainable net has optim optim_path = f'{prepath}_{net_name}_optim.pth' save(optim, optim_path) @@ -206,7 +209,8 @@ def load_algorithm(algorithm): net = getattr(algorithm, net_name) model_path = f'{prepath}_{net_name}_model.pth' load(net, model_path) - optim = getattr(algorithm, net_name.replace('net', 'optim'), None) + optim_name = net_name.replace('net', 'optim') + optim = getattr(algorithm, optim_name, None) if optim is not None: # only trainable net has optim optim_path = f'{prepath}_{net_name}_optim.pth' load(optim, optim_path) @@ -226,31 +230,31 @@ def polyak_update(src_net, tar_net, old_ratio=0.5): tar_param.data.copy_(old_ratio * src_param.data + (1.0 - old_ratio) * tar_param.data) -def to_check_training_step(): +def to_check_train_step(): '''Condition for running assert_trained''' return os.environ.get('PY_ENV') == 'test' or util.get_lab_mode() == 'dev' -def dev_check_training_step(fn): +def dev_check_train_step(fn): ''' - Decorator to check if net.training_step actually updates the network weights properly - Triggers only if to_check_training_step is True (dev/test mode) + Decorator to check if net.train_step actually updates the network weights properly + Triggers only if to_check_train_step is True (dev/test mode) @example - @net_util.dev_check_training_step - def training_step(self, ...): + @net_util.dev_check_train_step + def train_step(self, ...): ... ''' @wraps(fn) def check_fn(*args, **kwargs): - if not to_check_training_step(): + if not to_check_train_step(): return fn(*args, **kwargs) net = args[0] # first arg self # get pre-update parameters to compare pre_params = [param.clone() for param in net.parameters()] - # run training_step, get loss + # run train_step, get loss loss = fn(*args, **kwargs) assert not torch.isnan(loss).any(), loss @@ -264,8 +268,8 @@ def check_fn(*args, **kwargs): else: # check parameter updates try: - assert not all(torch.equal(w1, w2) for w1, w2 in zip(pre_params, post_params)), f'Model parameter is not updated in training_step(), check if your tensor is detached from graph. Loss: {loss:g}' - logger.info(f'Model parameter is updated in training_step(). Loss: {loss: g}') + assert not all(torch.equal(w1, w2) for w1, w2 in zip(pre_params, post_params)), f'Model parameter is not updated in train_step(), check if your tensor is detached from graph. Loss: {loss:g}' + logger.info(f'Model parameter is updated in train_step(). Loss: {loss: g}') except Exception as e: logger.error(e) if os.environ.get('PY_ENV') == 'test': @@ -296,3 +300,54 @@ def get_grad_norms(algorithm): if net.grad_norms is not None: grad_norms.extend(net.grad_norms) return grad_norms + + +def init_global_nets(algorithm): + ''' + Initialize global_nets for Hogwild using an identical instance of an algorithm from an isolated Session + in spec.meta.distributed, specify either: + - 'shared': global network parameter is shared all the time. In this mode, algorithm local network will be replaced directly by global_net via overriding by identify attribute name + - 'synced': global network parameter is periodically synced to local network after each gradient push. In this mode, algorithm will keep a separate reference to `global_{net}` for each of its network + ''' + dist_mode = algorithm.agent.spec['meta']['distributed'] + assert dist_mode in ('shared', 'synced'), f'Unrecognized distributed mode' + global_nets = {} + for net_name in algorithm.net_names: + optim_name = net_name.replace('net', 'optim') + if not hasattr(algorithm, optim_name): # only for trainable network, i.e. has an optim + continue + g_net = getattr(algorithm, net_name) + g_net.share_memory() # make net global + if dist_mode == 'shared': # use the same name to override the local net + global_nets[net_name] = g_net + else: # keep a separate reference for syncing + global_nets[f'global_{net_name}'] = g_net + # if optim is Global, set to override the local optim and its scheduler + optim = getattr(algorithm, optim_name) + if 'Global' in util.get_class_name(optim): + optim.share_memory() # make optim global + global_nets[optim_name] = optim + lr_scheduler_name = net_name.replace('net', 'lr_scheduler') + lr_scheduler = getattr(algorithm, lr_scheduler_name) + global_nets[lr_scheduler_name] = lr_scheduler + logger.info(f'Initialized global_nets attr {list(global_nets.keys())} for Hogwild') + return global_nets + + +def set_global_nets(algorithm, global_nets): + '''For Hogwild, set attr built in init_global_nets above. Use in algorithm init.''' + # set attr first so algorithm always has self.global_{net} to pass into train_step + for net_name in algorithm.net_names: + setattr(algorithm, f'global_{net_name}', None) + # set attr created in init_global_nets + if global_nets is not None: + util.set_attr(algorithm, global_nets) + logger.info(f'Set global_nets attr {list(global_nets.keys())} for Hogwild') + + +def push_global_grads(net, global_net): + '''Push gradients to global_net, call inside train_step between loss.backward() and optim.step()''' + for param, global_param in zip(net.parameters(), global_net.parameters()): + if global_param.grad is not None: + return # quick skip + global_param._grad = param.grad diff --git a/slm_lab/agent/net/recurrent.py b/slm_lab/agent/net/recurrent.py index 00337a181..6521d18bd 100644 --- a/slm_lab/agent/net/recurrent.py +++ b/slm_lab/agent/net/recurrent.py @@ -169,13 +169,17 @@ def forward(self, x): else: return self.model_tail(hid_x) - @net_util.dev_check_training_step - def training_step(self, loss, optim, lr_scheduler, lr_clock=None): + @net_util.dev_check_train_step + def train_step(self, loss, optim, lr_scheduler, lr_clock=None, global_net=None): lr_scheduler.step(epoch=ps.get(lr_clock, 'total_t')) optim.zero_grad() loss.backward() if self.clip_grad_val is not None: nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val) + if global_net is not None: + net_util.push_global_grads(self, global_net) optim.step() + if global_net is not None: + net_util.copy(global_net, self) lr_clock.tick('grad_step') return loss diff --git a/slm_lab/env/base.py b/slm_lab/env/base.py index c8034fd29..7ebe5bd07 100644 --- a/slm_lab/env/base.py +++ b/slm_lab/env/base.py @@ -7,7 +7,6 @@ import time ENV_DATA_NAMES = ['state', 'reward', 'done'] -NUM_EVAL_EPI = 100 # set the number of episodes to eval a model ckpt logger = logger.get_logger(__name__) @@ -115,21 +114,18 @@ def __init__(self, spec, e=None, env_space=None): 'max_tick', 'reward_scale', ]) - # infer if using RNN seq_len = ps.get(spec, 'agent.0.net.seq_len') - if seq_len is not None: + if seq_len is not None: # infer if using RNN self.frame_op = 'stack' self.frame_op_len = seq_len - if util.get_lab_mode() == 'eval': - self.num_envs = None # use singleton for eval - # override for eval, offset so epi is 0 - (num_eval_epi - 1) - self.max_tick = NUM_EVAL_EPI - 1 + if util.get_lab_mode() == 'eval': # use singleton for eval + self.num_envs = 1 self.max_tick_unit = 'epi' - if self.num_envs == 1: # guard: if 1, dont used venvs at all - self.num_envs = None - self.is_venv = self.num_envs is not None + if spec['meta']['distributed'] != False: # divide max_tick for distributed + self.max_tick = int(self.max_tick / spec['meta']['max_session']) + self.is_venv = (self.num_envs is not None and self.num_envs > 1) if self.is_venv: - assert self.log_frequency is not None, f'Specify log_frequency when using num_envs' + assert self.log_frequency is not None, f'Specify log_frequency when using venv' self.clock_speed = 1 * (self.num_envs or 1) # tick with a multiple of num_envs to properly count frames self.clock = Clock(self.max_tick, self.max_tick_unit, self.clock_speed) self.to_render = util.to_render() diff --git a/slm_lab/env/vec_env.py b/slm_lab/env/vec_env.py index 9b10e84a2..6619e2b69 100644 --- a/slm_lab/env/vec_env.py +++ b/slm_lab/env/vec_env.py @@ -213,7 +213,7 @@ def step_wait(self): def close_extras(self): ''' - Clean up the extra resources, beyond what's in this base class. + Clean up the extra resources, beyond what's in this base class. Only runs when not self.closed. ''' pass diff --git a/slm_lab/experiment/analysis.py b/slm_lab/experiment/analysis.py index 7d1e26470..d4edef27f 100644 --- a/slm_lab/experiment/analysis.py +++ b/slm_lab/experiment/analysis.py @@ -290,11 +290,11 @@ def plot_session(session_spec, session_data): aeb_df = session_data[(a, e, b)] aeb_df.fillna(0, inplace=True) # for saving plot, cant have nan fig_1 = viz.plot_line(aeb_df, 'reward_ma', max_tick_unit, legend_name=aeb_str, draw=False, trace_kwargs={'legendgroup': aeb_str, 'line': {'color': palette[idx]}}) - fig.append_trace(fig_1.data[0], 1, 1) + fig.add_trace(fig_1.data[0], 1, 1) fig_2 = viz.plot_line(aeb_df, ['loss'], max_tick_unit, y2_col=['explore_var'], trace_kwargs={'legendgroup': aeb_str, 'showlegend': False, 'line': {'color': palette[idx]}}, draw=False) - fig.append_trace(fig_2.data[0], 2, 1) - fig.append_trace(fig_2.data[1], 3, 1) + fig.add_trace(fig_2.data[0], 2, 1) + fig.add_trace(fig_2.data[1], 3, 1) fig.layout['xaxis1'].update(title=max_tick_unit, zerolinewidth=1) fig.layout['yaxis1'].update(fig_1.layout['yaxis']) @@ -426,7 +426,7 @@ def plot_experiment(experiment_spec, experiment_df): 'colorscale': 'YlGnBu', 'reversescale': True }, ) - fig.append_trace(trace, row_idx + 1, col_idx + 1) + fig.add_trace(trace, row_idx + 1, col_idx + 1) fig.layout[f'xaxis{col_idx+1}'].update(title='
'.join(ps.chunk(x, 20)), zerolinewidth=1, categoryarray=sorted(guard_cat_x.unique())) fig.layout[f'yaxis{row_idx+1}'].update(title=y, rangemode='tozero') fig.layout.update(title=f'experiment graph: {experiment_spec["name"]}', width=max(600, len(x_cols) * 300), height=700) diff --git a/slm_lab/experiment/control.py b/slm_lab/experiment/control.py index 4998bfed9..c15a6877f 100644 --- a/slm_lab/experiment/control.py +++ b/slm_lab/experiment/control.py @@ -5,6 +5,7 @@ from copy import deepcopy from importlib import reload from slm_lab.agent import AgentSpace, Agent +from slm_lab.agent.net import net_util from slm_lab.env import EnvSpace, make_env from slm_lab.experiment import analysis, retro_analysis, search from slm_lab.experiment.monitor import AEBSpace, Body, enable_aeb_space @@ -254,23 +255,14 @@ def run_sessions(self): break return session_datas - def make_global_nets(self, agent): - global_nets = {} - for net_name in agent.algorithm.net_names: - g_net = getattr(agent.algorithm, net_name) - g_net.share_memory() # make net global - # TODO also create shared optimizer here - global_nets[net_name] = g_net - return global_nets - def init_global_nets(self): session = self.SessionClass(deepcopy(self.spec)) if self.is_singleton: session.env.close() # safety - global_nets = self.make_global_nets(session.agent) + global_nets = net_util.init_global_nets(session.agent.algorithm) else: session.env_space.close() # safety - global_nets = [self.make_global_nets(agent) for agent in session.agent_space.agents] + global_nets = [net_util.init_global_nets(agent.algorithm) for agent in session.agent_space.agents] return global_nets def run_distributed_sessions(self): @@ -283,10 +275,10 @@ def close(self): logger.info('Trial done and closed.') def run(self): - if self.spec['meta'].get('distributed'): - session_datas = self.run_distributed_sessions() - else: + if self.spec['meta'].get('distributed') == False: session_datas = self.run_sessions() + else: + session_datas = self.run_distributed_sessions() self.session_data_dict = {data.index[0]: data for data in session_datas} self.data = analysis.analyze_trial(self) self.close() diff --git a/slm_lab/experiment/monitor.py b/slm_lab/experiment/monitor.py index 59bb63b15..13ec3ea4a 100644 --- a/slm_lab/experiment/monitor.py +++ b/slm_lab/experiment/monitor.py @@ -148,7 +148,7 @@ def calc_df_row(self, env): fps = 0 if wall_t == 0 else total_t / wall_t # update debugging variables - if net_util.to_check_training_step(): + if net_util.to_check_train_step(): grad_norms = net_util.get_grad_norms(self.agent.algorithm) self.mean_grad_norm = np.nan if ps.is_empty(grad_norms) else np.mean(grad_norms) @@ -197,9 +197,9 @@ def get_mean_lr(self): if not hasattr(self.agent.algorithm, 'net_names'): return np.nan lrs = [] - for k, attr in self.agent.algorithm.__dict__.items(): - if k.endswith('lr_scheduler'): - lrs.append(attr.get_lr()) + for attr, obj in self.agent.algorithm.__dict__.items(): + if attr.endswith('lr_scheduler'): + lrs.append(obj.get_lr()) return np.mean(lrs) def get_log_prefix(self): diff --git a/slm_lab/experiment/search.py b/slm_lab/experiment/search.py index bf6b4ca3a..b43a95fdd 100644 --- a/slm_lab/experiment/search.py +++ b/slm_lab/experiment/search.py @@ -172,7 +172,7 @@ def run(self): run_trial = create_remote_fn(self.experiment) meta_spec = self.experiment.spec['meta'] logging.getLogger('ray').propagate = True - ray.init(**meta_spec.get('resources', {})) + ray.init(**meta_spec.get('search_resources', {})) register_ray_serializer() max_trial = meta_spec['max_trial'] trial_data_dict = {} @@ -250,7 +250,7 @@ def run(self): run_trial = create_remote_fn(self.experiment) meta_spec = self.experiment.spec['meta'] logging.getLogger('ray').propagate = True - ray.init(**meta_spec.get('resources', {})) + ray.init(**meta_spec.get('search_resources', {})) register_ray_serializer() max_generation = meta_spec['max_generation'] pop_size = meta_spec['max_trial'] or calc_population_size(self.experiment) diff --git a/slm_lab/lib/logger.py b/slm_lab/lib/logger.py index bec5fe820..ca312b556 100644 --- a/slm_lab/lib/logger.py +++ b/slm_lab/lib/logger.py @@ -67,7 +67,7 @@ def info(msg, *args, **kwargs): def warn(msg, *args, **kwargs): - return lab_logger.warn(msg, *args, **kwargs) + return lab_logger.warning(msg, *args, **kwargs) def get_logger(__name__): diff --git a/slm_lab/lib/optimizer.py b/slm_lab/lib/optimizer.py new file mode 100644 index 000000000..932b85d4a --- /dev/null +++ b/slm_lab/lib/optimizer.py @@ -0,0 +1,101 @@ +import math +import torch + + +class GlobalAdam(torch.optim.Adam): + ''' + Global Adam algorithm with shared states for Hogwild. + Adapted from https://github.com/ikostrikov/pytorch-a3c/blob/master/my_optim.py (MIT) + ''' + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + super().__init__(params, lr, betas, eps, weight_decay) + + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'] = torch.zeros(1) + state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() + state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() + + def share_memory(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'].share_memory_() + state['exp_avg'].share_memory_() + state['exp_avg_sq'].share_memory_() + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add_(group['eps']) + bias_correction1 = 1 - beta1 ** state['step'].item() + bias_correction2 = 1 - beta2 ** state['step'].item() + step_size = group['lr'] * math.sqrt( + bias_correction2) / bias_correction1 + p.data.addcdiv_(-step_size, exp_avg, denom) + return loss + + +class GlobalRMSprop(torch.optim.RMSprop): + ''' + Global RMSprop algorithm with shared states for Hogwild. + Adapted from https://github.com/jingweiz/pytorch-rl/blob/master/optims/sharedRMSprop.py (MIT) + ''' + + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0): + super().__init__(params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=0, centered=False) + + # State initialisation (must be done before step, else will not be shared between threads) + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'] = p.data.new().resize_(1).zero_() + state['square_avg'] = p.data.new().resize_as_(p.data).zero_() + + def share_memory(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'].share_memory_() + state['square_avg'].share_memory_() + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + square_avg = state['square_avg'] + alpha = group['alpha'] + state['step'] += 1 + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) + avg = square_avg.sqrt().add_(group['eps']) + p.data.addcdiv_(-group['lr'], grad, avg) + return loss diff --git a/slm_lab/lib/util.py b/slm_lab/lib/util.py index 1dc02ce7f..8294eb745 100644 --- a/slm_lab/lib/util.py +++ b/slm_lab/lib/util.py @@ -568,10 +568,13 @@ def set_cuda_id(spec): for agent_spec in spec['agent']: if not agent_spec['net'].get('gpu'): return - trial_idx = spec['meta']['trial'] or 0 - session_idx = spec['meta']['session'] or 0 - job_idx = trial_idx * spec['meta']['max_session'] + session_idx - job_idx += spec['meta']['cuda_offset'] + meta_spec = spec['meta'] + trial_idx = meta_spec['trial'] or 0 + session_idx = meta_spec['session'] or 0 + if meta_spec['distributed'] == 'shared': # shared hogwild uses only global networks, offset them to idx 0 + session_idx = 0 + job_idx = trial_idx * meta_spec['max_session'] + session_idx + job_idx += meta_spec['cuda_offset'] device_count = torch.cuda.device_count() cuda_id = None if not device_count else job_idx % device_count diff --git a/slm_lab/spec/base.json b/slm_lab/spec/base.json index 100e70a65..bf587f45c 100644 --- a/slm_lab/spec/base.json +++ b/slm_lab/spec/base.json @@ -31,7 +31,7 @@ "max_session": 1, "max_trial": 1, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 4, "num_gpus": 0 } diff --git a/slm_lab/spec/benchmark/ddqn_lunar.json b/slm_lab/spec/benchmark/ddqn_lunar.json index cf20aaab2..b3da25d31 100644 --- a/slm_lab/spec/benchmark/ddqn_lunar.json +++ b/slm_lab/spec/benchmark/ddqn_lunar.json @@ -72,7 +72,7 @@ "max_session": 4, "max_trial": 62, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 62 } }, diff --git a/slm_lab/spec/benchmark/dqn_lunar.json b/slm_lab/spec/benchmark/dqn_lunar.json index 95dfd3cd9..651bcaecf 100644 --- a/slm_lab/spec/benchmark/dqn_lunar.json +++ b/slm_lab/spec/benchmark/dqn_lunar.json @@ -71,7 +71,7 @@ "max_session": 4, "max_trial": 62, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 62 } }, diff --git a/slm_lab/spec/demo.json b/slm_lab/spec/demo.json index 79e415432..97548f419 100644 --- a/slm_lab/spec/demo.json +++ b/slm_lab/spec/demo.json @@ -65,7 +65,7 @@ "max_trial": 4, "max_session": 1, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 4, "num_gpus": 0 } diff --git a/slm_lab/spec/experimental/a2c.json b/slm_lab/spec/experimental/a2c.json index 064c78756..9dc79fe42 100644 --- a/slm_lab/spec/experimental/a2c.json +++ b/slm_lab/spec/experimental/a2c.json @@ -848,7 +848,7 @@ "max_session": 4, "max_trial": 1, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16, } } diff --git a/slm_lab/spec/experimental/a2c/a2c_pong.json b/slm_lab/spec/experimental/a2c/a2c_pong.json new file mode 100644 index 000000000..36076443c --- /dev/null +++ b/slm_lab/spec/experimental/a2c/a2c_pong.json @@ -0,0 +1,84 @@ +{ + "a2c_pong": { + "agent": [{ + "name": "A2C", + "algorithm": { + "name": "ActorCritic", + "action_pdtype": "default", + "action_policy": "default", + "explore_var_spec": null, + "gamma": 0.99, + "lam": null, + "num_step_returns": 5, + "entropy_coef_spec": { + "name": "no_decay", + "start_val": 0.01, + "end_val": 0.01, + "start_step": 0, + "end_step": 0 + }, + "val_loss_coef": 0.5, + "training_frequency": 5, + "normalize_state": false + }, + "memory": { + "name": "OnPolicyBatchReplay" + }, + "net": { + "type": "ConvNet", + "shared": true, + "conv_hid_layers": [ + [32, 8, 4, 0, 1], + [64, 4, 2, 0, 1], + [32, 3, 1, 0, 1] + ], + "fc_hid_layers": [512], + "hid_layers_activation": "relu", + "init_fn": "orthogonal_", + "normalize": true, + "batch_norm": false, + "clip_grad_val": 0.5, + "use_same_optim": false, + "loss_spec": { + "name": "MSELoss" + }, + "actor_optim_spec": { + "name": "RMSprop", + "lr": 7e-4, + "alpha": 0.99, + "eps": 1e-5 + }, + "critic_optim_spec": { + "name": "RMSprop", + "lr": 7e-4, + "alpha": 0.99, + "eps": 1e-5 + }, + "lr_scheduler_spec": null, + "gpu": true + } + }], + "env": [{ + "name": "PongNoFrameskip-v4", + "frame_op": "concat", + "frame_op_len": 4, + "reward_scale": "sign", + "num_envs": 16, + "max_t": null, + "max_tick": 1e7 + }], + "body": { + "product": "outer", + "num": 1, + }, + "meta": { + "distributed": false, + "log_frequency": 50000, + "eval_frequency": 50000, + "max_tick_unit": "total_t", + "max_session": 4, + "max_trial": 1, + "param_spec_process": 4 + } + } +} diff --git a/slm_lab/spec/experimental/a3c/a3c.json b/slm_lab/spec/experimental/a3c/a3c.json index 13a7081e2..40ce15376 100644 --- a/slm_lab/spec/experimental/a3c/a3c.json +++ b/slm_lab/spec/experimental/a3c/a3c.json @@ -59,7 +59,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -143,7 +143,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -227,7 +227,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -315,7 +315,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -403,7 +403,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -487,7 +487,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -571,7 +571,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -659,7 +659,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -747,7 +747,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -836,7 +836,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 1, @@ -909,7 +909,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 1, diff --git a/slm_lab/spec/experimental/a3c/a3c_atari.json b/slm_lab/spec/experimental/a3c/a3c_atari.json new file mode 100644 index 000000000..2479f5f4c --- /dev/null +++ b/slm_lab/spec/experimental/a3c/a3c_atari.json @@ -0,0 +1,89 @@ +{ + "a3c_pong": { + "agent": [{ + "name": "A3C", + "algorithm": { + "name": "ActorCritic", + "action_pdtype": "default", + "action_policy": "default", + "explore_var_spec": null, + "gamma": 0.99, + "lam": null, + "num_step_returns": 5, + "entropy_coef_spec": { + "name": "no_decay", + "start_val": 0.01, + "end_val": 0.01, + "start_step": 0, + "end_step": 0 + }, + "val_loss_coef": 0.5, + "training_frequency": 5, + "normalize_state": false + }, + "memory": { + "name": "OnPolicyBatchReplay", + }, + "net": { + "type": "ConvNet", + "shared": true, + "conv_hid_layers": [ + [32, 8, 4, 0, 1], + [64, 4, 2, 0, 1], + [32, 3, 1, 0, 1] + ], + "fc_hid_layers": [512], + "hid_layers_activation": "relu", + "init_fn": "orthogonal_", + "normalize": true, + "batch_norm": false, + "clip_grad_val": 0.5, + "use_same_optim": false, + "loss_spec": { + "name": "MSELoss" + }, + "actor_optim_spec": { + "name": "RMSprop", + "lr": 7e-4, + "alpha": 0.99, + "eps": 1e-5 + }, + "critic_optim_spec": { + "name": "RMSprop", + "lr": 7e-4, + "alpha": 0.99, + "eps": 1e-5 + }, + "lr_scheduler_spec": null, + "gpu": false + } + }], + "env": [{ + "name": "${env}", + "frame_op": "concat", + "frame_op_len": 4, + "reward_scale": "sign", + "num_envs": 1, + "max_t": null, + "max_tick": 1e7 + }], + "body": { + "product": "outer", + "num": 1 + }, + "meta": { + "distributed": "synced", + "log_frequency": 50000, + "eval_frequency": 50000, + "max_tick_unit": "total_t", + "max_session": 16, + "max_trial": 1, + "param_spec_process": 4 + }, + "spec_params": { + "env": [ + "BeamRiderNoFrameskip-v4", "BreakoutNoFrameskip-v4", "EnduroNoFrameskip-v4", "MsPacmanNoFrameskip-v4", "PongNoFrameskip-v4", "QbertNoFrameskip-v4", "SeaquestNoFrameskip-v4", "SpaceInvadersNoFrameskip-v4" + ] + } + } +} diff --git a/slm_lab/spec/experimental/a3c/a3c_gae_atari.json b/slm_lab/spec/experimental/a3c/a3c_gae_atari.json index 672a49b80..3ff0d84df 100644 --- a/slm_lab/spec/experimental/a3c/a3c_gae_atari.json +++ b/slm_lab/spec/experimental/a3c/a3c_gae_atari.json @@ -1,7 +1,7 @@ { "a3c_gae_atari": { "agent": [{ - "name": "A2C", + "name": "A3C", "algorithm": { "name": "ActorCritic", "action_pdtype": "default", @@ -72,7 +72,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "log_frequency": 50000, "eval_frequency": 50000, "max_tick_unit": "total_t", diff --git a/slm_lab/spec/experimental/a3c/a3c_gae_pong.json b/slm_lab/spec/experimental/a3c/a3c_gae_pong.json index f6605aeb4..927816795 100644 --- a/slm_lab/spec/experimental/a3c/a3c_gae_pong.json +++ b/slm_lab/spec/experimental/a3c/a3c_gae_pong.json @@ -1,7 +1,7 @@ { "a3c_gae_pong": { "agent": [{ - "name": "A2C", + "name": "A3C", "algorithm": { "name": "ActorCritic", "action_pdtype": "default", @@ -72,7 +72,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "log_frequency": 50000, "eval_frequency": 50000, "max_tick_unit": "total_t", diff --git a/slm_lab/spec/experimental/a3c/a3c_pong.json b/slm_lab/spec/experimental/a3c/a3c_pong.json index 01fca7e29..8d9369beb 100644 --- a/slm_lab/spec/experimental/a3c/a3c_pong.json +++ b/slm_lab/spec/experimental/a3c/a3c_pong.json @@ -1,7 +1,7 @@ { "a3c_pong": { "agent": [{ - "name": "A2C", + "name": "A3C", "algorithm": { "name": "ActorCritic", "action_pdtype": "default", @@ -43,16 +43,12 @@ "name": "MSELoss" }, "actor_optim_spec": { - "name": "RMSprop", - "lr": 7e-4, - "alpha": 0.99, - "eps": 1e-5 + "name": "GlobalAdam", + "lr": 1e-4 }, "critic_optim_spec": { - "name": "RMSprop", - "lr": 7e-4, - "alpha": 0.99, - "eps": 1e-5 + "name": "GlobalAdam", + "lr": 1e-4 }, "lr_scheduler_spec": null, "gpu": false @@ -63,7 +59,7 @@ "frame_op": "concat", "frame_op_len": 4, "reward_scale": "sign", - "num_envs": 1, + "num_envs": 8, "max_t": null, "max_tick": 1e7 }], @@ -72,8 +68,8 @@ "num": 1 }, "meta": { - "distributed": true, - "log_frequency": 50000, + "distributed": "synced", + "log_frequency": 10000, "eval_frequency": 50000, "max_tick_unit": "total_t", "max_session": 16, diff --git a/slm_lab/spec/experimental/cartpole.json b/slm_lab/spec/experimental/cartpole.json index 9b474109e..66e504da3 100644 --- a/slm_lab/spec/experimental/cartpole.json +++ b/slm_lab/spec/experimental/cartpole.json @@ -57,7 +57,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -140,7 +140,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -227,7 +227,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -314,7 +314,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -404,7 +404,7 @@ "max_session": 4, "max_trial": 23, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -498,7 +498,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -588,7 +588,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -676,7 +676,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -761,13 +761,13 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "total_t", "max_session": 4, "max_trial": 23, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -858,7 +858,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -954,7 +954,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -1056,7 +1056,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -1153,7 +1153,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -1236,7 +1236,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -1317,7 +1317,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -1405,7 +1405,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -1490,13 +1490,13 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "total_t", "max_session": 4, "max_trial": 23, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -1584,7 +1584,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -1678,7 +1678,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -1772,7 +1772,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -1864,7 +1864,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -1956,7 +1956,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -2050,7 +2050,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -2144,7 +2144,7 @@ "max_session": 4, "max_trial": 64, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16 } }, @@ -2235,7 +2235,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, diff --git a/slm_lab/spec/experimental/dqn.json b/slm_lab/spec/experimental/dqn.json index 7620d7528..49ba8ccd7 100644 --- a/slm_lab/spec/experimental/dqn.json +++ b/slm_lab/spec/experimental/dqn.json @@ -583,7 +583,7 @@ "max_session": 1, "max_trial": 16, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 16, } }, diff --git a/slm_lab/spec/experimental/dqn/lunar_dqn.json b/slm_lab/spec/experimental/dqn/lunar_dqn.json index 616cd4abe..a5465677d 100644 --- a/slm_lab/spec/experimental/dqn/lunar_dqn.json +++ b/slm_lab/spec/experimental/dqn/lunar_dqn.json @@ -66,7 +66,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -163,7 +163,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -260,7 +260,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -357,7 +357,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -454,7 +454,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -551,7 +551,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -648,7 +648,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -745,7 +745,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -846,7 +846,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -943,7 +943,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, diff --git a/slm_lab/spec/experimental/misc/gridworld.json b/slm_lab/spec/experimental/misc/gridworld.json index a03b37444..49ead2c2b 100644 --- a/slm_lab/spec/experimental/misc/gridworld.json +++ b/slm_lab/spec/experimental/misc/gridworld.json @@ -59,7 +59,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -148,7 +148,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -233,7 +233,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -322,7 +322,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -415,7 +415,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -512,7 +512,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -607,7 +607,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -704,7 +704,7 @@ "max_session": 4, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, diff --git a/slm_lab/spec/experimental/misc/lunar_pg.json b/slm_lab/spec/experimental/misc/lunar_pg.json index ba67bcea8..848c976d0 100644 --- a/slm_lab/spec/experimental/misc/lunar_pg.json +++ b/slm_lab/spec/experimental/misc/lunar_pg.json @@ -57,7 +57,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -150,7 +150,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -249,7 +249,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95, } }, @@ -350,7 +350,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -449,7 +449,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -551,7 +551,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -656,7 +656,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -756,7 +756,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -848,7 +848,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -942,7 +942,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1040,7 +1040,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1129,7 +1129,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1223,7 +1223,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1321,7 +1321,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1415,7 +1415,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1512,7 +1512,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1608,7 +1608,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1704,7 +1704,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1807,7 +1807,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -1908,7 +1908,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -2009,7 +2009,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -2107,7 +2107,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -2216,7 +2216,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, @@ -2323,7 +2323,7 @@ "max_session": 2, "max_trial": 95, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 91, } }, diff --git a/slm_lab/spec/experimental/misc/mountain_car.json b/slm_lab/spec/experimental/misc/mountain_car.json index d3529a20f..3139e5f54 100644 --- a/slm_lab/spec/experimental/misc/mountain_car.json +++ b/slm_lab/spec/experimental/misc/mountain_car.json @@ -65,7 +65,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -166,7 +166,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -261,7 +261,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -358,7 +358,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -454,7 +454,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -549,7 +549,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -643,7 +643,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -738,7 +738,7 @@ "max_session": 4, "max_trial": 200, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, diff --git a/slm_lab/spec/experimental/misc/pendulum.json b/slm_lab/spec/experimental/misc/pendulum.json index 8e9b5b0d4..39b83a2db 100644 --- a/slm_lab/spec/experimental/misc/pendulum.json +++ b/slm_lab/spec/experimental/misc/pendulum.json @@ -65,7 +65,7 @@ "max_session": 4, "max_trial": 190, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -162,7 +162,7 @@ "max_session": 4, "max_trial": 190, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -256,7 +256,7 @@ "max_session": 4, "max_trial": 190, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -353,7 +353,7 @@ "max_session": 4, "max_trial": 190, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, @@ -454,7 +454,7 @@ "max_session": 4, "max_trial": 190, "search": "RandomSearch", - "resources": { + "search_resources": { "num_cpus": 95 } }, diff --git a/slm_lab/spec/experimental/ppo/dppo.json b/slm_lab/spec/experimental/ppo/dppo.json index 866759768..a5279c532 100644 --- a/slm_lab/spec/experimental/ppo/dppo.json +++ b/slm_lab/spec/experimental/ppo/dppo.json @@ -64,7 +64,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -156,7 +156,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -252,7 +252,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -348,7 +348,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -440,7 +440,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -532,7 +532,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -628,7 +628,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -724,7 +724,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 4, @@ -821,7 +821,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 1, @@ -899,7 +899,7 @@ "num": 1 }, "meta": { - "distributed": true, + "distributed": "synced", "eval_frequency": 1000, "max_tick_unit": "epi", "max_session": 1, diff --git a/slm_lab/spec/spec_util.py b/slm_lab/spec/spec_util.py index bc6105230..4778235f2 100644 --- a/slm_lab/spec/spec_util.py +++ b/slm_lab/spec/spec_util.py @@ -37,7 +37,6 @@ "num": (int, list), }, "meta": { - "distributed": bool, "eval_frequency": (int, float), "max_tick_unit": str, "max_session": int, @@ -80,8 +79,8 @@ def check_body_spec(spec): def check_compatibility(spec): '''Check compatibility among spec setups''' # TODO expand to be more comprehensive - if spec['meta'].get('distributed'): - assert ps.get(spec, 'agent.0.net.gpu') != True, f'Hogwild lock-free does not work with GPU locked CUDA tensors. Set gpu: false.' + if spec['meta'].get('distributed') == 'synced': + assert ps.get(spec, 'agent.0.net.gpu') == False, f'Distributed mode "synced" works with CPU only. Set gpu: false.' def check(spec): diff --git a/test/agent/net/test_conv.py b/test/agent/net/test_conv.py index 929ac33b4..264b774fd 100644 --- a/test/agent/net/test_conv.py +++ b/test/agent/net/test_conv.py @@ -56,11 +56,11 @@ def test_forward(): assert y.shape == (batch_size, out_dim) -def test_training_step(): +def test_train_step(): y = torch.rand((batch_size, out_dim)) clock = Clock(100, 'total_t', 1) loss = net.loss_fn(net.forward(x), y) - net.training_step(loss, optim, lr_scheduler, lr_clock=clock) + net.train_step(loss, optim, lr_scheduler, lr_clock=clock) assert loss != 0.0 diff --git a/test/agent/net/test_mlp.py b/test/agent/net/test_mlp.py index c032c2fa0..d70ab8235 100644 --- a/test/agent/net/test_mlp.py +++ b/test/agent/net/test_mlp.py @@ -52,11 +52,11 @@ def test_forward(): assert y.shape == (batch_size, out_dim) -def test_training_step(): +def test_train_step(): y = torch.rand((batch_size, out_dim)) clock = Clock(100, 'total_t', 1) loss = net.loss_fn(net.forward(x), y) - net.training_step(loss, optim, lr_scheduler, lr_clock=clock) + net.train_step(loss, optim, lr_scheduler, lr_clock=clock) assert loss != 0.0 diff --git a/test/agent/net/test_recurrent.py b/test/agent/net/test_recurrent.py index 418a54bbc..642202219 100644 --- a/test/agent/net/test_recurrent.py +++ b/test/agent/net/test_recurrent.py @@ -59,11 +59,11 @@ def test_forward(): assert y.shape == (batch_size, out_dim) -def test_training_step(): +def test_train_step(): y = torch.rand((batch_size, out_dim)) clock = Clock(100, 'total_t', 1) loss = net.loss_fn(net.forward(x), y) - net.training_step(loss, optim, lr_scheduler, lr_clock=clock) + net.train_step(loss, optim, lr_scheduler, lr_clock=clock) assert loss != 0.0 diff --git a/test/experiment/test_control.py b/test/experiment/test_control.py index bc0116a0c..a100881c3 100644 --- a/test/experiment/test_control.py +++ b/test/experiment/test_control.py @@ -25,7 +25,7 @@ def test_session_total_t(test_spec): env_spec['max_tick'] = 30 spec['meta']['max_tick_unit'] = 'total_t' session = Session(spec) - assert session.env.max_tick_unit == 'total_t' + assert session.env.clock.max_tick_unit == 'total_t' session_data = session.run() assert isinstance(session_data, pd.DataFrame) diff --git a/test/spec/test_dist_spec.py b/test/spec/test_dist_spec.py index b30493770..08734bb3a 100644 --- a/test/spec/test_dist_spec.py +++ b/test/spec/test_dist_spec.py @@ -15,7 +15,7 @@ def run_trial_test_dist(spec_file, spec_name=False): spec = spec_util.get(spec_file, spec_name) spec = spec_util.override_test_spec(spec) spec_util.tick(spec, 'trial') - spec['meta']['distributed'] = True + spec['meta']['distributed'] = 'synced' spec['meta']['max_session'] = 2 trial = Trial(spec) diff --git a/test/spec/test_spec.py b/test/spec/test_spec.py index dbce3ff95..7c104da91 100644 --- a/test/spec/test_spec.py +++ b/test/spec/test_spec.py @@ -194,6 +194,7 @@ def test_atari(spec_file, spec_name): run_trial_test(spec_file, spec_name) +@flaky @pytest.mark.parametrize('spec_file,spec_name', [ ('experimental/reinforce.json', 'reinforce_conv_vizdoom'), ])