diff --git a/ml-agents/mlagents/trainers/torch/action_model.py b/ml-agents/mlagents/trainers/torch/action_model.py index 0b0d53706e..c5de586e4d 100644 --- a/ml-agents/mlagents/trainers/torch/action_model.py +++ b/ml-agents/mlagents/trainers/torch/action_model.py @@ -164,6 +164,9 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten if self.action_spec.continuous_size > 0 and dists.continuous is not None: continuous_out = dists.continuous.exported_model_output() action_out_deprecated = dists.continuous.exported_model_output() + if self._clip_action_on_export: + continuous_out = torch.clamp(continuous_out, -3, 3) / 3 + action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3 if self.action_spec.discrete_size > 0 and dists.discrete is not None: discrete_out_list = [ discrete_dist.exported_model_output()