Skip to content

Commit

Permalink
fix(pu): fix hppo entropy_weight to avoid nan error in log_prob (#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 authored Dec 29, 2023
1 parent beb91d9 commit d7a61c2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
5 changes: 3 additions & 2 deletions ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class PPOPolicy(Policy):
# collect_mode config
collect=dict(
# (int) How many training samples collected in one collection procedure.
# Only one of [n_sample, n_episode] shoule be set.
# Only one of [n_sample, n_episode] should be set.
# n_sample=64,
# (int) Split episodes or trajectories into pieces with length `unroll_len`.
unroll_len=1,
Expand Down Expand Up @@ -511,7 +511,8 @@ def _init_eval(self) -> None:
elif self._action_space == 'discrete':
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
elif self._action_space == 'hybrid':
self._eval_model = model_wrap(self._model, wrapper_name='hybrid_deterministic_argmax_sample')
self._eval_model = model_wrap(self._model, wrapper_name='hybrid_reparam_multinomial_sample')

self._eval_model.reset()

def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
Expand Down
8 changes: 6 additions & 2 deletions ding/rl_utils/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,14 +687,18 @@ def q_nstep_td_error(
q, next_n_q, action, next_n_action, reward, done, weight = data
if weight is None:
weight = torch.ones_like(reward)
if len(action.shape) > 1: # MARL case

if len(action.shape) == 1: # single agent case
action = action.unsqueeze(-1)
elif len(action.shape) > 1: # MARL case
reward = reward.unsqueeze(-1)
weight = weight.unsqueeze(-1)
done = done.unsqueeze(-1)
if value_gamma is not None:
value_gamma = value_gamma.unsqueeze(-1)

q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1)
q_s_a = q.gather(-1, action).squeeze(-1)

target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1)

if cum_reward:
Expand Down
2 changes: 1 addition & 1 deletion dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
epoch_per_collect=10,
batch_size=320,
learning_rate=3e-4,
entropy_weight=0.03,
entropy_weight=0.5,
adv_norm=True,
value_norm=True,
),
Expand Down

0 comments on commit d7a61c2

Please sign in to comment.