diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 43d5bfe0d00..efb1a755d0e 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -815,6 +815,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. DiscreteTensorSpec MultiDiscreteTensorSpec MultiOneHotDiscreteTensorSpec + NonTensorSpec OneHotDiscreteTensorSpec UnboundedContinuousTensorSpec UnboundedDiscreteTensorSpec diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index fe72ea89a56..5c39c5a1349 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -759,6 +759,75 @@ to always know what the latest available actions are. You can do this like so: Recorders --------- +.. _Environment-Recorders: + +Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as +reporting results after training. + +TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable +can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected +tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added +to keep track of the call count within ``callback``). + +To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used. + +Recording videos +~~~~~~~~~~~~~~~~ + +Several backends offer the possibility of recording rendered images from the environment. +If the pixels are already part of the environment output (e.g. Atari or other game simulators), a +:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input +a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger` +or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved. +For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"` +argument. + +The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch +formatted images (WHC or CWH). + + >>> logger = CSVLogger("dummy-exp", video_format="mp4") + >>> env = GymEnv("ALE/Pong-v5") + >>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"])) + >>> env.rollout(10) + >>> env.transform.dump() # Save the video and clear cache + +Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to +take care of calling dumpy when needed to avoid OOM issues. + +In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible +(some libraries only allow one environment instance per workspace). +In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform` +can be used to call `render` on the parent environment and save the images in the rollout data stream. +This class works over single and batched environments alike: + + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> # Uncomment this line to execute per-env + >>> # env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(16, EnvCreator(make_env)) + ... env.start() + ... # Comment this line to execute per-env + ... env = env.append_transform(PixelRenderTransform()) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... env.transform.dump() + ... env.close() + + .. currentmodule:: torchrl.record Recorders are transforms that register data as they come in, for logging purposes. @@ -769,6 +838,7 @@ Recorders are transforms that register data as they come in, for logging purpose TensorDictRecorder VideoRecorder + PixelRenderTransform Helpers diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 04d4386c631..821902b2ee2 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -9,7 +9,7 @@ loop the optimization steps. We believe this fits multiple RL training schemes, on-policy, off-policy, model-based and model-free solutions, offline RL and others. More particular cases, such as meta-RL algorithms may have training schemes that differ substentially. -The :obj:`trainer.train()` method can be sketched as follows: +The ``trainer.train()`` method can be sketched as follows: .. code-block:: :caption: Trainer loops @@ -63,35 +63,35 @@ The :obj:`trainer.train()` method can be sketched as follows: ... self._post_steps_hook() # "post_steps" ... self._post_steps_log_hook(batch) # "post_steps_log" -There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`, -:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`, -:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied. -Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`), -**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook -(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`). - -- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept - a :obj:`TensorDict` object as input and update it given some strategy. - Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization - constants update), data subsampling (:class:`~torchrl.trainers.BatchSubSampler`) and such. - -- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger - some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward - logger (:obj:`LogReward`) and such. Hooks should return a dictionary (or a None value) containing the - data to log. The key :obj:`"log_pbar"` is reserved to boolean values indicating if the logged value +There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``, +``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``, +``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied. +Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``), +**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook +(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``). + +- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept + a ``TensorDict`` object as input and update it given some strategy. + Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization + constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. + +- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger + some information retrieved from that data. Examples include the ``Recorder`` hook, the reward + logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the + data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. - **Operation** hooks are hooks that execute specific operations over the models, data collectors, - target network updates and such. For instance, syncing the weights of the collectors using :obj:`UpdateWeights` - or update the priority of the replay buffer using :obj:`ReplayBufferTrainer.update_priority` are examples - of operation hooks. They are data-independent (they do not require a :obj:`TensorDict` + target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights`` + or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples + of operation hooks. They are data-independent (they do not require a ``TensorDict`` input), they are just supposed to be executed once at every iteration (or every N iterations). -The hooks provided by TorchRL usually inherit from a common abstract class :obj:`TrainerHookBase`, -and all implement three base methods: a :obj:`state_dict` and :obj:`load_state_dict` method for -checkpointing and a :obj:`register` method that registers the hook at the default value in the +The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``, +and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for +checkpointing and a ``register`` method that registers the hook at the default value in the trainer. This method takes a trainer and a module name as input. For instance, the following logging -hook is executed every 10 calls to :obj:`"post_optim_log"`: +hook is executed every 10 calls to ``"post_optim_log"``: .. code-block:: @@ -122,22 +122,22 @@ Checkpointing ------------- The trainer class and hooks support checkpointing, which can be achieved either -using the `torchsnapshot `_ backend or -the regular torch backend. This can be controlled via the global variable :obj:`CKPT_BACKEND`: +using the ``torchsnapshot ``_ backend or +the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``: .. code-block:: $ CKPT_BACKEND=torch python script.py -which defaults to :obj:`torchsnapshot`. The advantage of torchsnapshot over pytorch +which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch is that it is a more flexible API, which supports distributed checkpointing and also allows users to load tensors from a file stored on disk to a tensor with a physical storage (which pytorch currently does not support). This allows, for instance, to load tensors from and to a replay buffer that would otherwise not fit in memory. When building a trainer, one can provide a file path where the checkpoints are to -be written. With the :obj:`torchsnapshot` backend, a directory path is expected, -whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` file). +be written. With the ``torchsnapshot`` backend, a directory path is expected, +whereas the ``torch`` backend expects a file path (typically a ``.pt`` file). .. code-block:: @@ -157,7 +157,7 @@ whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` fi >>> # to load from a path >>> trainer.load_from_file(filepath) -The :obj:`Trainer.train()` method can be used to execute the above loop with all of +The ``Trainer.train()`` method can be used to execute the above loop with all of its hooks, although using the :obj:`Trainer` class for its checkpointing capability only is also a perfectly valid use. @@ -238,6 +238,8 @@ Loggers Recording utils --------------- +Recording utils are detailed :ref:`here `. + .. currentmodule:: torchrl.record .. autosummary:: @@ -246,3 +248,4 @@ Recording utils VideoRecorder TensorDictRecorder + PixelRenderTransform diff --git a/test/test_loggers.py b/test/test_loggers.py index 98a330d0daf..f51b9d290ab 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib.util import os import os.path import pathlib @@ -12,12 +13,14 @@ import pytest import torch - from tensordict import MemoryMappedTensor + +from torchrl.envs import check_env_specs, GymEnv, ParallelEnv from torchrl.record.loggers.csv import CSVLogger from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger from torchrl.record.loggers.wandb import _has_wandb, WandbLogger +from torchrl.record.recorder import PixelRenderTransform, VideoRecorder if _has_tv: import torchvision @@ -28,6 +31,11 @@ if _has_mlflow: import mlflow +_has_gym = ( + importlib.util.find_spec("gym", None) is not None + or importlib.util.find_spec("gymnasium", None) is not None +) + @pytest.fixture def tb_logger(tmp_path_factory): @@ -397,6 +405,36 @@ def test_log_hparams(self, mlflow_fixture, config): logger.log_hparams(config) +@pytest.mark.skipif(not _has_gym, reason="gym required to test rendering") +class TestPixelRenderTransform: + @pytest.mark.parametrize("parallel", [False, True]) + @pytest.mark.parametrize("in_key", ["pixels", ("nested", "pix")]) + def test_pixel_render(self, parallel, in_key, tmpdir): + def make_env(): + env = GymEnv("CartPole-v1", render_mode="rgb_array", device=None) + env = env.append_transform(PixelRenderTransform(out_keys=in_key)) + return env + + if parallel: + env = ParallelEnv(2, make_env, mp_start_method="spawn") + else: + env = make_env() + logger = CSVLogger("dummy", log_dir=tmpdir) + try: + env = env.append_transform( + VideoRecorder(logger=logger, in_keys=[in_key], tag="pixels_record") + ) + check_env_specs(env) + env.rollout(10) + env.transform.dump() + assert os.path.isfile( + os.path.join(tmpdir, "dummy", "videos", "pixels_record_0.pt") + ) + finally: + if not env.is_closed: + env.close() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_specs.py b/test/test_specs.py index 36f5aef65ca..058144c1a94 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -23,6 +23,7 @@ LazyStackedCompositeSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, + NonTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, @@ -1462,6 +1463,14 @@ def test_multionehot(self, shape1, shape2): assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape + def test_non_tensor(self): + spec = NonTensorSpec((3, 4), device="cpu") + assert ( + spec.expand(2, 3, 4) + == spec.expand((2, 3, 4)) + == NonTensorSpec((2, 3, 4), device="cpu") + ) + @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_onehot(self, shape1, shape2): @@ -1675,6 +1684,11 @@ def test_multionehot( assert spec == spec.clone() assert spec is not spec.clone() + def test_non_tensor(self): + spec = NonTensorSpec(shape=(3, 4), device="cpu") + assert spec.clone() == spec + assert spec.clone() is not spec + @pytest.mark.parametrize("shape1", [None, (), (5,)]) def test_onehot( self, @@ -1840,6 +1854,11 @@ def test_multionehot( with pytest.raises(ValueError): spec.unbind(-1) + def test_non_tensor(self): + spec = NonTensorSpec(shape=(3, 4), device="cpu") + assert spec.unbind(1)[0] == spec[:, 0] + assert spec.unbind(1)[0] is not spec[:, 0] + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) def test_onehot( self, @@ -2114,6 +2133,15 @@ def test_stack_multionehot_zero(self, shape, stack_dim): r = c.zero() assert r.shape == c.shape + def test_stack_non_tensor(self, shape, stack_dim): + spec0 = NonTensorSpec(shape=shape, device="cpu") + spec1 = NonTensorSpec(shape=shape, device="cpu") + new_spec = torch.stack([spec0, spec1], stack_dim) + shape_insert = list(shape) + shape_insert.insert(stack_dim, 2) + assert new_spec.shape == torch.Size(shape_insert) + assert new_spec.device == torch.device("cpu") + def test_stack_onehot(self, shape, stack_dim): n = 5 shape = (*shape, 5) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index cb84ce32586..bc512a585b7 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -51,6 +51,7 @@ LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, + NonTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 71598938eab..c9d0683ad9c 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -31,7 +31,13 @@ import numpy as np import torch -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key +from tensordict import ( + LazyStackedTensorDict, + NonTensorData, + TensorDict, + TensorDictBase, + unravel_key, +) from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import get_binary_env_var @@ -715,8 +721,9 @@ def _flatten(self, start_dim, end_dim): shape = torch.zeros(self.shape, device="meta").flatten(start_dim, end_dim).shape return self._reshape(shape) + @abc.abstractmethod def _project(self, val: torch.Tensor) -> torch.Tensor: - raise NotImplementedError + raise NotImplementedError(type(self)) @abc.abstractmethod def is_in(self, val: torch.Tensor) -> bool: @@ -1917,6 +1924,107 @@ def _is_nested_list(index, notuple=False): return False +class NonTensorSpec(TensorSpec): + """A spec for non-tensor data.""" + + def __init__( + self, + shape: Union[torch.Size, int] = _DEFAULT_SHAPE, + device: Optional[DEVICE_TYPING] = None, + dtype: torch.dtype | None = None, + **kwargs, + ): + if isinstance(shape, int): + shape = torch.Size([shape]) + + _, device = _default_dtype_and_device(None, device) + domain = None + super().__init__( + shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs + ) + + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + elif dest is None: + return self + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + if dest_device == self.device and dest_dtype == self.dtype: + return self + return self.__class__(shape=self.shape, device=dest_device, dtype=None) + + def clone(self) -> NonTensorSpec: + return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) + + def rand(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def zero(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def one(self, shape): + return NonTensorData(data=None, shape=self.shape, device=self.device) + + def is_in(self, val: torch.Tensor) -> bool: + shape = torch.broadcast_shapes(self.shape, val.shape) + return ( + isinstance(val, NonTensorData) + and val.shape == shape + and val.device == self.device + and val.dtype == self.dtype + ) + + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): + shape = shape[0] + shape = torch.Size(shape) + if not all( + (old == 1) or (old == new) + for old, new in zip(self.shape, shape[-len(self.shape) :]) + ): + raise ValueError( + f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}." + ) + return self.__class__(shape=shape, device=self.device, dtype=None) + + def _reshape(self, shape): + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + + def _unflatten(self, dim, sizes): + shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape + return self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + + def __getitem__(self, idx: SHAPE_INDEX_TYPING): + """Indexes the current TensorSpec based on the provided index.""" + indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) + + def unbind(self, dim: int): + orig_dim = dim + if dim < 0: + dim = len(self.shape) + dim + if dim < 0: + raise ValueError( + f"Cannot unbind along dim {orig_dim} with shape {self.shape}." + ) + shape = tuple(s for i, s in enumerate(self.shape) if i != dim) + return tuple( + self.__class__( + shape=shape, + device=self.device, + dtype=self.dtype, + ) + for i in range(self.shape[dim]) + ) + + @dataclass(repr=False) class UnboundedContinuousTensorSpec(TensorSpec): """An unbounded continuous tensor spec. @@ -1954,7 +2062,9 @@ def __init__( shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to( + self, dest: Union[torch.dtype, DEVICE_TYPING] + ) -> UnboundedContinuousTensorSpec: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -1979,7 +2089,11 @@ def rand(self, shape=None) -> torch.Tensor: return torch.empty(shape, device=self.device, dtype=self.dtype).random_() def is_in(self, val: torch.Tensor) -> bool: - return True + shape = torch.broadcast_shapes(self.shape, val.shape) + return val.shape == shape and val.dtype == self.dtype + + def _project(self, val: torch.Tensor) -> torch.Tensor: + return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2130,7 +2244,8 @@ def rand(self, shape=None) -> torch.Tensor: return r.to(self.device) def is_in(self, val: torch.Tensor) -> bool: - return True + shape = torch.broadcast_shapes(self.shape, val.shape) + return val.shape == shape and val.dtype == self.dtype def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 660aecb3fd8..b3026da35ca 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1624,6 +1624,7 @@ def __getattr__(self, attr: str) -> Any: try: # _ = getattr(self._dummy_env, attr) if self.is_closed: + self.start() raise RuntimeError( "Trying to access attributes of closed/non started " "environments. Check that the batched environment " diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f5d4625fd07..8712c74340a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2298,7 +2298,7 @@ def rollout( self, max_steps: int, policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - callback: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, + callback: Optional[Callable[[TensorDictBase, ...], Any]] = None, auto_reset: bool = True, auto_cast_to_device: bool = False, break_when_any_done: bool = True, @@ -2320,7 +2320,10 @@ def rollout( The policy can be any callable that reads either a tensordict or the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``. Defaults to `None`. - callback (callable, optional): function to be called at each iteration with the given TensorDict. + callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given + TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user + responsibility to save any result within the callback call if data needs to be carried over beyond + the call to ``rollout``. auto_reset (bool, optional): if ``True``, resets automatically the environment if it is in a done state when the rollout is initiated. Default is ``True``. diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 49cf58f8103..cd51b4fd23b 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,6 +16,7 @@ from enum import Enum from typing import Any, Dict, List, Union +import tensordict import torch from tensordict import ( @@ -25,6 +26,7 @@ TensorDictBase, unravel_key, ) +from tensordict.base import _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! @@ -183,10 +185,15 @@ def _is_reset(key: NestedKey): return key == "_reset" return key[-1] == "_reset" - actual = {key for key in tensordict.keys(True, True) if not _is_reset(key)} + actual = { + key + for key in tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) + if not _is_reset(key) + } expected = set(expected) self.validated = expected.intersection(actual) == expected if not self.validated: + raise RuntimeError warnings.warn( "The expected key set and actual key set differ. " "This will work but with a slower throughput than " @@ -262,7 +269,7 @@ def _exclude( cls._exclude(nested_key_dict, td, td_out) return out has_set = False - for key, value in data_in.items(): + for key, value in data_in.items(is_leaf=tensordict.base._is_leaf_nontensor): subdict = nested_key_dict.get(key, NO_DEFAULT) if subdict is NO_DEFAULT: value = value.copy() if is_tensor_collection(value) else value diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py index 726d29ea051..f6c9bcdefbb 100644 --- a/torchrl/record/__init__.py +++ b/torchrl/record/__init__.py @@ -4,4 +4,4 @@ # LICENSE file in the root directory of this source tree. from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger -from .recorder import TensorDictRecorder, VideoRecorder +from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index a486b689feb..079c8b71e12 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -6,14 +6,20 @@ import importlib.util from copy import copy -from typing import Optional, Sequence +from typing import Callable, List, Optional, Sequence, Union +import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import NonTensorData, TensorDict, TensorDictBase from tensordict.utils import NestedKey +from torchrl._utils import _can_be_pickled +from torchrl.data import TensorSpec +from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.utils import CloudpickleWrapper +from torchrl.envs import EnvBase from torchrl.envs.transforms import ObservationTransform, Transform from torchrl.record.loggers import Logger @@ -155,20 +161,22 @@ def skip(self, value): self._skip = value def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + if isinstance(observation, NonTensorData): + observation_trsf = torch.tensor(observation.data) + else: + observation_trsf = observation self.count += 1 if self.count % self.skip == 0: if ( - observation.ndim >= 3 - and observation.shape[-3] == 3 - and observation.shape[-2] > 3 - and observation.shape[-1] > 3 + observation_trsf.ndim >= 3 + and observation_trsf.shape[-3] == 3 + and observation_trsf.shape[-2] > 3 + and observation_trsf.shape[-1] > 3 ): # permute the channels to the last dim - observation_trsf = observation.permute( - *range(observation.ndim - 3), -2, -1, -3 + observation_trsf = observation_trsf.permute( + *range(observation_trsf.ndim - 3), -2, -1, -3 ) - else: - observation_trsf = observation if not ( observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2 ): @@ -321,3 +329,209 @@ def _reset( ) -> TensorDictBase: self._call(tensordict_reset) return tensordict_reset + + +class PixelRenderTransform(Transform): + """A transform to call render on the parent environment and register the pixel observation in the tensordict. + + This transform offers an alternative to the ``from_pixels`` syntatic sugar when instantiating an environment + that offers rendering is expensive, or when ``from_pixels`` is not implemented. + It can be used within a single environment or over batched environments alike. + + Args: + out_keys (List[NestedKey] or Nested): List of keys where to register the pixel observations. + preproc (Callable, optional): a preproc function. Can be used to reshape the observation, or apply + any other transformation that makes it possible to register it in the output data. + as_non_tensor (bool, optional): if ``True``, the data will be written as a :class:`~tensordict.NonTensorData` + thereby relaxing the shape requirements. If not provided, it will be inferred automatically from the + input data type and shape. + render_method (str, optional): the name of the render method. Defaults to ``"render"``. + **kwargs: additional keyword arguments to pass to the render function (e.g. ``mode="rgb_array"``). + + Examples: + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(4, EnvCreator(make_env)) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... print(env) + ... env.transform.dump() + ... env.close() + + This transform can also be used whenever a batched environment ``render()`` returns a single image: + + Examples: + >>> from torchrl.envs import check_env_specs + >>> from torchrl.envs.libs.vmas import VmasEnv + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> env = VmasEnv( + ... scenario="flocking", + ... num_envs=32, + ... continuous_actions=True, + ... max_steps=200, + ... device="cpu", + ... seed=None, + ... # Scenario kwargs + ... n_agents=5, + ... ) + >>> + >>> logger = CSVLogger("dummy", video_format="mp4") + >>> + >>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy())) + >>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + >>> + >>> check_env_specs(env) + >>> + >>> r = env.rollout(30) + >>> env.transform[-1].dump() + + The transform can be disabled using the :meth:`~torchrl.record.PixelRenderTransform.switch` method, which will + turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behaviour). + Since transforms are :class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control + this behaviour: + + >>> def switch(module): + ... if isinstance(module, PixelRenderTransform): + ... module.switch() + >>> env.apply(switch) + + """ + + def __init__( + self, + out_keys: List[NestedKey] = None, + preproc: Callable[ + [np.ndarray | torch.Tensor], np.ndarray | torch.Tensor + ] = None, + as_non_tensor: bool = None, + render_method: str = "render", + **kwargs, + ) -> None: + if out_keys is None: + out_keys = ["pixels"] + elif isinstance(out_keys, (str, tuple)): + out_keys = [out_keys] + if len(out_keys) != 1: + raise RuntimeError( + f"Expected one and only one out_key, got out_keys={out_keys}" + ) + if preproc is not None and not _can_be_pickled(preproc): + preproc = CloudpickleWrapper(preproc) + self.preproc = preproc + self.as_non_tensor = as_non_tensor + self.kwargs = kwargs + self.render_method = render_method + self._enabled = True + super().__init__(in_keys=[], out_keys=out_keys) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + return self._call(tensordict_reset) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self._enabled: + return tensordict + + array = getattr(self.parent, self.render_method)(**self.kwargs) + if self.preproc: + array = self.preproc(array) + if self.as_non_tensor is None: + if isinstance(array, list): + if isinstance(array[0], np.ndarray): + array = np.asarray(array) + else: + array = torch.as_tensor(array) + if ( + array.ndim == 3 + and array.shape[-1] == 3 + and self.parent.batch_size != () + ): + self.as_non_tensor = True + else: + self.as_non_tensor = False + if not self.as_non_tensor: + try: + tensordict.set(self.out_keys[0], array) + except Exception: + raise RuntimeError( + f"An exception was raised while writing the rendered array " + f"(shape={getattr(array, 'shape', None)}, dtype={getattr(array, 'dtype', None)}) in the tensordict with shape {tensordict.shape}. " + f"Consider adapting your preproc function in {type(self).__name__}. You can also " + f"pass keyword arguments to the render function of the parent environment, or save " + f"this observation as a non-tensor data with as_non_tensor=True." + ) + else: + tensordict.set_non_tensor(self.out_keys[0], array) + return tensordict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + # Adds the pixel observation spec by calling render on the parent env + switch = False + if not self.enabled: + switch = True + self.switch() + parent = self.parent + td_in = TensorDict({}, batch_size=parent.batch_size, device=parent.device) + self._call(td_in) + obs = td_in.get(self.out_keys[0]) + if isinstance(obs, NonTensorData): + spec = NonTensorSpec(device=obs.device, dtype=obs.dtype, shape=obs.shape) + else: + spec = UnboundedContinuousTensorSpec( + device=obs.device, dtype=obs.dtype, shape=obs.shape + ) + observation_spec[self.out_keys[0]] = spec + if switch: + self.switch() + return observation_spec + + def switch(self, mode: str | bool = None): + """Sets the transform on or off. + + Args: + mode (str or bool, optional): if provided, sets the switch to the desired mode. + ``"on"``, ``"off"``, ``True`` and ``False`` are accepted values. + By default, ``switch`` sets the mode to the opposite of the current one. + + """ + if mode is None: + mode = not self._enabled + if not isinstance(mode, bool): + if mode not in ("on", "off"): + raise ValueError("mode must be either 'on' or 'off', or a boolean.") + mode = mode == "on" + self._enabled = mode + + @property + def enabled(self) -> bool: + """Whether the recorder is enabled.""" + return self._enabled + + def set_container(self, container: Union[Transform, EnvBase]) -> None: + out = super().set_container(container) + if isinstance(self.parent, EnvBase): + # Start the env if needed + method = getattr(self.parent, self.render_method, None) + if method is None or not callable(method): + raise ValueError( + f"The render method must exist and be a callable. Got render={method}." + ) + return out