diff --git a/rllib/algorithms/bc/bc_catalog.py b/rllib/algorithms/bc/bc_catalog.py index 1c8adf71b174..6f4a1f8468fa 100644 --- a/rllib/algorithms/bc/bc_catalog.py +++ b/rllib/algorithms/bc/bc_catalog.py @@ -79,6 +79,13 @@ def build_pi_head(self, framework: str) -> Model: _check_if_diag_gaussian( action_distribution_cls=action_distribution_cls, framework=framework ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) required_output_dim = action_distribution_cls.required_input_dim( space=self.action_space, model_config=self._model_config_dict ) @@ -95,6 +102,8 @@ def build_pi_head(self, framework: str) -> Model: hidden_layer_activation=self.pi_head_activation, output_layer_dim=required_output_dim, output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), ) return self.pi_head_config.build(framework=framework) diff --git a/rllib/algorithms/marwil/marwil_catalog.py b/rllib/algorithms/marwil/marwil_catalog.py index 6ef2202a828d..913cf543ff53 100644 --- a/rllib/algorithms/marwil/marwil_catalog.py +++ b/rllib/algorithms/marwil/marwil_catalog.py @@ -99,6 +99,13 @@ def build_pi_head(self, framework: str) -> Model: _check_if_diag_gaussian( action_distribution_cls=action_distribution_cls, framework=framework ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) required_output_dim = action_distribution_cls.required_input_dim( space=self.action_space, model_config=self._model_config_dict @@ -116,6 +123,8 @@ def build_pi_head(self, framework: str) -> Model: hidden_layer_activation=self.pi_and_vf_activation, output_layer_dim=required_output_dim, output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), ) return self.pi_head_config.build(framework=framework) diff --git a/rllib/algorithms/ppo/ppo_catalog.py b/rllib/algorithms/ppo/ppo_catalog.py index cc2209eecd2e..958607ade616 100644 --- a/rllib/algorithms/ppo/ppo_catalog.py +++ b/rllib/algorithms/ppo/ppo_catalog.py @@ -12,21 +12,29 @@ from ray.rllib.utils.annotations import OverrideToImplementCustomLogic -def _check_if_diag_gaussian(action_distribution_cls, framework): +def _check_if_diag_gaussian(action_distribution_cls, framework, no_error=False): if framework == "torch": from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian - assert issubclass(action_distribution_cls, TorchDiagGaussian), ( - f"free_log_std is only supported for DiagGaussian action distributions. " - f"Found action distribution: {action_distribution_cls}." - ) + is_diag_gaussian = issubclass(action_distribution_cls, TorchDiagGaussian) + if no_error: + return is_diag_gaussian + else: + assert is_diag_gaussian, ( + f"free_log_std is only supported for DiagGaussian action " + f"distributions. Found action distribution: {action_distribution_cls}." + ) elif framework == "tf2": from ray.rllib.models.tf.tf_distributions import TfDiagGaussian - assert issubclass(action_distribution_cls, TfDiagGaussian), ( - "free_log_std is only supported for DiagGaussian action distributions. " - "Found action distribution: {}.".format(action_distribution_cls) - ) + is_diag_gaussian = issubclass(action_distribution_cls, TfDiagGaussian) + if no_error: + return is_diag_gaussian + else: + assert is_diag_gaussian, ( + "free_log_std is only supported for DiagGaussian action distributions. " + "Found action distribution: {}.".format(action_distribution_cls) + ) else: raise ValueError(f"Framework {framework} not supported for free_log_std.") @@ -148,6 +156,13 @@ def build_pi_head(self, framework: str) -> Model: _check_if_diag_gaussian( action_distribution_cls=action_distribution_cls, framework=framework ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) required_output_dim = action_distribution_cls.required_input_dim( space=self.action_space, model_config=self._model_config_dict ) @@ -164,6 +179,8 @@ def build_pi_head(self, framework: str) -> Model: hidden_layer_activation=self.pi_and_vf_head_activation, output_layer_dim=required_output_dim, output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), ) return self.pi_head_config.build(framework=framework) diff --git a/rllib/algorithms/sac/sac_catalog.py b/rllib/algorithms/sac/sac_catalog.py index 7fc5a9feb713..a748fd6f00d7 100644 --- a/rllib/algorithms/sac/sac_catalog.py +++ b/rllib/algorithms/sac/sac_catalog.py @@ -171,6 +171,13 @@ def build_pi_head(self, framework: str) -> Model: _check_if_diag_gaussian( action_distribution_cls=action_distribution_cls, framework=framework ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) required_output_dim = action_distribution_cls.required_input_dim( space=self.action_space, model_config=self._model_config_dict ) @@ -187,6 +194,8 @@ def build_pi_head(self, framework: str) -> Model: hidden_layer_activation=self.pi_and_qf_head_activation, output_layer_dim=required_output_dim, output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), ) return self.pi_head_config.build(framework=framework) diff --git a/rllib/core/models/configs.py b/rllib/core/models/configs.py index 1db76587f970..60a0758bbd76 100644 --- a/rllib/core/models/configs.py +++ b/rllib/core/models/configs.py @@ -160,6 +160,12 @@ class _MLPConfig(ModelConfig): "_" are allowed. output_layer_bias_initializer_config: Configuration to pass into the initializer defined in `output_layer_bias_initializer`. + clip_log_std: If log std should be clipped by `log_std_clip_param`. This applies + only to the action distribution parameters that encode the log standard + deviation of a `DiagGaussian` distribution. + log_std_clip_param: The clipping parameter for the log std, if clipping should + be applied - i.e. `clip_log_std=True`. The default value is 20, i.e. log + stds are clipped in between -20 and 20. """ hidden_layer_dims: Union[List[int], Tuple[int]] = (256, 256) @@ -181,6 +187,11 @@ class _MLPConfig(ModelConfig): output_layer_bias_initializer: Optional[Union[str, Callable]] = None output_layer_bias_initializer_config: Optional[Dict] = None + # Optional clipping of log standard deviation. + clip_log_std: bool = False + # Optional clip parameter for the log standard deviation. + log_std_clip_param: float = 20.0 + @property def output_dims(self): if self.output_layer_dim is None and not self.hidden_layer_dims: @@ -205,6 +216,11 @@ def _validate(self, framework: str = "torch"): "1D, e.g. `[32]`! This is an inferred value, hence other settings might" " be wrong." ) + if self.log_std_clip_param is None: + raise ValueError( + "`log_std_clip_param` of _MLPConfig must be a float value, but is " + "`None`." + ) # Call these already here to catch errors early on. get_activation_fn(self.hidden_layer_activation, framework=framework) diff --git a/rllib/core/models/tf/heads.py b/rllib/core/models/tf/heads.py index a09595e6a02e..823946efc3e8 100644 --- a/rllib/core/models/tf/heads.py +++ b/rllib/core/models/tf/heads.py @@ -123,6 +123,11 @@ def __init__(self, config: MLPHeadConfig) -> None: output_bias_initializer=config.output_layer_bias_initializer, output_bias_initializer_config=config.output_layer_bias_initializer_config, ) + # If log standard deviations should be clipped. This should be only true for + # policy heads. Value heads should never be clipped. + self.clip_log_std = config.clip_log_std + # The clipping parameter for the log standard deviation. + self.log_std_clip_param = tf.constant([config.log_std_clip_param]) @override(Model) def get_input_specs(self) -> Optional[Spec]: @@ -135,7 +140,19 @@ def get_output_specs(self) -> Optional[Spec]: @override(Model) @auto_fold_unfold_time("input_specs") def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: - return self.net(inputs) + # Only clip the log standard deviations, if the user wants to clip. This + # avoids also clipping value heads. + if self.clip_log_std: + # Forward pass. + means, log_stds = tf.split(self.net(inputs), num_or_size_splits=2, axis=-1) + # Clip the log standard deviations. + log_stds = tf.clip_by_value( + log_stds, -self.log_std_clip_param, self.log_std_clip_param + ) + return tf.concat([means, log_stds], axis=-1) + # Otherwise just return the logits. + else: + return self.net(inputs) class TfFreeLogStdMLPHead(TfModel): @@ -178,6 +195,11 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None: dtype=tf.float32, trainable=True, ) + # If log standard deviations should be clipped. This should be only true for + # policy heads. Value heads should never be clipped. + self.clip_log_std = config.clip_log_std + # The clipping parameter for the log standard deviation. + self.log_std_clip_param = tf.constant([config.log_std_clip_param]) @override(Model) def get_input_specs(self) -> Optional[Spec]: @@ -192,7 +214,16 @@ def get_output_specs(self) -> Optional[Spec]: def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: # Compute the mean first, then append the log_std. mean = self.net(inputs) - log_std_out = tf.tile(tf.expand_dims(self.log_std, 0), [tf.shape(inputs)[0], 1]) + # If log standard deviation should be clipped. + if self.clip_log_std: + # Clip log standard deviations to stabilize training. Note, the + # default clip value is `inf`, i.e. no clipping. + log_std = tf.clip_by_value( + self.log_std, -self.log_std_clip_param, self.log_std_clip_param + ) + else: + log_std = self.log_std + log_std_out = tf.tile(tf.expand_dims(log_std, 0), [tf.shape(inputs)[0], 1]) logits_out = tf.concat([mean, log_std_out], axis=1) return logits_out diff --git a/rllib/core/models/torch/heads.py b/rllib/core/models/torch/heads.py index d634ff2ef4c8..2e9e23cd969a 100644 --- a/rllib/core/models/torch/heads.py +++ b/rllib/core/models/torch/heads.py @@ -121,6 +121,13 @@ def __init__(self, config: MLPHeadConfig) -> None: output_bias_initializer=config.output_layer_bias_initializer, output_bias_initializer_config=config.output_layer_bias_initializer_config, ) + # If log standard deviations should be clipped. This should be only true for + # policy heads. Value heads should never be clipped. + self.clip_log_std = config.clip_log_std + # The clipping parameter for the log standard deviation. + self.log_std_clip_param = torch.Tensor([config.log_std_clip_param]) + # Register a buffer to handle device mapping. + self.register_buffer("log_std_clip_param_const", self.log_std_clip_param) @override(Model) def get_input_specs(self) -> Optional[Spec]: @@ -133,7 +140,19 @@ def get_output_specs(self) -> Optional[Spec]: @override(Model) @auto_fold_unfold_time("input_specs") def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: - return self.net(inputs) + # Only clip the log standard deviations, if the user wants to clip. This + # avoids also clipping value heads. + if self.clip_log_std: + # Forward pass. + means, log_stds = torch.chunk(self.net(inputs), chunks=2, dim=-1) + # Clip the log standard deviations. + log_stds = torch.clamp( + log_stds, -self.log_std_clip_param_const, self.log_std_clip_param_const + ) + return torch.cat((means, log_stds), dim=-1) + # Otherwise just return the logits. + else: + return self.net(inputs) class TorchFreeLogStdMLPHead(TorchModel): @@ -173,6 +192,15 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None: self.log_std = torch.nn.Parameter( torch.as_tensor([0.0] * self._half_output_dim) ) + # If log standard deviations should be clipped. This should be only true for + # policy heads. Value heads should never be clipped. + self.clip_log_std = config.clip_log_std + # The clipping parameter for the log standard deviation. + self.log_std_clip_param = torch.Tensor( + [config.log_std_clip_param], device=self.log_std.device + ) + # Register a buffer to handle device mapping. + self.register_buffer("log_std_clip_param_const", self.log_std_clip_param) @override(Model) def get_input_specs(self) -> Optional[Spec]: @@ -188,9 +216,19 @@ def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: # Compute the mean first, then append the log_std. mean = self.net(inputs) - return torch.cat( - [mean, self.log_std.unsqueeze(0).repeat([len(mean), 1])], axis=1 - ) + # If log standard deviation should be clipped. + if self.clip_log_std: + # Clip the log standard deviation to avoid running into too small + # deviations that factually collapses the policy. + log_std = torch.clamp( + self.log_std, + -self.log_std_clip_param_const, + self.log_std_clip_param_const, + ) + else: + log_std = self.log_std + + return torch.cat([mean, log_std.unsqueeze(0).repeat([len(mean), 1])], axis=1) class TorchCNNTransposeHead(TorchModel): diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 64a4bf6af39a..614a38c45884 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -190,6 +190,12 @@ # outputs floating bias variables instead of state-dependent. This only # has an effect is using the default fully connected net. "free_log_std": False, + # Whether to clip the log standard deviation when using a Gaussian (or any + # other continuous control distribution). This can stabilize training and avoid + # very small or large log standard deviations leading to numerical instabilities + # which can turn network outputs to `nan`. The default is to clamp the log std + # in between -20 and 20. + "log_std_clip_param": 20.0, # Whether to skip the final linear layer used to resize the hidden layer # outputs to size `num_outputs`. If True, then the last hidden layer # should already match num_outputs.