Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] RLModule API change: If "actions" key returned from forward_inference|exploration, use actions as-is. #36067

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 98 additions & 8 deletions doc/source/rllib/rllib-rlmodule.rst
Copy link
Contributor

@kouroshHakha kouroshHakha Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This added documentation is not specific to those who are migrating from policy API. We should put it under the right section. I think adding it somewhere close to Writing Custom Single Agent RL Modules would be the way to go (it requires getting rid of those policy specific numenclature)

Maybe we can consolidate your paragraph with the description of what needs to be implemented for each forward method shown here?

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I moved this up into the suggested section and created a new table to explain the difference between returning the "actions" key and NOT returning the "actions" key.

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ RL Modules (Alpha)

.. note::

This is an experimental module that serves as a general replacement for ModelV2, and is subject to change. It will eventually match the functionality of the previous stack. If you only use high-level RLlib APIs such as :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` you should not experience siginficant changes, except for a few new parameters to the configuration object. If you've used custom models or policies before, you'll need to migrate them to the new modules. Check the Migration guide for more information.
This is an experimental module that serves as a general replacement for ModelV2, and is subject to change. It will eventually match the functionality of the previous stack. If you only use high-level RLlib APIs such as :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` you should not experience significant changes, except for a few new parameters to the configuration object. If you've used custom models or policies before, you'll need to migrate them to the new modules. Check the Migration guide for more information.

The table below shows the list of migrated algorithms and their current supported features, which will be updated as we progress.

Expand All @@ -33,19 +33,19 @@ RL Modules (Alpha)
* - **PPO**
- |pytorch| |tensorflow|
- |pytorch| |tensorflow|
- |pytorch|
- |pytorch| |tensorflow|
-
- |pytorch|
* - **Impala**
- |pytorch| |tensorflow|
- |pytorch| |tensorflow|
- |pytorch|
- |pytorch| |tensorflow|
-
- |pytorch|
* - **APPO**
- |tensorflow|
- |tensorflow|
-
- |pytorch| |tensorflow|
- |pytorch| |tensorflow|
- |pytorch| |tensorflow|
-
-

Expand Down Expand Up @@ -140,7 +140,93 @@ The minimum requirement is for sub-classes of :py:class:`~ray.rllib.core.rl_modu

- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_exploration`: Forward pass for exploration.

Also the class's constrcutor requires a dataclass config object called `~ray.rllib.core.rl_module.rl_module.RLModuleConfig` which contains the following fields:
For your custom :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference`
methods, you must return a dictionary that either contains the key "actions" and/or the key "action_dist_inputs".

If you return the "actions" key:

- RLlib will use the actions provided thereunder as-is.
- If you also returned the "action_dist_inputs" key: RLlib will also create a :py:class:`~ray.rllib.models.distributions.Distribution` object from the distribution parameters under that key and - in the case of :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` - compute action probs and logp values from the given actions automatically.

If you do not return the "actions" key:

- You must return the "action_dist_inputs" key instead from your :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` methods.
- RLlib will create a :py:class:`~ray.rllib.models.distributions.Distribution` object from the distribution parameters under that key and sample actions from the thus generated distribution.
- In the case of :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration`, RLlib will also compute action probs and logp values from the sampled actions automatically.

.. note::

In the case of :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference`,
the generated distributions (from returned key "action_dist_inputs") will always be made deterministic first via
the :py:meth:`~ray.rllib.models.distributions.Distribution.to_deterministic` utility before a possible action sample step.
Thus, for example, sampling from a Categorical distribution will be reduced to simply selecting the argmax actions from the distribution's logits/probs.

Commonly used distribution implementations can be found under ``ray.rllib.models.tf.tf_distributions`` for tensorflow and
``ray.rllib.models.torch.torch_distributions`` for torch. You can choose to return determinstic actions, by creating a determinstic distribution instance.


.. tab-set::

.. tab-item:: Returning "actions" key

.. code-block:: python

"""
An RLModule whose forward_exploration/inference methods return the
"actions" key.
"""

class MyRLModule(TorchRLModule):
...

def _forward_inference(self, batch):
...
return {
"actions": ... # actions will be used as-is
}

def _forward_exploration(self, batch):
...
return {
"actions": ... # actions will be used as-is (no sampling step!)
"action_dist_inputs": ... # optional: If provided, will be used to compute action probs and logp.
}

.. tab-item:: Not returning "actions" key

.. code-block:: python

"""
An RLModule whose forward_exploration/inference methods do NOT return the
"actions" key.
"""

class MyRLModule(TorchRLModule):
...

def _forward_inference(self, batch):
...
return {
# RLlib will:
# - Generate distribution from these parameters.
# - Convert distribution to a deterministic equivalent.
# - "sample" from the deterministic distribution.
"action_dist_inputs": ...
}

def _forward_exploration(self, batch):
...
return {
# RLlib will:
# - Generate distribution from these parameters.
# - "sample" from the (stochastic) distribution.
# - Compute action probs/logs automatically using the sampled
# actions and the generated distribution object.
"action_dist_inputs": ...
}


Also the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` class's constrcutor requires a dataclass config object called `~ray.rllib.core.rl_module.rl_module.RLModuleConfig` which contains the following fields:

- :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.observation_space`: The observation space of the environment (either processed or raw).
- :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.action_space`: The action space of the environment.
Expand Down Expand Up @@ -426,7 +512,11 @@ What your customization could have looked like before:
return None, None, None


All of the ``Policy.compute_***`` functions expect that `~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and `~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` return a dictionary that contains the key "action_dist_inputs", whose value are the parameters (inputs) of a ``ray.rllib.models.distributions.Distribution`` class. Commonly used distribution implementations can be found under ``ray.rllib.models.tf.tf_distributions`` for tensorflow and ``ray.rllib.models.torch.torch_distributions`` for torch. You can choose to return determinstic actions, by creating a determinstic distribution instance. See `Writing Custom Single Agent RL Modules`_ for more details on how to implement your own custom RL Module.
All of the ``Policy.compute_***`` functions expect that
:py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference`
return a dictionary that either contains the key "actions" and/or the key "action_dist_inputs".

See `Writing Custom Single Agent RL Modules`_ for more details on how to implement your own custom RL Modules.

.. tab-set::

Expand Down
56 changes: 41 additions & 15 deletions rllib/policy/eager_tf_policy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,22 +854,36 @@ def _compute_actions_helper_rl_module_explore(
input_dict[STATE_IN] = None
input_dict[SampleBatch.SEQ_LENS] = None

action_dist_class = self.model.get_exploration_action_dist_cls()
fwd_out = self.model.forward_exploration(input_dict)
action_dist = action_dist_class.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)
actions = action_dist.sample()

# ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
# Create a distribution object.
action_dist = None
if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
action_dist_class = self.model.get_exploration_action_dist_cls()
action_dist = action_dist_class.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)

# If `forward_exploration()` returned actions, use them here as-is.
if SampleBatch.ACTIONS in fwd_out:
actions = fwd_out[SampleBatch.ACTIONS]
# Otherwise, sample actions from the distribution.
else:
assert action_dist
actions = action_dist.sample()

# Anything but action_dist and state_out is an extra fetch
for k, v in fwd_out.items():
if k not in [SampleBatch.ACTIONS, "state_out"]:
extra_fetches[k] = v

# Action-logp and action-prob.
logp = action_dist.logp(actions)
extra_fetches[SampleBatch.ACTION_LOGP] = logp
extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
# Compute action-logp and action-prob from distribution and add to
# `extra_fetches`, if possible.
if action_dist is not None:
logp = action_dist.logp(actions)
extra_fetches[SampleBatch.ACTION_LOGP] = logp
extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)

return actions, {}, extra_fetches

Expand All @@ -895,13 +909,25 @@ def _compute_actions_helper_rl_module_inference(
input_dict[STATE_IN] = None
input_dict[SampleBatch.SEQ_LENS] = None

action_dist_class = self.model.get_inference_action_dist_cls()
fwd_out = self.model.forward_inference(input_dict)
action_dist = action_dist_class.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)
action_dist = action_dist.to_deterministic()
actions = action_dist.sample()

# ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
# Create a (deterministic) distribution object.
action_dist = None
if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
action_dist_class = self.model.get_inference_action_dist_cls()
action_dist = action_dist_class.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)
action_dist = action_dist.to_deterministic()

# If `forward_inference()` returned actions, use them here as-is.
if SampleBatch.ACTIONS in fwd_out:
actions = fwd_out[SampleBatch.ACTIONS]
# Otherwise, sample actions from the distribution.
else:
assert action_dist
actions = action_dist.sample()

# Anything but action_dist and state_out is an extra fetch
for k, v in fwd_out.items():
Expand Down
57 changes: 42 additions & 15 deletions rllib/policy/torch_policy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,30 +1127,57 @@ def _compute_action_helper(
if self.model:
self.model.eval()

extra_fetches = None
extra_fetches = dist_inputs = logp = None

# New API stack: `self.model` is-a RLModule.
if isinstance(self.model, RLModule):
if explore:
action_dist_class = self.model.get_exploration_action_dist_cls()
fwd_out = self.model.forward_exploration(input_dict)
action_dist = action_dist_class.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)
actions = action_dist.sample()
logp = action_dist.logp(actions)

# ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
# Create a distribution object.
action_dist = None
if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
action_dist_class = self.model.get_exploration_action_dist_cls()
action_dist = action_dist_class.from_logits(dist_inputs)

# If `forward_exploration()` returned actions, use them here as-is.
if SampleBatch.ACTIONS in fwd_out:
actions = fwd_out[SampleBatch.ACTIONS]
# Otherwise, sample actions from the distribution.
else:
assert action_dist
actions = action_dist.sample()

# Compute action-logp and action-prob from distribution and add to
# `extra_fetches`, if possible.
if action_dist is not None:
logp = action_dist.logp(actions)
else:
action_dist_class = self.model.get_inference_action_dist_cls()
fwd_out = self.model.forward_inference(input_dict)
action_dist = action_dist_class.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)
action_dist = action_dist.to_deterministic()
actions = action_dist.sample()
logp = None

# ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
# Create a distribution object.
action_dist = None
if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
action_dist_class = self.model.get_inference_action_dist_cls()
action_dist = action_dist_class.from_logits(dist_inputs)
action_dist = action_dist.to_deterministic()

# If `forward_inference()` returned actions, use them here as-is.
if SampleBatch.ACTIONS in fwd_out:
actions = fwd_out[SampleBatch.ACTIONS]
# Otherwise, sample actions from the distribution.
else:
assert action_dist
actions = action_dist.sample()

# Anything but actions and state_out is an extra fetch.
state_out = fwd_out.pop(STATE_OUT, {})
extra_fetches = fwd_out
dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]

elif is_overridden(self.action_sampler_fn):
action_dist = None
actions, logp, dist_inputs, state_out = self.action_sampler_fn(
Expand Down