Skip to content

Commit

Permalink
fix(nyz): fix mappo adv compute bug (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jul 1, 2024
1 parent b4ab08a commit 35ec39e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions ding/framework/middleware/functional/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non
return task.void()

model = policy.get_attribute('model')
# Unify the shape of obs and action
obs_shape = cfg['policy']['model']['obs_shape']
obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \
else torch.Size(torch.tensor(obs_shape).unsqueeze(0))
if buffer_ is not None:
# Unify the shape of obs and action
obs_shape = cfg['policy']['model']['obs_shape']
obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \
else torch.Size(torch.tensor(obs_shape).unsqueeze(0))

def _gae(ctx: "OnlineRLContext"):
"""
Expand Down

0 comments on commit 35ec39e

Please sign in to comment.