From 35ec39eaceb6ebc86f78da4b0cfffc47979cc42b Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Mon, 1 Jul 2024 14:40:50 +0800 Subject: [PATCH] fix(nyz): fix mappo adv compute bug (#812) --- .../middleware/functional/advantage_estimator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index c365f4d7a3..6daf8d4528 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -31,11 +31,12 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non return task.void() model = policy.get_attribute('model') - # Unify the shape of obs and action - obs_shape = cfg['policy']['model']['obs_shape'] - obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \ - else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \ - else torch.Size(torch.tensor(obs_shape).unsqueeze(0)) + if buffer_ is not None: + # Unify the shape of obs and action + obs_shape = cfg['policy']['model']['obs_shape'] + obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \ + else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \ + else torch.Size(torch.tensor(obs_shape).unsqueeze(0)) def _gae(ctx: "OnlineRLContext"): """