diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 65706d00f..0d1b6f266 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -4,6 +4,7 @@ from typing import Optional, Tuple import torch +import wandb from ding.config import compile_config from ding.envs import create_env_manager from ding.envs import get_vec_env_setting @@ -81,6 +82,16 @@ def train_muzero( if cfg.policy.eval_offline: cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq + if cfg.policy.use_wandb: + # Initialize wandb + wandb.init( + project="LightZero", + config=cfg, + sync_tensorboard=False, + monitor_gym=False, + save_code=True, + ) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # load pretrained model @@ -103,7 +114,7 @@ def train_muzero( policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, - policy_config=policy_config + policy_config=policy_config, ) evaluator = Evaluator( eval_freq=cfg.policy.eval_freq, @@ -121,6 +132,8 @@ def train_muzero( # ============================================================== # Learner's before_run hook. learner.call_hook('before_run') + if policy_config.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) if cfg.policy.update_per_collect is not None: update_per_collect = cfg.policy.update_per_collect @@ -199,6 +212,9 @@ def train_muzero( ) break + if policy_config.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + # The core train steps for MCTS+RL algorithms. log_vars = learner.train(train_data, collector.envstep) @@ -223,4 +239,5 @@ def train_muzero( # Learner's after_run hook. learner.call_hook('after_run') + wandb.finish() return policy diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index e12dd19d5..cd7ff7605 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -4,6 +4,7 @@ from typing import Tuple, Optional import torch +import wandb from ding.config import compile_config from ding.envs import create_env_manager from ding.envs import get_vec_env_setting @@ -76,6 +77,16 @@ def train_unizero( evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + if cfg.policy.use_wandb: + # Initialize wandb + wandb.init( + project="LightZero", + config=cfg, + sync_tensorboard=False, + monitor_gym=False, + save_code=True, + ) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # Load pretrained model if specified @@ -99,6 +110,8 @@ def train_unizero( # Learner's before_run hook learner.call_hook('before_run') + if policy_config.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) # Collect random data before training if cfg.policy.random_collect_episode_num > 0: @@ -172,6 +185,9 @@ def train_unizero( # Clear caches and precompute positional embedding matrices policy.recompute_pos_emb_diff_and_clear_cache() # TODO + if policy_config.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + train_data.append({'train_which_component': 'transformer'}) log_vars = learner.train(train_data, collector.envstep) @@ -185,4 +201,5 @@ def train_unizero( break learner.call_hook('after_run') + wandb.finish() return policy diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 9769c10cd..ccd9831a3 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -4,10 +4,13 @@ import numpy as np import torch import torch.optim as optim +import wandb from ding.model import model_wrap from ding.policy.base_policy import Policy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY +from torch.nn import L1Loss + from lzero.entry.utils import initialize_zeros_batch from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree @@ -16,7 +19,6 @@ from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ prepare_obs, configure_optimizers -from torch.nn import L1Loss @POLICY_REGISTRY.register('muzero') @@ -75,6 +77,8 @@ class MuZeroPolicy(Policy): harmony_balance=False ), # ****** common ****** + # (bool) Whether to use wandb to log the training process. + use_wandb=False, # (bool) whether to use rnd model. use_rnd_model=False, # (bool) Whether to use multi-gpu training. @@ -253,6 +257,17 @@ def default_model(self) -> Tuple[str, List[str]]: else: raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + def set_train_iter_env_step(self, train_iter, env_step) -> None: + """ + Overview: + Set the train_iter and env_step for the policy. + Arguments: + - train_iter (:obj:`int`): The train_iter for the policy. + - env_step (:obj:`int`): The env_step for the policy. + """ + self.train_iter = train_iter + self.env_step = env_step + def _init_learn(self) -> None: """ Overview: @@ -338,6 +353,10 @@ def _init_learn(self) -> None: self.dormant_ratio_encoder = 0. self.dormant_ratio_dynamics = 0. + if self._cfg.use_wandb: + # TODO: add the model to wandb + wandb.watch(self._learn_model.representation_network, log="all") + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ Overview: @@ -596,7 +615,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) - return_dict = { + return_log_dict = { 'collect_mcts_temperature': self._collect_mcts_temperature, 'collect_epsilon': self.collect_epsilon, 'cur_lr': self._optimizer.param_groups[0]['lr'], @@ -644,8 +663,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in "harmony_entropy": self.harmony_entropy.item(), "harmony_entropy_exp_recip": (1 / torch.exp(self.harmony_entropy)).item(), } - return_dict.update(harmony_dict) - return return_dict + return_log_dict.update(harmony_dict) + + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_log_dict def _init_collect(self) -> None: """ diff --git a/lzero/policy/sampled_muzero.py b/lzero/policy/sampled_muzero.py index d47471852..636683e1d 100644 --- a/lzero/policy/sampled_muzero.py +++ b/lzero/policy/sampled_muzero.py @@ -3,6 +3,7 @@ import numpy as np import torch +import wandb import torch.optim as optim from ding.model import model_wrap from ding.torch_utils import to_tensor @@ -520,7 +521,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) - return_data = { + return_log_dict = { 'cur_lr': self._optimizer.param_groups[0]['lr'], 'collect_mcts_temperature': self._collect_mcts_temperature, 'weighted_total_loss': weighted_total_loss.item(), @@ -546,7 +547,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: } if self._cfg.model.continuous_action_space: - return_data.update({ + return_log_dict.update({ # ============================================================== # sampled related core code # ============================================================== @@ -563,7 +564,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() }) else: - return_data.update({ + return_log_dict.update({ # ============================================================== # sampled related core code # ============================================================== @@ -574,7 +575,11 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() }) - return return_data + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_log_dict def _calculate_policy_loss_cont( self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, diff --git a/lzero/policy/sampled_unizero.py b/lzero/policy/sampled_unizero.py index 7694b9f13..8848e4a1e 100644 --- a/lzero/policy/sampled_unizero.py +++ b/lzero/policy/sampled_unizero.py @@ -5,6 +5,7 @@ import numpy as np import torch +import wandb from ding.model import model_wrap from ding.utils import POLICY_REGISTRY @@ -532,7 +533,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in current_memory_allocated_gb = 0. max_memory_allocated_gb = 0. - return_loss_dict = { + return_log_dict = { 'analysis/first_step_loss_value': first_step_losses['loss_value'].item(), 'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(), 'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(), @@ -579,7 +580,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in } if self._cfg.model.continuous_action_space: - return_loss_dict.update({ + return_log_dict.update({ # ============================================================== # sampled related core code # ============================================================== @@ -595,7 +596,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'target_sampled_actions_mean': target_sampled_actions_mean }) - return return_loss_dict + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_log_dict def monitor_weights_and_grads(self, model): for name, param in model.named_parameters(): diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 238c1daba..cf95c46d9 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -4,6 +4,7 @@ import numpy as np import torch +import wandb from ding.model import model_wrap from ding.utils import POLICY_REGISTRY @@ -329,6 +330,10 @@ def _init_learn(self) -> None: self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. + + if self._cfg.use_wandb: + # TODO: add the model to wandb + wandb.watch(self._learn_model.representation_network, log="all") # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: @@ -468,7 +473,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in current_memory_allocated_gb = 0. max_memory_allocated_gb = 0. - return_loss_dict = { + return_log_dict = { 'analysis/first_step_loss_value': first_step_losses['loss_value'].item(), 'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(), 'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(), @@ -513,8 +518,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, } + + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) - return return_loss_dict + return return_log_dict def monitor_weights_and_grads(self, model): for name, param in model.named_parameters(): diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index f9e225a2a..eff413df6 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -4,6 +4,7 @@ import numpy as np import torch +import wandb from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \ @@ -776,4 +777,7 @@ def _output_log(self, train_iter: int) -> None: self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + + if self.policy_config.use_wandb: + wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index f40fb90e4..f7cc39047 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -5,6 +5,7 @@ import numpy as np import torch +import wandb from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray, to_item, to_tensor from ding.utils import build_logger, EasyTimer @@ -433,6 +434,9 @@ def eval( continue self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if self.policy_config.use_wandb: + wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) + episode_return = np.mean(episode_return) if episode_return > self._max_episode_return: if save_ckpt_fn: diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py index 1f01ef5a3..abb214bc1 100644 --- a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py @@ -19,6 +19,7 @@ exp_name=f'data_muzero/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rer{reanalyze_ratio}_seed0', env=dict( env_id='CartPole-v0', + stop_value=200, continuous=False, manually_discretization=False, collector_env_num=collector_env_num, @@ -27,6 +28,7 @@ manager=dict(shared_memory=False, ), ), policy=dict( + use_wandb=True, model=dict( observation_shape=4, action_space_size=2, @@ -52,7 +54,7 @@ num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, - eval_freq=int(2e2), + eval_freq=int(100), replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/classic_control/cartpole/config/cartpole_unizero_config.py b/zoo/classic_control/cartpole/config/cartpole_unizero_config.py index eef0309b1..3c5171137 100644 --- a/zoo/classic_control/cartpole/config/cartpole_unizero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_unizero_config.py @@ -52,6 +52,7 @@ norm_type='BN', ), ), + use_wandb=True, # (str) The path of the pretrained model. If None, the model will be initialized by the default model. model_path=None, num_unroll_steps=num_unroll_steps,