diff --git a/doc/source/rllib/rllib-rlmodule.rst b/doc/source/rllib/rllib-rlmodule.rst index cf7be524ea4a..91473cc98452 100644 --- a/doc/source/rllib/rllib-rlmodule.rst +++ b/doc/source/rllib/rllib-rlmodule.rst @@ -16,7 +16,7 @@ RL Modules (Alpha) .. note:: - This is an experimental module that serves as a general replacement for ModelV2, and is subject to change. It will eventually match the functionality of the previous stack. If you only use high-level RLlib APIs such as :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` you should not experience siginficant changes, except for a few new parameters to the configuration object. If you've used custom models or policies before, you'll need to migrate them to the new modules. Check the Migration guide for more information. + This is an experimental module that serves as a general replacement for ModelV2, and is subject to change. It will eventually match the functionality of the previous stack. If you only use high-level RLlib APIs such as :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` you should not experience significant changes, except for a few new parameters to the configuration object. If you've used custom models or policies before, you'll need to migrate them to the new modules. Check the Migration guide for more information. The table below shows the list of migrated algorithms and their current supported features, which will be updated as we progress. @@ -33,19 +33,19 @@ RL Modules (Alpha) * - **PPO** - |pytorch| |tensorflow| - |pytorch| |tensorflow| - - |pytorch| + - |pytorch| |tensorflow| - - |pytorch| * - **Impala** - |pytorch| |tensorflow| - |pytorch| |tensorflow| - - |pytorch| + - |pytorch| |tensorflow| - - |pytorch| * - **APPO** - - |tensorflow| - - |tensorflow| - - + - |pytorch| |tensorflow| + - |pytorch| |tensorflow| + - |pytorch| |tensorflow| - - @@ -140,7 +140,93 @@ The minimum requirement is for sub-classes of :py:class:`~ray.rllib.core.rl_modu - :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule._forward_exploration`: Forward pass for exploration. -Also the class's constrcutor requires a dataclass config object called `~ray.rllib.core.rl_module.rl_module.RLModuleConfig` which contains the following fields: +For your custom :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` +methods, you must return a dictionary that either contains the key "actions" and/or the key "action_dist_inputs". + +If you return the "actions" key: + +- RLlib will use the actions provided thereunder as-is. +- If you also returned the "action_dist_inputs" key: RLlib will also create a :py:class:`~ray.rllib.models.distributions.Distribution` object from the distribution parameters under that key and - in the case of :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` - compute action probs and logp values from the given actions automatically. + +If you do not return the "actions" key: + +- You must return the "action_dist_inputs" key instead from your :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` methods. +- RLlib will create a :py:class:`~ray.rllib.models.distributions.Distribution` object from the distribution parameters under that key and sample actions from the thus generated distribution. +- In the case of :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration`, RLlib will also compute action probs and logp values from the sampled actions automatically. + +.. note:: + + In the case of :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference`, + the generated distributions (from returned key "action_dist_inputs") will always be made deterministic first via + the :py:meth:`~ray.rllib.models.distributions.Distribution.to_deterministic` utility before a possible action sample step. + Thus, for example, sampling from a Categorical distribution will be reduced to simply selecting the argmax actions from the distribution's logits/probs. + +Commonly used distribution implementations can be found under ``ray.rllib.models.tf.tf_distributions`` for tensorflow and +``ray.rllib.models.torch.torch_distributions`` for torch. You can choose to return determinstic actions, by creating a determinstic distribution instance. + + +.. tab-set:: + + .. tab-item:: Returning "actions" key + + .. code-block:: python + + """ + An RLModule whose forward_exploration/inference methods return the + "actions" key. + """ + + class MyRLModule(TorchRLModule): + ... + + def _forward_inference(self, batch): + ... + return { + "actions": ... # actions will be used as-is + } + + def _forward_exploration(self, batch): + ... + return { + "actions": ... # actions will be used as-is (no sampling step!) + "action_dist_inputs": ... # optional: If provided, will be used to compute action probs and logp. + } + + .. tab-item:: Not returning "actions" key + + .. code-block:: python + + """ + An RLModule whose forward_exploration/inference methods do NOT return the + "actions" key. + """ + + class MyRLModule(TorchRLModule): + ... + + def _forward_inference(self, batch): + ... + return { + # RLlib will: + # - Generate distribution from these parameters. + # - Convert distribution to a deterministic equivalent. + # - "sample" from the deterministic distribution. + "action_dist_inputs": ... + } + + def _forward_exploration(self, batch): + ... + return { + # RLlib will: + # - Generate distribution from these parameters. + # - "sample" from the (stochastic) distribution. + # - Compute action probs/logs automatically using the sampled + # actions and the generated distribution object. + "action_dist_inputs": ... + } + + +Also the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` class's constrcutor requires a dataclass config object called `~ray.rllib.core.rl_module.rl_module.RLModuleConfig` which contains the following fields: - :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.observation_space`: The observation space of the environment (either processed or raw). - :py:attr:`~ray.rllib.core.rl_module.rl_module.RLModuleConfig.action_space`: The action space of the environment. @@ -426,7 +512,11 @@ What your customization could have looked like before: return None, None, None -All of the ``Policy.compute_***`` functions expect that `~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and `~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` return a dictionary that contains the key "action_dist_inputs", whose value are the parameters (inputs) of a ``ray.rllib.models.distributions.Distribution`` class. Commonly used distribution implementations can be found under ``ray.rllib.models.tf.tf_distributions`` for tensorflow and ``ray.rllib.models.torch.torch_distributions`` for torch. You can choose to return determinstic actions, by creating a determinstic distribution instance. See `Writing Custom Single Agent RL Modules`_ for more details on how to implement your own custom RL Module. +All of the ``Policy.compute_***`` functions expect that +:py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference` +return a dictionary that either contains the key "actions" and/or the key "action_dist_inputs". + +See `Writing Custom Single Agent RL Modules`_ for more details on how to implement your own custom RL Modules. .. tab-set:: diff --git a/rllib/policy/eager_tf_policy_v2.py b/rllib/policy/eager_tf_policy_v2.py index d777623c19b0..8a4093fb0e2d 100644 --- a/rllib/policy/eager_tf_policy_v2.py +++ b/rllib/policy/eager_tf_policy_v2.py @@ -854,22 +854,36 @@ def _compute_actions_helper_rl_module_explore( input_dict[STATE_IN] = None input_dict[SampleBatch.SEQ_LENS] = None - action_dist_class = self.model.get_exploration_action_dist_cls() fwd_out = self.model.forward_exploration(input_dict) - action_dist = action_dist_class.from_logits( - fwd_out[SampleBatch.ACTION_DIST_INPUTS] - ) - actions = action_dist.sample() + + # ACTION_DIST_INPUTS field returned by `forward_exploration()` -> + # Create a distribution object. + action_dist = None + if SampleBatch.ACTION_DIST_INPUTS in fwd_out: + action_dist_class = self.model.get_exploration_action_dist_cls() + action_dist = action_dist_class.from_logits( + fwd_out[SampleBatch.ACTION_DIST_INPUTS] + ) + + # If `forward_exploration()` returned actions, use them here as-is. + if SampleBatch.ACTIONS in fwd_out: + actions = fwd_out[SampleBatch.ACTIONS] + # Otherwise, sample actions from the distribution. + else: + assert action_dist + actions = action_dist.sample() # Anything but action_dist and state_out is an extra fetch for k, v in fwd_out.items(): if k not in [SampleBatch.ACTIONS, "state_out"]: extra_fetches[k] = v - # Action-logp and action-prob. - logp = action_dist.logp(actions) - extra_fetches[SampleBatch.ACTION_LOGP] = logp - extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) + # Compute action-logp and action-prob from distribution and add to + # `extra_fetches`, if possible. + if action_dist is not None: + logp = action_dist.logp(actions) + extra_fetches[SampleBatch.ACTION_LOGP] = logp + extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp) return actions, {}, extra_fetches @@ -895,13 +909,25 @@ def _compute_actions_helper_rl_module_inference( input_dict[STATE_IN] = None input_dict[SampleBatch.SEQ_LENS] = None - action_dist_class = self.model.get_inference_action_dist_cls() fwd_out = self.model.forward_inference(input_dict) - action_dist = action_dist_class.from_logits( - fwd_out[SampleBatch.ACTION_DIST_INPUTS] - ) - action_dist = action_dist.to_deterministic() - actions = action_dist.sample() + + # ACTION_DIST_INPUTS field returned by `forward_exploration()` -> + # Create a (deterministic) distribution object. + action_dist = None + if SampleBatch.ACTION_DIST_INPUTS in fwd_out: + action_dist_class = self.model.get_inference_action_dist_cls() + action_dist = action_dist_class.from_logits( + fwd_out[SampleBatch.ACTION_DIST_INPUTS] + ) + action_dist = action_dist.to_deterministic() + + # If `forward_inference()` returned actions, use them here as-is. + if SampleBatch.ACTIONS in fwd_out: + actions = fwd_out[SampleBatch.ACTIONS] + # Otherwise, sample actions from the distribution. + else: + assert action_dist + actions = action_dist.sample() # Anything but action_dist and state_out is an extra fetch for k, v in fwd_out.items(): diff --git a/rllib/policy/torch_policy_v2.py b/rllib/policy/torch_policy_v2.py index ad348b3b0aca..4165da80a1f8 100644 --- a/rllib/policy/torch_policy_v2.py +++ b/rllib/policy/torch_policy_v2.py @@ -1127,30 +1127,57 @@ def _compute_action_helper( if self.model: self.model.eval() - extra_fetches = None + extra_fetches = dist_inputs = logp = None + + # New API stack: `self.model` is-a RLModule. if isinstance(self.model, RLModule): if explore: - action_dist_class = self.model.get_exploration_action_dist_cls() fwd_out = self.model.forward_exploration(input_dict) - action_dist = action_dist_class.from_logits( - fwd_out[SampleBatch.ACTION_DIST_INPUTS] - ) - actions = action_dist.sample() - logp = action_dist.logp(actions) + + # ACTION_DIST_INPUTS field returned by `forward_exploration()` -> + # Create a distribution object. + action_dist = None + if SampleBatch.ACTION_DIST_INPUTS in fwd_out: + dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS] + action_dist_class = self.model.get_exploration_action_dist_cls() + action_dist = action_dist_class.from_logits(dist_inputs) + + # If `forward_exploration()` returned actions, use them here as-is. + if SampleBatch.ACTIONS in fwd_out: + actions = fwd_out[SampleBatch.ACTIONS] + # Otherwise, sample actions from the distribution. + else: + assert action_dist + actions = action_dist.sample() + + # Compute action-logp and action-prob from distribution and add to + # `extra_fetches`, if possible. + if action_dist is not None: + logp = action_dist.logp(actions) else: - action_dist_class = self.model.get_inference_action_dist_cls() fwd_out = self.model.forward_inference(input_dict) - action_dist = action_dist_class.from_logits( - fwd_out[SampleBatch.ACTION_DIST_INPUTS] - ) - action_dist = action_dist.to_deterministic() - actions = action_dist.sample() - logp = None + + # ACTION_DIST_INPUTS field returned by `forward_exploration()` -> + # Create a distribution object. + action_dist = None + if SampleBatch.ACTION_DIST_INPUTS in fwd_out: + dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS] + action_dist_class = self.model.get_inference_action_dist_cls() + action_dist = action_dist_class.from_logits(dist_inputs) + action_dist = action_dist.to_deterministic() + + # If `forward_inference()` returned actions, use them here as-is. + if SampleBatch.ACTIONS in fwd_out: + actions = fwd_out[SampleBatch.ACTIONS] + # Otherwise, sample actions from the distribution. + else: + assert action_dist + actions = action_dist.sample() # Anything but actions and state_out is an extra fetch. state_out = fwd_out.pop(STATE_OUT, {}) extra_fetches = fwd_out - dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS] + elif is_overridden(self.action_sampler_fn): action_dist = None actions, logp, dist_inputs, state_out = self.action_sampler_fn(