From d7a61c2fdfe4532b802ce9ce66e8855c41fb7dea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Fri, 29 Dec 2023 17:26:05 +0800 Subject: [PATCH] fix(pu): fix hppo entropy_weight to avoid nan error in log_prob (#761) --- ding/policy/ppo.py | 5 +++-- ding/rl_utils/td.py | 8 ++++++-- dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 6eb4da1876..289bc72c44 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -75,7 +75,7 @@ class PPOPolicy(Policy): # collect_mode config collect=dict( # (int) How many training samples collected in one collection procedure. - # Only one of [n_sample, n_episode] shoule be set. + # Only one of [n_sample, n_episode] should be set. # n_sample=64, # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, @@ -511,7 +511,8 @@ def _init_eval(self) -> None: elif self._action_space == 'discrete': self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') elif self._action_space == 'hybrid': - self._eval_model = model_wrap(self._model, wrapper_name='hybrid_deterministic_argmax_sample') + self._eval_model = model_wrap(self._model, wrapper_name='hybrid_reparam_multinomial_sample') + self._eval_model.reset() def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: diff --git a/ding/rl_utils/td.py b/ding/rl_utils/td.py index 4dd2df6c4b..1622d2c289 100644 --- a/ding/rl_utils/td.py +++ b/ding/rl_utils/td.py @@ -687,14 +687,18 @@ def q_nstep_td_error( q, next_n_q, action, next_n_action, reward, done, weight = data if weight is None: weight = torch.ones_like(reward) - if len(action.shape) > 1: # MARL case + + if len(action.shape) == 1: # single agent case + action = action.unsqueeze(-1) + elif len(action.shape) > 1: # MARL case reward = reward.unsqueeze(-1) weight = weight.unsqueeze(-1) done = done.unsqueeze(-1) if value_gamma is not None: value_gamma = value_gamma.unsqueeze(-1) - q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1) + q_s_a = q.gather(-1, action).squeeze(-1) + target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) if cum_reward: diff --git a/dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py b/dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py index 45e0bebba2..2011972e19 100644 --- a/dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py +++ b/dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py @@ -31,7 +31,7 @@ epoch_per_collect=10, batch_size=320, learning_rate=3e-4, - entropy_weight=0.03, + entropy_weight=0.5, adv_norm=True, value_norm=True, ),