diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index c365f4d7a3..6daf8d4528 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -31,11 +31,12 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non return task.void() model = policy.get_attribute('model') - # Unify the shape of obs and action - obs_shape = cfg['policy']['model']['obs_shape'] - obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \ - else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \ - else torch.Size(torch.tensor(obs_shape).unsqueeze(0)) + if buffer_ is not None: + # Unify the shape of obs and action + obs_shape = cfg['policy']['model']['obs_shape'] + obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \ + else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \ + else torch.Size(torch.tensor(obs_shape).unsqueeze(0)) def _gae(ctx: "OnlineRLContext"): """