diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 81a7997df6..f5f89d30ec 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -159,10 +159,7 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): requires_grad=True, ) _cont_log_ent_coef = torch.nn.Parameter( - torch.log( - torch.as_tensor([self.init_entcoef] * self._action_spec.continuous_size) - ), - requires_grad=True, + torch.log(torch.as_tensor([self.init_entcoef])), requires_grad=True ) self._log_ent_coef = TorchSACOptimizer.LogEntCoef( discrete=_disc_log_ent_coef, continuous=_cont_log_ent_coef @@ -426,7 +423,7 @@ def sac_entropy_loss( ) # We update all the _cont_ent_coef as one block entropy_loss += -1 * ModelUtils.masked_mean( - torch.mean(_cont_ent_coef) * target_current_diff, loss_masks + _cont_ent_coef * target_current_diff, loss_masks ) return entropy_loss