Skip to content

Commit

Permalink
[RLlib] Add log-std clipping to 'MLPHead's. (ray-project#47827)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 15, 2024
1 parent a9e3210 commit e966649
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 15 deletions.
9 changes: 9 additions & 0 deletions rllib/algorithms/bc/bc_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def build_pi_head(self, framework: str) -> Model:
_check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls, framework=framework
)
is_diag_gaussian = True
else:
is_diag_gaussian = _check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls,
framework=framework,
no_error=True,
)
required_output_dim = action_distribution_cls.required_input_dim(
space=self.action_space, model_config=self._model_config_dict
)
Expand All @@ -95,6 +102,8 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_head_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
clip_log_std=is_diag_gaussian,
log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
)

return self.pi_head_config.build(framework=framework)
Expand Down
9 changes: 9 additions & 0 deletions rllib/algorithms/marwil/marwil_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def build_pi_head(self, framework: str) -> Model:
_check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls, framework=framework
)
is_diag_gaussian = True
else:
is_diag_gaussian = _check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls,
framework=framework,
no_error=True,
)

required_output_dim = action_distribution_cls.required_input_dim(
space=self.action_space, model_config=self._model_config_dict
Expand All @@ -116,6 +123,8 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_and_vf_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
clip_log_std=is_diag_gaussian,
log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
)

return self.pi_head_config.build(framework=framework)
Expand Down
35 changes: 26 additions & 9 deletions rllib/algorithms/ppo/ppo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,29 @@
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic


def _check_if_diag_gaussian(action_distribution_cls, framework):
def _check_if_diag_gaussian(action_distribution_cls, framework, no_error=False):
if framework == "torch":
from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian

assert issubclass(action_distribution_cls, TorchDiagGaussian), (
f"free_log_std is only supported for DiagGaussian action distributions. "
f"Found action distribution: {action_distribution_cls}."
)
is_diag_gaussian = issubclass(action_distribution_cls, TorchDiagGaussian)
if no_error:
return is_diag_gaussian
else:
assert is_diag_gaussian, (
f"free_log_std is only supported for DiagGaussian action "
f"distributions. Found action distribution: {action_distribution_cls}."
)
elif framework == "tf2":
from ray.rllib.models.tf.tf_distributions import TfDiagGaussian

assert issubclass(action_distribution_cls, TfDiagGaussian), (
"free_log_std is only supported for DiagGaussian action distributions. "
"Found action distribution: {}.".format(action_distribution_cls)
)
is_diag_gaussian = issubclass(action_distribution_cls, TfDiagGaussian)
if no_error:
return is_diag_gaussian
else:
assert is_diag_gaussian, (
"free_log_std is only supported for DiagGaussian action distributions. "
"Found action distribution: {}.".format(action_distribution_cls)
)
else:
raise ValueError(f"Framework {framework} not supported for free_log_std.")

Expand Down Expand Up @@ -148,6 +156,13 @@ def build_pi_head(self, framework: str) -> Model:
_check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls, framework=framework
)
is_diag_gaussian = True
else:
is_diag_gaussian = _check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls,
framework=framework,
no_error=True,
)
required_output_dim = action_distribution_cls.required_input_dim(
space=self.action_space, model_config=self._model_config_dict
)
Expand All @@ -164,6 +179,8 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_and_vf_head_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
clip_log_std=is_diag_gaussian,
log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
)

return self.pi_head_config.build(framework=framework)
Expand Down
9 changes: 9 additions & 0 deletions rllib/algorithms/sac/sac_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ def build_pi_head(self, framework: str) -> Model:
_check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls, framework=framework
)
is_diag_gaussian = True
else:
is_diag_gaussian = _check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls,
framework=framework,
no_error=True,
)
required_output_dim = action_distribution_cls.required_input_dim(
space=self.action_space, model_config=self._model_config_dict
)
Expand All @@ -187,6 +194,8 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_and_qf_head_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
clip_log_std=is_diag_gaussian,
log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
)

return self.pi_head_config.build(framework=framework)
Expand Down
16 changes: 16 additions & 0 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ class _MLPConfig(ModelConfig):
"_" are allowed.
output_layer_bias_initializer_config: Configuration to pass into the
initializer defined in `output_layer_bias_initializer`.
clip_log_std: If log std should be clipped by `log_std_clip_param`. This applies
only to the action distribution parameters that encode the log standard
deviation of a `DiagGaussian` distribution.
log_std_clip_param: The clipping parameter for the log std, if clipping should
be applied - i.e. `clip_log_std=True`. The default value is 20, i.e. log
stds are clipped in between -20 and 20.
"""

hidden_layer_dims: Union[List[int], Tuple[int]] = (256, 256)
Expand All @@ -181,6 +187,11 @@ class _MLPConfig(ModelConfig):
output_layer_bias_initializer: Optional[Union[str, Callable]] = None
output_layer_bias_initializer_config: Optional[Dict] = None

# Optional clipping of log standard deviation.
clip_log_std: bool = False
# Optional clip parameter for the log standard deviation.
log_std_clip_param: float = 20.0

@property
def output_dims(self):
if self.output_layer_dim is None and not self.hidden_layer_dims:
Expand All @@ -205,6 +216,11 @@ def _validate(self, framework: str = "torch"):
"1D, e.g. `[32]`! This is an inferred value, hence other settings might"
" be wrong."
)
if self.log_std_clip_param is None:
raise ValueError(
"`log_std_clip_param` of _MLPConfig must be a float value, but is "
"`None`."
)

# Call these already here to catch errors early on.
get_activation_fn(self.hidden_layer_activation, framework=framework)
Expand Down
35 changes: 33 additions & 2 deletions rllib/core/models/tf/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def __init__(self, config: MLPHeadConfig) -> None:
output_bias_initializer=config.output_layer_bias_initializer,
output_bias_initializer_config=config.output_layer_bias_initializer_config,
)
# If log standard deviations should be clipped. This should be only true for
# policy heads. Value heads should never be clipped.
self.clip_log_std = config.clip_log_std
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = tf.constant([config.log_std_clip_param])

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -135,7 +140,19 @@ def get_output_specs(self) -> Optional[Spec]:
@override(Model)
@auto_fold_unfold_time("input_specs")
def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
return self.net(inputs)
# Only clip the log standard deviations, if the user wants to clip. This
# avoids also clipping value heads.
if self.clip_log_std:
# Forward pass.
means, log_stds = tf.split(self.net(inputs), num_or_size_splits=2, axis=-1)
# Clip the log standard deviations.
log_stds = tf.clip_by_value(
log_stds, -self.log_std_clip_param, self.log_std_clip_param
)
return tf.concat([means, log_stds], axis=-1)
# Otherwise just return the logits.
else:
return self.net(inputs)


class TfFreeLogStdMLPHead(TfModel):
Expand Down Expand Up @@ -178,6 +195,11 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None:
dtype=tf.float32,
trainable=True,
)
# If log standard deviations should be clipped. This should be only true for
# policy heads. Value heads should never be clipped.
self.clip_log_std = config.clip_log_std
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = tf.constant([config.log_std_clip_param])

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -192,7 +214,16 @@ def get_output_specs(self) -> Optional[Spec]:
def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
# Compute the mean first, then append the log_std.
mean = self.net(inputs)
log_std_out = tf.tile(tf.expand_dims(self.log_std, 0), [tf.shape(inputs)[0], 1])
# If log standard deviation should be clipped.
if self.clip_log_std:
# Clip log standard deviations to stabilize training. Note, the
# default clip value is `inf`, i.e. no clipping.
log_std = tf.clip_by_value(
self.log_std, -self.log_std_clip_param, self.log_std_clip_param
)
else:
log_std = self.log_std
log_std_out = tf.tile(tf.expand_dims(log_std, 0), [tf.shape(inputs)[0], 1])
logits_out = tf.concat([mean, log_std_out], axis=1)
return logits_out

Expand Down
46 changes: 42 additions & 4 deletions rllib/core/models/torch/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ def __init__(self, config: MLPHeadConfig) -> None:
output_bias_initializer=config.output_layer_bias_initializer,
output_bias_initializer_config=config.output_layer_bias_initializer_config,
)
# If log standard deviations should be clipped. This should be only true for
# policy heads. Value heads should never be clipped.
self.clip_log_std = config.clip_log_std
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = torch.Tensor([config.log_std_clip_param])
# Register a buffer to handle device mapping.
self.register_buffer("log_std_clip_param_const", self.log_std_clip_param)

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -133,7 +140,19 @@ def get_output_specs(self) -> Optional[Spec]:
@override(Model)
@auto_fold_unfold_time("input_specs")
def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(inputs)
# Only clip the log standard deviations, if the user wants to clip. This
# avoids also clipping value heads.
if self.clip_log_std:
# Forward pass.
means, log_stds = torch.chunk(self.net(inputs), chunks=2, dim=-1)
# Clip the log standard deviations.
log_stds = torch.clamp(
log_stds, -self.log_std_clip_param_const, self.log_std_clip_param_const
)
return torch.cat((means, log_stds), dim=-1)
# Otherwise just return the logits.
else:
return self.net(inputs)


class TorchFreeLogStdMLPHead(TorchModel):
Expand Down Expand Up @@ -173,6 +192,15 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None:
self.log_std = torch.nn.Parameter(
torch.as_tensor([0.0] * self._half_output_dim)
)
# If log standard deviations should be clipped. This should be only true for
# policy heads. Value heads should never be clipped.
self.clip_log_std = config.clip_log_std
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = torch.Tensor(
[config.log_std_clip_param], device=self.log_std.device
)
# Register a buffer to handle device mapping.
self.register_buffer("log_std_clip_param_const", self.log_std_clip_param)

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -188,9 +216,19 @@ def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
# Compute the mean first, then append the log_std.
mean = self.net(inputs)

return torch.cat(
[mean, self.log_std.unsqueeze(0).repeat([len(mean), 1])], axis=1
)
# If log standard deviation should be clipped.
if self.clip_log_std:
# Clip the log standard deviation to avoid running into too small
# deviations that factually collapses the policy.
log_std = torch.clamp(
self.log_std,
-self.log_std_clip_param_const,
self.log_std_clip_param_const,
)
else:
log_std = self.log_std

return torch.cat([mean, log_std.unsqueeze(0).repeat([len(mean), 1])], axis=1)


class TorchCNNTransposeHead(TorchModel):
Expand Down
6 changes: 6 additions & 0 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@
# outputs floating bias variables instead of state-dependent. This only
# has an effect is using the default fully connected net.
"free_log_std": False,
# Whether to clip the log standard deviation when using a Gaussian (or any
# other continuous control distribution). This can stabilize training and avoid
# very small or large log standard deviations leading to numerical instabilities
# which can turn network outputs to `nan`. The default is to clamp the log std
# in between -20 and 20.
"log_std_clip_param": 20.0,
# Whether to skip the final linear layer used to resize the hidden layer
# outputs to size `num_outputs`. If True, then the last hidden layer
# should already match num_outputs.
Expand Down

0 comments on commit e966649

Please sign in to comment.