-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add log-std clipping to 'MLPHead's. #47827
[RLlib] Add log-std clipping to 'MLPHead's. #47827
Conversation
… that can use continuous action distributions. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
rllib/core/models/configs.py
Outdated
@@ -181,6 +181,9 @@ class _MLPConfig(ModelConfig): | |||
output_layer_bias_initializer: Optional[Union[str, Callable]] = None | |||
output_layer_bias_initializer_config: Optional[Dict] = None | |||
|
|||
# Optional clip parameter for the log standard deviation. | |||
log_std_clip_param: float = float("inf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we set this to 20 by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Computer says no ... Yes, we can :)
rllib/algorithms/bc/bc_catalog.py
Outdated
@@ -95,6 +95,7 @@ def build_pi_head(self, framework: str) -> Model: | |||
hidden_layer_activation=self.pi_head_activation, | |||
output_layer_dim=required_output_dim, | |||
output_layer_activation="linear", | |||
log_std_clip_param=self._model_config_dict["log_std_clip_param"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the user doesn't define this in model_config_dict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then the default sets in doesn't it? (rllib/models/catalog.py
)
rllib/models/catalog.py
Outdated
# very small or large log standard deviations leading to numerical instabilities | ||
# which can turn network outputs to `nan`. The default is infinity, i.e. no | ||
# clipping. | ||
"log_std_clip_param": float("inf"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, ok, here it is. But still, as we'll soon get rid of the old stack model config dict, we should be defensive against users bringing their own model_config_dict
to BC or other algos.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sven1977 you are right. I am a bit reluctant to bring up an intermediate solution that does not hold in general for all other attributes of the model_config_dict
. I thought with the AlgorithmConfig._model_config_auto_includes
we solved the problem - it still uses the old rllib/models/catalog.py
defaults but can in near future be replaced by another logic (imo we will not be able to get around some default model config to ensure that (a) users do not need to provide all inputs and (b) to enable generality such that we do not have to add all configs to all algorithms from anew.
We could add log_std_clip_param
to the overridden model_config_auto_includes
. We also have a default value in the MLPHeadConfig
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Just 1-2 nits/change requests (default should be 20, not inf).
…'clip_log_std' to enable no clipping for value heads. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…e that we have no categorical distribution when applying log std clipping. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…tant is on the same device as the network output. Furthermore, fixed a bug where the newly registered buffer was not used. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…MODEL_DEFAULT'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Why are these changes needed?
Many implementation of continuous control algorithms suffer from instabilities in training where the log standard deviation takes on extreme values (most often very small values) and lead to numerical overflow in backward calculations (see this discussion).
These instabilities can be partially controlled for by using a log standard deviation that can freely move and acts like a bias (i.e. it is not trained within the neural networks, but still optimized during training). Nevertheless, clipping standard deviations still is an often used technique to stabilize training further.
This PR proposes a clip parameter for the logg standard deviation and applies it in all
MLPHead
s and therefore in all algorithms that use continuous actions. More specifically:log_std_clip_param
in therllib/models/catalog.py
.inf
, i.e. factually no clippingPPO
,APPO
,IMPALA
,SAC
,BC
, andMARWIL
; note,DreamerV3
already uses log std clipping)Related issue number
Closes #46442
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.