Skip to content

Commit

Permalink
Added log-std clipping to 'MLPHeads', '_MLPConfig' and all algorithms…
Browse files Browse the repository at this point in the history
… that can use continuous action distributions.

Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
  • Loading branch information
simonsays1980 committed Sep 26, 2024
1 parent 80c2a42 commit 9a6b2dd
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 5 deletions.
1 change: 1 addition & 0 deletions rllib/algorithms/bc/bc_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_head_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
log_std_clip_param=self._model_config_dict["log_std_clip_param"],
)

return self.pi_head_config.build(framework=framework)
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/marwil/marwil_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_and_vf_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
log_std_clip_param=self._model_config_dict["log_std_clip_param"],
)

return self.pi_head_config.build(framework=framework)
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/ppo/ppo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_and_vf_head_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
log_std_clip_param=self._model_config_dict["log_std_clip_param"],
)

return self.pi_head_config.build(framework=framework)
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/sac/sac_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def build_pi_head(self, framework: str) -> Model:
hidden_layer_activation=self.pi_and_qf_head_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
log_std_clip_param=self._model_config_dict["log_std_clip_param"],
)

return self.pi_head_config.build(framework=framework)
Expand Down
8 changes: 8 additions & 0 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ class _MLPConfig(ModelConfig):
output_layer_bias_initializer: Optional[Union[str, Callable]] = None
output_layer_bias_initializer_config: Optional[Dict] = None

# Optional clip parameter for the log standard deviation.
log_std_clip_param: float = float("inf")

@property
def output_dims(self):
if self.output_layer_dim is None and not self.hidden_layer_dims:
Expand All @@ -205,6 +208,11 @@ def _validate(self, framework: str = "torch"):
"1D, e.g. `[32]`! This is an inferred value, hence other settings might"
" be wrong."
)
if self.log_std_clip_param is None:
raise ValueError(
"`log_std_clip_param` of _MLPConfig must be a float value, but is "
"`None`."
)

# Call these already here to catch errors early on.
get_activation_fn(self.hidden_layer_activation, framework=framework)
Expand Down
25 changes: 23 additions & 2 deletions rllib/core/models/tf/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(self, config: MLPHeadConfig) -> None:
output_bias_initializer=config.output_layer_bias_initializer,
output_bias_initializer_config=config.output_layer_bias_initializer_config,
)
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = tf.constant([config.log_std_clip_param])

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -135,7 +137,19 @@ def get_output_specs(self) -> Optional[Spec]:
@override(Model)
@auto_fold_unfold_time("input_specs")
def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
return self.net(inputs)
# Only clip the log standard deviations, if the user wants to clip. This
# avoids also clipping value heads.
if tf.math.isfinite(self.log_std_clip_param):
# Forward pass.
means, log_stds = tf.split(self.net(inputs), num_or_size_splits=2, axis=-1)
# Clip the log standard deviations.
log_stds = tf.clip_by_value(
log_stds, -self.log_std_clip_param, self.log_std_clip_param
)
return tf.keras.concatenate([means, log_stds], axis=-1)
# Otherwise just return the logits.
else:
return self.net(inputs)


class TfFreeLogStdMLPHead(TfModel):
Expand Down Expand Up @@ -178,6 +192,8 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None:
dtype=tf.float32,
trainable=True,
)
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = tf.constant([config.log_std_clip_param])

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -192,7 +208,12 @@ def get_output_specs(self) -> Optional[Spec]:
def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
# Compute the mean first, then append the log_std.
mean = self.net(inputs)
log_std_out = tf.tile(tf.expand_dims(self.log_std, 0), [tf.shape(inputs)[0], 1])
# Clip log standard deviations to stabilize training. Note, the
# default clip value is `inf`, i.e. no clipping.
log_std = tf.clip_by_value(
self.log_std, -self.log_std_clip_param, self.log_std_clip_param
)
log_std_out = tf.tile(tf.expand_dims(log_std, 0), [tf.shape(inputs)[0], 1])
logits_out = tf.concat([mean, log_std_out], axis=1)
return logits_out

Expand Down
26 changes: 23 additions & 3 deletions rllib/core/models/torch/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def __init__(self, config: MLPHeadConfig) -> None:
output_bias_initializer=config.output_layer_bias_initializer,
output_bias_initializer_config=config.output_layer_bias_initializer_config,
)
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = torch.Tensor([config.log_std_clip_param])

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -133,7 +135,19 @@ def get_output_specs(self) -> Optional[Spec]:
@override(Model)
@auto_fold_unfold_time("input_specs")
def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(inputs)
# Only clip the log standard deviations, if the user wants to clip. This
# avoids also clipping value heads.
if torch.isfinite(self.log_std_clip_param):
# Forward pass.
means, log_stds = torch.chunk(self.net(inputs), chunks=2, dim=-1)
# Clip the log standard deviations.
log_stds = torch.clamp(
log_stds, -self.log_std_clip_param, self.log_std_clip_param
)
return torch.cat((means, log_stds), dim=-1)
# Otherwise just return the logits.
else:
return self.net(inputs)


class TorchFreeLogStdMLPHead(TorchModel):
Expand Down Expand Up @@ -173,6 +187,8 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None:
self.log_std = torch.nn.Parameter(
torch.as_tensor([0.0] * self._half_output_dim)
)
# The clipping parameter for the log standard deviation.
self.log_std_clip_param = torch.Tensor([config.log_std_clip_param])

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
Expand All @@ -188,10 +204,14 @@ def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
# Compute the mean first, then append the log_std.
mean = self.net(inputs)

return torch.cat(
[mean, self.log_std.unsqueeze(0).repeat([len(mean), 1])], axis=1
# Clip the log standard deviation to avoid running into too small
# deviations that factually collapses the policy.
log_std = torch.clamp(
self.log_std, -self.log_std_clip_param, self.log_std_clip_param
)

return torch.cat([mean, log_std.unsqueeze(0).repeat([len(mean), 1])], axis=1)


class TorchCNNTransposeHead(TorchModel):
def __init__(self, config: CNNTransposeHeadConfig) -> None:
Expand Down
6 changes: 6 additions & 0 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@
# outputs floating bias variables instead of state-dependent. This only
# has an effect is using the default fully connected net.
"free_log_std": False,
# Whether to clip the log standard deviation when using a Gaussian (or any
# other continuous control distribution). This can stabilize training and avoid
# very small or large log standard deviations leading to numerical instabilities
# which can turn network outputs to `nan`. The default is infinity, i.e. no
# clipping.
"log_std_clip_param": float("inf"),
# Whether to skip the final linear layer used to resize the hidden layer
# outputs to size `num_outputs`. If True, then the last hidden layer
# should already match num_outputs.
Expand Down

0 comments on commit 9a6b2dd

Please sign in to comment.