Skip to content

Commit

Permalink
feature(pu): add wandb support (#294)
Browse files Browse the repository at this point in the history
* feature(pu): add wandb support in lz

* feature(pu): add wandb support in lz
  • Loading branch information
puyuan1996 authored Nov 15, 2024
1 parent dd7a5eb commit 60be9e3
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 16 deletions.
19 changes: 18 additions & 1 deletion lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Tuple

import torch
import wandb
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
Expand Down Expand Up @@ -81,6 +82,16 @@ def train_muzero(
if cfg.policy.eval_offline:
cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq

if cfg.policy.use_wandb:
# Initialize wandb
wandb.init(
project="LightZero",
config=cfg,
sync_tensorboard=False,
monitor_gym=False,
save_code=True,
)

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
Expand All @@ -103,7 +114,7 @@ def train_muzero(
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
policy_config=policy_config,
)
evaluator = Evaluator(
eval_freq=cfg.policy.eval_freq,
Expand All @@ -121,6 +132,8 @@ def train_muzero(
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')
if policy_config.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

if cfg.policy.update_per_collect is not None:
update_per_collect = cfg.policy.update_per_collect
Expand Down Expand Up @@ -199,6 +212,9 @@ def train_muzero(
)
break

if policy_config.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)

Expand All @@ -223,4 +239,5 @@ def train_muzero(

# Learner's after_run hook.
learner.call_hook('after_run')
wandb.finish()
return policy
17 changes: 17 additions & 0 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple, Optional

import torch
import wandb
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
Expand Down Expand Up @@ -76,6 +77,16 @@ def train_unizero(
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())

if cfg.policy.use_wandb:
# Initialize wandb
wandb.init(
project="LightZero",
config=cfg,
sync_tensorboard=False,
monitor_gym=False,
save_code=True,
)

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# Load pretrained model if specified
Expand All @@ -99,6 +110,8 @@ def train_unizero(

# Learner's before_run hook
learner.call_hook('before_run')
if policy_config.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# Collect random data before training
if cfg.policy.random_collect_episode_num > 0:
Expand Down Expand Up @@ -172,6 +185,9 @@ def train_unizero(
# Clear caches and precompute positional embedding matrices
policy.recompute_pos_emb_diff_and_clear_cache() # TODO

if policy_config.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

train_data.append({'train_which_component': 'transformer'})
log_vars = learner.train(train_data, collector.envstep)

Expand All @@ -185,4 +201,5 @@ def train_unizero(
break

learner.call_hook('after_run')
wandb.finish()
return policy
32 changes: 28 additions & 4 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import numpy as np
import torch
import torch.optim as optim
import wandb
from ding.model import model_wrap
from ding.policy.base_policy import Policy
from ding.torch_utils import to_tensor
from ding.utils import POLICY_REGISTRY
from torch.nn import L1Loss

from lzero.entry.utils import initialize_zeros_batch
from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
from lzero.mcts import MuZeroMCTSPtree as MCTSPtree
Expand All @@ -16,7 +19,6 @@
from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \
DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \
prepare_obs, configure_optimizers
from torch.nn import L1Loss


@POLICY_REGISTRY.register('muzero')
Expand Down Expand Up @@ -75,6 +77,8 @@ class MuZeroPolicy(Policy):
harmony_balance=False
),
# ****** common ******
# (bool) Whether to use wandb to log the training process.
use_wandb=False,
# (bool) whether to use rnd model.
use_rnd_model=False,
# (bool) Whether to use multi-gpu training.
Expand Down Expand Up @@ -253,6 +257,17 @@ def default_model(self) -> Tuple[str, List[str]]:
else:
raise ValueError("model type {} is not supported".format(self._cfg.model.model_type))

def set_train_iter_env_step(self, train_iter, env_step) -> None:
"""
Overview:
Set the train_iter and env_step for the policy.
Arguments:
- train_iter (:obj:`int`): The train_iter for the policy.
- env_step (:obj:`int`): The env_step for the policy.
"""
self.train_iter = train_iter
self.env_step = env_step

def _init_learn(self) -> None:
"""
Overview:
Expand Down Expand Up @@ -338,6 +353,10 @@ def _init_learn(self) -> None:
self.dormant_ratio_encoder = 0.
self.dormant_ratio_dynamics = 0.

if self._cfg.use_wandb:
# TODO: add the model to wandb
wandb.watch(self._learn_model.representation_network, log="all")

def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
"""
Overview:
Expand Down Expand Up @@ -596,7 +615,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1)
predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1)

return_dict = {
return_log_dict = {
'collect_mcts_temperature': self._collect_mcts_temperature,
'collect_epsilon': self.collect_epsilon,
'cur_lr': self._optimizer.param_groups[0]['lr'],
Expand Down Expand Up @@ -644,8 +663,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
"harmony_entropy": self.harmony_entropy.item(),
"harmony_entropy_exp_recip": (1 / torch.exp(self.harmony_entropy)).item(),
}
return_dict.update(harmony_dict)
return return_dict
return_log_dict.update(harmony_dict)

if self._cfg.use_wandb:
wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step)
wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step)

return return_log_dict

def _init_collect(self) -> None:
"""
Expand Down
13 changes: 9 additions & 4 deletions lzero/policy/sampled_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import wandb
import torch.optim as optim
from ding.model import model_wrap
from ding.torch_utils import to_tensor
Expand Down Expand Up @@ -520,7 +521,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1)
predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1)

return_data = {
return_log_dict = {
'cur_lr': self._optimizer.param_groups[0]['lr'],
'collect_mcts_temperature': self._collect_mcts_temperature,
'weighted_total_loss': weighted_total_loss.item(),
Expand All @@ -546,7 +547,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
}

if self._cfg.model.continuous_action_space:
return_data.update({
return_log_dict.update({
# ==============================================================
# sampled related core code
# ==============================================================
Expand All @@ -563,7 +564,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
'total_grad_norm_before_clip': total_grad_norm_before_clip.item()
})
else:
return_data.update({
return_log_dict.update({
# ==============================================================
# sampled related core code
# ==============================================================
Expand All @@ -574,7 +575,11 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
'total_grad_norm_before_clip': total_grad_norm_before_clip.item()
})

return return_data
if self._cfg.use_wandb:
wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step)
wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step)

return return_log_dict

def _calculate_policy_loss_cont(
self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor,
Expand Down
11 changes: 8 additions & 3 deletions lzero/policy/sampled_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
import wandb
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY

Expand Down Expand Up @@ -532,7 +533,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
current_memory_allocated_gb = 0.
max_memory_allocated_gb = 0.

return_loss_dict = {
return_log_dict = {
'analysis/first_step_loss_value': first_step_losses['loss_value'].item(),
'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(),
'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(),
Expand Down Expand Up @@ -579,7 +580,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
}

if self._cfg.model.continuous_action_space:
return_loss_dict.update({
return_log_dict.update({
# ==============================================================
# sampled related core code
# ==============================================================
Expand All @@ -595,7 +596,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
'target_sampled_actions_mean': target_sampled_actions_mean
})

return return_loss_dict
if self._cfg.use_wandb:
wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step)
wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step)

return return_log_dict

def monitor_weights_and_grads(self, model):
for name, param in model.named_parameters():
Expand Down
13 changes: 11 additions & 2 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import wandb
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY

Expand Down Expand Up @@ -329,6 +330,10 @@ def _init_learn(self) -> None:
self.l2_norm_after = 0.
self.grad_norm_before = 0.
self.grad_norm_after = 0.

if self._cfg.use_wandb:
# TODO: add the model to wandb
wandb.watch(self._learn_model.representation_network, log="all")

# @profile
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
Expand Down Expand Up @@ -468,7 +473,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
current_memory_allocated_gb = 0.
max_memory_allocated_gb = 0.

return_loss_dict = {
return_log_dict = {
'analysis/first_step_loss_value': first_step_losses['loss_value'].item(),
'analysis/first_step_loss_policy': first_step_losses['loss_policy'].item(),
'analysis/first_step_loss_rewards': first_step_losses['loss_rewards'].item(),
Expand Down Expand Up @@ -513,8 +518,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
'analysis/grad_norm_before': self.grad_norm_before,
'analysis/grad_norm_after': self.grad_norm_after,
}

if self._cfg.use_wandb:
wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step)
wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step)

return return_loss_dict
return return_log_dict

def monitor_weights_and_grads(self, model):
for name, param in model.named_parameters():
Expand Down
6 changes: 5 additions & 1 deletion lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import wandb
from ding.envs import BaseEnvManager
from ding.torch_utils import to_ndarray
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \
Expand Down Expand Up @@ -776,4 +777,7 @@ def _output_log(self, train_iter: int) -> None:
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
if k in ['total_envstep_count']:
continue
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)

if self.policy_config.use_wandb:
wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count)
4 changes: 4 additions & 0 deletions lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
import wandb
from ding.envs import BaseEnvManager
from ding.torch_utils import to_ndarray, to_item, to_tensor
from ding.utils import build_logger, EasyTimer
Expand Down Expand Up @@ -433,6 +434,9 @@ def eval(
continue
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
if self.policy_config.use_wandb:
wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep)

episode_return = np.mean(episode_return)
if episode_return > self._max_episode_return:
if save_ckpt_fn:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
exp_name=f'data_muzero/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rer{reanalyze_ratio}_seed0',
env=dict(
env_id='CartPole-v0',
stop_value=200,
continuous=False,
manually_discretization=False,
collector_env_num=collector_env_num,
Expand All @@ -27,6 +28,7 @@
manager=dict(shared_memory=False, ),
),
policy=dict(
use_wandb=True,
model=dict(
observation_shape=4,
action_space_size=2,
Expand All @@ -52,7 +54,7 @@
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e2),
eval_freq=int(100),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
norm_type='BN',
),
),
use_wandb=True,
# (str) The path of the pretrained model. If None, the model will be initialized by the default model.
model_path=None,
num_unroll_steps=num_unroll_steps,
Expand Down

0 comments on commit 60be9e3

Please sign in to comment.