diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst
index 43d5bfe0d00..efb1a755d0e 100644
--- a/docs/source/reference/data.rst
+++ b/docs/source/reference/data.rst
@@ -815,6 +815,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
DiscreteTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
+ NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst
index fe72ea89a56..5c39c5a1349 100644
--- a/docs/source/reference/envs.rst
+++ b/docs/source/reference/envs.rst
@@ -759,6 +759,75 @@ to always know what the latest available actions are. You can do this like so:
Recorders
---------
+.. _Environment-Recorders:
+
+Recording data during environment rollout execution is crucial to keep an eye on the algorithm performance as well as
+reporting results after training.
+
+TorchRL offers several tools to interact with the environment output: first and foremost, a ``callback`` callable
+can be passed to the :meth:`~torchrl.envs.EnvBase.rollout` method. This function will be called upon the collected
+tensordict at each iteration of the rollout (if some iterations have to be skipped, an internal variable should be added
+to keep track of the call count within ``callback``).
+
+To save collected tensordicts on disk, the :class:`~torchrl.record.TensorDictRecorder` can be used.
+
+Recording videos
+~~~~~~~~~~~~~~~~
+
+Several backends offer the possibility of recording rendered images from the environment.
+If the pixels are already part of the environment output (e.g. Atari or other game simulators), a
+:class:`~torchrl.record.VideoRecorder` can be appended to the environment. This environment transform takes as input
+a logger capable of recording videos (e.g. :class:`~torchrl.record.loggers.CSVLogger`, :class:`~torchrl.record.loggers.WandbLogger`
+or :class:`~torchrl.record.loggers.TensorBoardLogger`) as well as a tag indicating where the video should be saved.
+For instance, to save mp4 videos on disk, one can use :class:`~torchrl.record.loggers.CSVLogger` with a `video_format="mp4"`
+argument.
+
+The :class:`~torchrl.record.VideoRecorder` transform can handle batched images and automatically detects numpy or PyTorch
+formatted images (WHC or CWH).
+
+ >>> logger = CSVLogger("dummy-exp", video_format="mp4")
+ >>> env = GymEnv("ALE/Pong-v5")
+ >>> env = env.append_transform(VideoRecorder(logger, tag="rendered", in_keys=["pixels"]))
+ >>> env.rollout(10)
+ >>> env.transform.dump() # Save the video and clear cache
+
+Note that the cache of the transform will keep on growing until dump is called. It is the user responsibility to
+take care of calling dumpy when needed to avoid OOM issues.
+
+In some cases, creating a testing environment where images can be collected is tedious or expensive, or simply impossible
+(some libraries only allow one environment instance per workspace).
+In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform`
+can be used to call `render` on the parent environment and save the images in the rollout data stream.
+This class works over single and batched environments alike:
+
+ >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
+ >>> from torchrl.record.loggers import CSVLogger
+ >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
+ >>>
+ >>> def make_env():
+ >>> env = GymEnv("CartPole-v1", render_mode="rgb_array")
+ >>> # Uncomment this line to execute per-env
+ >>> # env = env.append_transform(PixelRenderTransform())
+ >>> return env
+ >>>
+ >>> if __name__ == "__main__":
+ ... logger = CSVLogger("dummy", video_format="mp4")
+ ...
+ ... env = ParallelEnv(16, EnvCreator(make_env))
+ ... env.start()
+ ... # Comment this line to execute per-env
+ ... env = env.append_transform(PixelRenderTransform())
+ ...
+ ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
+ ... env.rollout(3)
+ ...
+ ... check_env_specs(env)
+ ...
+ ... r = env.rollout(30)
+ ... env.transform.dump()
+ ... env.close()
+
+
.. currentmodule:: torchrl.record
Recorders are transforms that register data as they come in, for logging purposes.
@@ -769,6 +838,7 @@ Recorders are transforms that register data as they come in, for logging purpose
TensorDictRecorder
VideoRecorder
+ PixelRenderTransform
Helpers
diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst
index 04d4386c631..821902b2ee2 100644
--- a/docs/source/reference/trainers.rst
+++ b/docs/source/reference/trainers.rst
@@ -9,7 +9,7 @@ loop the optimization steps. We believe this fits multiple RL training schemes,
on-policy, off-policy, model-based and model-free solutions, offline RL and others.
More particular cases, such as meta-RL algorithms may have training schemes that differ substentially.
-The :obj:`trainer.train()` method can be sketched as follows:
+The ``trainer.train()`` method can be sketched as follows:
.. code-block::
:caption: Trainer loops
@@ -63,35 +63,35 @@ The :obj:`trainer.train()` method can be sketched as follows:
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
-There are 10 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`,
-:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`,
-:obj:`"post_steps_log"`, :obj:`"post_optim_log"` and :obj:`"optimizer"`. They are indicated in the comments where they are applied.
-Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`),
-**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook
-(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`).
-
-- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
- a :obj:`TensorDict` object as input and update it given some strategy.
- Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
- constants update), data subsampling (:class:`~torchrl.trainers.BatchSubSampler`) and such.
-
-- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
- some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
- logger (:obj:`LogReward`) and such. Hooks should return a dictionary (or a None value) containing the
- data to log. The key :obj:`"log_pbar"` is reserved to boolean values indicating if the logged value
+There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``,
+``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``,
+``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied.
+Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``),
+**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook
+(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``).
+
+- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept
+ a ``TensorDict`` object as input and update it given some strategy.
+ Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization
+ constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.
+
+- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
+ some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
+ logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
+ data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.
- **Operation** hooks are hooks that execute specific operations over the models, data collectors,
- target network updates and such. For instance, syncing the weights of the collectors using :obj:`UpdateWeights`
- or update the priority of the replay buffer using :obj:`ReplayBufferTrainer.update_priority` are examples
- of operation hooks. They are data-independent (they do not require a :obj:`TensorDict`
+ target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights``
+ or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples
+ of operation hooks. They are data-independent (they do not require a ``TensorDict``
input), they are just supposed to be executed once at every iteration (or every N iterations).
-The hooks provided by TorchRL usually inherit from a common abstract class :obj:`TrainerHookBase`,
-and all implement three base methods: a :obj:`state_dict` and :obj:`load_state_dict` method for
-checkpointing and a :obj:`register` method that registers the hook at the default value in the
+The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``,
+and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for
+checkpointing and a ``register`` method that registers the hook at the default value in the
trainer. This method takes a trainer and a module name as input. For instance, the following logging
-hook is executed every 10 calls to :obj:`"post_optim_log"`:
+hook is executed every 10 calls to ``"post_optim_log"``:
.. code-block::
@@ -122,22 +122,22 @@ Checkpointing
-------------
The trainer class and hooks support checkpointing, which can be achieved either
-using the `torchsnapshot `_ backend or
-the regular torch backend. This can be controlled via the global variable :obj:`CKPT_BACKEND`:
+using the ``torchsnapshot ``_ backend or
+the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:
.. code-block::
$ CKPT_BACKEND=torch python script.py
-which defaults to :obj:`torchsnapshot`. The advantage of torchsnapshot over pytorch
+which defaults to ``torchsnapshot``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.
When building a trainer, one can provide a file path where the checkpoints are to
-be written. With the :obj:`torchsnapshot` backend, a directory path is expected,
-whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` file).
+be written. With the ``torchsnapshot`` backend, a directory path is expected,
+whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).
.. code-block::
@@ -157,7 +157,7 @@ whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` fi
>>> # to load from a path
>>> trainer.load_from_file(filepath)
-The :obj:`Trainer.train()` method can be used to execute the above loop with all of
+The ``Trainer.train()`` method can be used to execute the above loop with all of
its hooks, although using the :obj:`Trainer` class for its checkpointing capability
only is also a perfectly valid use.
@@ -238,6 +238,8 @@ Loggers
Recording utils
---------------
+Recording utils are detailed :ref:`here `.
+
.. currentmodule:: torchrl.record
.. autosummary::
@@ -246,3 +248,4 @@ Recording utils
VideoRecorder
TensorDictRecorder
+ PixelRenderTransform
diff --git a/test/test_loggers.py b/test/test_loggers.py
index 98a330d0daf..f51b9d290ab 100644
--- a/test/test_loggers.py
+++ b/test/test_loggers.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
+import importlib.util
import os
import os.path
import pathlib
@@ -12,12 +13,14 @@
import pytest
import torch
-
from tensordict import MemoryMappedTensor
+
+from torchrl.envs import check_env_specs, GymEnv, ParallelEnv
from torchrl.record.loggers.csv import CSVLogger
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
+from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
if _has_tv:
import torchvision
@@ -28,6 +31,11 @@
if _has_mlflow:
import mlflow
+_has_gym = (
+ importlib.util.find_spec("gym", None) is not None
+ or importlib.util.find_spec("gymnasium", None) is not None
+)
+
@pytest.fixture
def tb_logger(tmp_path_factory):
@@ -397,6 +405,36 @@ def test_log_hparams(self, mlflow_fixture, config):
logger.log_hparams(config)
+@pytest.mark.skipif(not _has_gym, reason="gym required to test rendering")
+class TestPixelRenderTransform:
+ @pytest.mark.parametrize("parallel", [False, True])
+ @pytest.mark.parametrize("in_key", ["pixels", ("nested", "pix")])
+ def test_pixel_render(self, parallel, in_key, tmpdir):
+ def make_env():
+ env = GymEnv("CartPole-v1", render_mode="rgb_array", device=None)
+ env = env.append_transform(PixelRenderTransform(out_keys=in_key))
+ return env
+
+ if parallel:
+ env = ParallelEnv(2, make_env, mp_start_method="spawn")
+ else:
+ env = make_env()
+ logger = CSVLogger("dummy", log_dir=tmpdir)
+ try:
+ env = env.append_transform(
+ VideoRecorder(logger=logger, in_keys=[in_key], tag="pixels_record")
+ )
+ check_env_specs(env)
+ env.rollout(10)
+ env.transform.dump()
+ assert os.path.isfile(
+ os.path.join(tmpdir, "dummy", "videos", "pixels_record_0.pt")
+ )
+ finally:
+ if not env.is_closed:
+ env.close()
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/test_specs.py b/test/test_specs.py
index 36f5aef65ca..058144c1a94 100644
--- a/test/test_specs.py
+++ b/test/test_specs.py
@@ -23,6 +23,7 @@
LazyStackedCompositeSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
+ NonTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
@@ -1462,6 +1463,14 @@ def test_multionehot(self, shape1, shape2):
assert spec2.rand().shape == spec2.shape
assert spec2.zero().shape == spec2.shape
+ def test_non_tensor(self):
+ spec = NonTensorSpec((3, 4), device="cpu")
+ assert (
+ spec.expand(2, 3, 4)
+ == spec.expand((2, 3, 4))
+ == NonTensorSpec((2, 3, 4), device="cpu")
+ )
+
@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
def test_onehot(self, shape1, shape2):
@@ -1675,6 +1684,11 @@ def test_multionehot(
assert spec == spec.clone()
assert spec is not spec.clone()
+ def test_non_tensor(self):
+ spec = NonTensorSpec(shape=(3, 4), device="cpu")
+ assert spec.clone() == spec
+ assert spec.clone() is not spec
+
@pytest.mark.parametrize("shape1", [None, (), (5,)])
def test_onehot(
self,
@@ -1840,6 +1854,11 @@ def test_multionehot(
with pytest.raises(ValueError):
spec.unbind(-1)
+ def test_non_tensor(self):
+ spec = NonTensorSpec(shape=(3, 4), device="cpu")
+ assert spec.unbind(1)[0] == spec[:, 0]
+ assert spec.unbind(1)[0] is not spec[:, 0]
+
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
self,
@@ -2114,6 +2133,15 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
r = c.zero()
assert r.shape == c.shape
+ def test_stack_non_tensor(self, shape, stack_dim):
+ spec0 = NonTensorSpec(shape=shape, device="cpu")
+ spec1 = NonTensorSpec(shape=shape, device="cpu")
+ new_spec = torch.stack([spec0, spec1], stack_dim)
+ shape_insert = list(shape)
+ shape_insert.insert(stack_dim, 2)
+ assert new_spec.shape == torch.Size(shape_insert)
+ assert new_spec.device == torch.device("cpu")
+
def test_stack_onehot(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py
index cb84ce32586..bc512a585b7 100644
--- a/torchrl/data/__init__.py
+++ b/torchrl/data/__init__.py
@@ -51,6 +51,7 @@
LazyStackedTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
+ NonTensorSpec,
OneHotDiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py
index 71598938eab..c9d0683ad9c 100644
--- a/torchrl/data/tensor_specs.py
+++ b/torchrl/data/tensor_specs.py
@@ -31,7 +31,13 @@
import numpy as np
import torch
-from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase, unravel_key
+from tensordict import (
+ LazyStackedTensorDict,
+ NonTensorData,
+ TensorDict,
+ TensorDictBase,
+ unravel_key,
+)
from tensordict.utils import _getitem_batch_size, NestedKey
from torchrl._utils import get_binary_env_var
@@ -715,8 +721,9 @@ def _flatten(self, start_dim, end_dim):
shape = torch.zeros(self.shape, device="meta").flatten(start_dim, end_dim).shape
return self._reshape(shape)
+ @abc.abstractmethod
def _project(self, val: torch.Tensor) -> torch.Tensor:
- raise NotImplementedError
+ raise NotImplementedError(type(self))
@abc.abstractmethod
def is_in(self, val: torch.Tensor) -> bool:
@@ -1917,6 +1924,107 @@ def _is_nested_list(index, notuple=False):
return False
+class NonTensorSpec(TensorSpec):
+ """A spec for non-tensor data."""
+
+ def __init__(
+ self,
+ shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
+ device: Optional[DEVICE_TYPING] = None,
+ dtype: torch.dtype | None = None,
+ **kwargs,
+ ):
+ if isinstance(shape, int):
+ shape = torch.Size([shape])
+
+ _, device = _default_dtype_and_device(None, device)
+ domain = None
+ super().__init__(
+ shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
+ )
+
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec:
+ if isinstance(dest, torch.dtype):
+ dest_dtype = dest
+ dest_device = self.device
+ elif dest is None:
+ return self
+ else:
+ dest_dtype = self.dtype
+ dest_device = torch.device(dest)
+ if dest_device == self.device and dest_dtype == self.dtype:
+ return self
+ return self.__class__(shape=self.shape, device=dest_device, dtype=None)
+
+ def clone(self) -> NonTensorSpec:
+ return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
+
+ def rand(self, shape):
+ return NonTensorData(data=None, shape=self.shape, device=self.device)
+
+ def zero(self, shape):
+ return NonTensorData(data=None, shape=self.shape, device=self.device)
+
+ def one(self, shape):
+ return NonTensorData(data=None, shape=self.shape, device=self.device)
+
+ def is_in(self, val: torch.Tensor) -> bool:
+ shape = torch.broadcast_shapes(self.shape, val.shape)
+ return (
+ isinstance(val, NonTensorData)
+ and val.shape == shape
+ and val.device == self.device
+ and val.dtype == self.dtype
+ )
+
+ def expand(self, *shape):
+ if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
+ shape = shape[0]
+ shape = torch.Size(shape)
+ if not all(
+ (old == 1) or (old == new)
+ for old, new in zip(self.shape, shape[-len(self.shape) :])
+ ):
+ raise ValueError(
+ f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
+ )
+ return self.__class__(shape=shape, device=self.device, dtype=None)
+
+ def _reshape(self, shape):
+ return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
+
+ def _unflatten(self, dim, sizes):
+ shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
+ return self.__class__(
+ shape=shape,
+ device=self.device,
+ dtype=self.dtype,
+ )
+
+ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
+ """Indexes the current TensorSpec based on the provided index."""
+ indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
+ return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
+
+ def unbind(self, dim: int):
+ orig_dim = dim
+ if dim < 0:
+ dim = len(self.shape) + dim
+ if dim < 0:
+ raise ValueError(
+ f"Cannot unbind along dim {orig_dim} with shape {self.shape}."
+ )
+ shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
+ return tuple(
+ self.__class__(
+ shape=shape,
+ device=self.device,
+ dtype=self.dtype,
+ )
+ for i in range(self.shape[dim])
+ )
+
+
@dataclass(repr=False)
class UnboundedContinuousTensorSpec(TensorSpec):
"""An unbounded continuous tensor spec.
@@ -1954,7 +2062,9 @@ def __init__(
shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs
)
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(
+ self, dest: Union[torch.dtype, DEVICE_TYPING]
+ ) -> UnboundedContinuousTensorSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -1979,7 +2089,11 @@ def rand(self, shape=None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()
def is_in(self, val: torch.Tensor) -> bool:
- return True
+ shape = torch.broadcast_shapes(self.shape, val.shape)
+ return val.shape == shape and val.dtype == self.dtype
+
+ def _project(self, val: torch.Tensor) -> torch.Tensor:
+ return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape)
def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
@@ -2130,7 +2244,8 @@ def rand(self, shape=None) -> torch.Tensor:
return r.to(self.device)
def is_in(self, val: torch.Tensor) -> bool:
- return True
+ shape = torch.broadcast_shapes(self.shape, val.shape)
+ return val.shape == shape and val.dtype == self.dtype
def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py
index 660aecb3fd8..b3026da35ca 100644
--- a/torchrl/envs/batched_envs.py
+++ b/torchrl/envs/batched_envs.py
@@ -1624,6 +1624,7 @@ def __getattr__(self, attr: str) -> Any:
try:
# _ = getattr(self._dummy_env, attr)
if self.is_closed:
+ self.start()
raise RuntimeError(
"Trying to access attributes of closed/non started "
"environments. Check that the batched environment "
diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py
index f5d4625fd07..8712c74340a 100644
--- a/torchrl/envs/common.py
+++ b/torchrl/envs/common.py
@@ -2298,7 +2298,7 @@ def rollout(
self,
max_steps: int,
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
- callback: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
+ callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
auto_reset: bool = True,
auto_cast_to_device: bool = False,
break_when_any_done: bool = True,
@@ -2320,7 +2320,10 @@ def rollout(
The policy can be any callable that reads either a tensordict or
the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``.
Defaults to `None`.
- callback (callable, optional): function to be called at each iteration with the given TensorDict.
+ callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given
+ TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
+ responsibility to save any result within the callback call if data needs to be carried over beyond
+ the call to ``rollout``.
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is ``True``.
diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py
index 49cf58f8103..cd51b4fd23b 100644
--- a/torchrl/envs/utils.py
+++ b/torchrl/envs/utils.py
@@ -16,6 +16,7 @@
from enum import Enum
from typing import Any, Dict, List, Union
+import tensordict
import torch
from tensordict import (
@@ -25,6 +26,7 @@
TensorDictBase,
unravel_key,
)
+from tensordict.base import _is_leaf_nontensor
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.nn.probabilistic import ( # noqa
# Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated!
@@ -183,10 +185,15 @@ def _is_reset(key: NestedKey):
return key == "_reset"
return key[-1] == "_reset"
- actual = {key for key in tensordict.keys(True, True) if not _is_reset(key)}
+ actual = {
+ key
+ for key in tensordict.keys(True, True, is_leaf=_is_leaf_nontensor)
+ if not _is_reset(key)
+ }
expected = set(expected)
self.validated = expected.intersection(actual) == expected
if not self.validated:
+ raise RuntimeError
warnings.warn(
"The expected key set and actual key set differ. "
"This will work but with a slower throughput than "
@@ -262,7 +269,7 @@ def _exclude(
cls._exclude(nested_key_dict, td, td_out)
return out
has_set = False
- for key, value in data_in.items():
+ for key, value in data_in.items(is_leaf=tensordict.base._is_leaf_nontensor):
subdict = nested_key_dict.get(key, NO_DEFAULT)
if subdict is NO_DEFAULT:
value = value.copy() if is_tensor_collection(value) else value
diff --git a/torchrl/record/__init__.py b/torchrl/record/__init__.py
index 726d29ea051..f6c9bcdefbb 100644
--- a/torchrl/record/__init__.py
+++ b/torchrl/record/__init__.py
@@ -4,4 +4,4 @@
# LICENSE file in the root directory of this source tree.
from .loggers import CSVLogger, MLFlowLogger, TensorboardLogger, WandbLogger
-from .recorder import TensorDictRecorder, VideoRecorder
+from .recorder import PixelRenderTransform, TensorDictRecorder, VideoRecorder
diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py
index a486b689feb..079c8b71e12 100644
--- a/torchrl/record/recorder.py
+++ b/torchrl/record/recorder.py
@@ -6,14 +6,20 @@
import importlib.util
from copy import copy
-from typing import Optional, Sequence
+from typing import Callable, List, Optional, Sequence, Union
+import numpy as np
import torch
-from tensordict import TensorDictBase
+from tensordict import NonTensorData, TensorDict, TensorDictBase
from tensordict.utils import NestedKey
+from torchrl._utils import _can_be_pickled
+from torchrl.data import TensorSpec
+from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec
+from torchrl.data.utils import CloudpickleWrapper
+from torchrl.envs import EnvBase
from torchrl.envs.transforms import ObservationTransform, Transform
from torchrl.record.loggers import Logger
@@ -155,20 +161,22 @@ def skip(self, value):
self._skip = value
def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
+ if isinstance(observation, NonTensorData):
+ observation_trsf = torch.tensor(observation.data)
+ else:
+ observation_trsf = observation
self.count += 1
if self.count % self.skip == 0:
if (
- observation.ndim >= 3
- and observation.shape[-3] == 3
- and observation.shape[-2] > 3
- and observation.shape[-1] > 3
+ observation_trsf.ndim >= 3
+ and observation_trsf.shape[-3] == 3
+ and observation_trsf.shape[-2] > 3
+ and observation_trsf.shape[-1] > 3
):
# permute the channels to the last dim
- observation_trsf = observation.permute(
- *range(observation.ndim - 3), -2, -1, -3
+ observation_trsf = observation_trsf.permute(
+ *range(observation_trsf.ndim - 3), -2, -1, -3
)
- else:
- observation_trsf = observation
if not (
observation_trsf.shape[-1] == 3 or observation_trsf.ndimension() == 2
):
@@ -321,3 +329,209 @@ def _reset(
) -> TensorDictBase:
self._call(tensordict_reset)
return tensordict_reset
+
+
+class PixelRenderTransform(Transform):
+ """A transform to call render on the parent environment and register the pixel observation in the tensordict.
+
+ This transform offers an alternative to the ``from_pixels`` syntatic sugar when instantiating an environment
+ that offers rendering is expensive, or when ``from_pixels`` is not implemented.
+ It can be used within a single environment or over batched environments alike.
+
+ Args:
+ out_keys (List[NestedKey] or Nested): List of keys where to register the pixel observations.
+ preproc (Callable, optional): a preproc function. Can be used to reshape the observation, or apply
+ any other transformation that makes it possible to register it in the output data.
+ as_non_tensor (bool, optional): if ``True``, the data will be written as a :class:`~tensordict.NonTensorData`
+ thereby relaxing the shape requirements. If not provided, it will be inferred automatically from the
+ input data type and shape.
+ render_method (str, optional): the name of the render method. Defaults to ``"render"``.
+ **kwargs: additional keyword arguments to pass to the render function (e.g. ``mode="rgb_array"``).
+
+ Examples:
+ >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
+ >>> from torchrl.record.loggers import CSVLogger
+ >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
+ >>>
+ >>> def make_env():
+ >>> env = GymEnv("CartPole-v1", render_mode="rgb_array")
+ >>> env = env.append_transform(PixelRenderTransform())
+ >>> return env
+ >>>
+ >>> if __name__ == "__main__":
+ ... logger = CSVLogger("dummy", video_format="mp4")
+ ...
+ ... env = ParallelEnv(4, EnvCreator(make_env))
+ ...
+ ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
+ ... env.rollout(3)
+ ...
+ ... check_env_specs(env)
+ ...
+ ... r = env.rollout(30)
+ ... print(env)
+ ... env.transform.dump()
+ ... env.close()
+
+ This transform can also be used whenever a batched environment ``render()`` returns a single image:
+
+ Examples:
+ >>> from torchrl.envs import check_env_specs
+ >>> from torchrl.envs.libs.vmas import VmasEnv
+ >>> from torchrl.record.loggers import CSVLogger
+ >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
+ >>>
+ >>> env = VmasEnv(
+ ... scenario="flocking",
+ ... num_envs=32,
+ ... continuous_actions=True,
+ ... max_steps=200,
+ ... device="cpu",
+ ... seed=None,
+ ... # Scenario kwargs
+ ... n_agents=5,
+ ... )
+ >>>
+ >>> logger = CSVLogger("dummy", video_format="mp4")
+ >>>
+ >>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy()))
+ >>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
+ >>>
+ >>> check_env_specs(env)
+ >>>
+ >>> r = env.rollout(30)
+ >>> env.transform[-1].dump()
+
+ The transform can be disabled using the :meth:`~torchrl.record.PixelRenderTransform.switch` method, which will
+ turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behaviour).
+ Since transforms are :class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control
+ this behaviour:
+
+ >>> def switch(module):
+ ... if isinstance(module, PixelRenderTransform):
+ ... module.switch()
+ >>> env.apply(switch)
+
+ """
+
+ def __init__(
+ self,
+ out_keys: List[NestedKey] = None,
+ preproc: Callable[
+ [np.ndarray | torch.Tensor], np.ndarray | torch.Tensor
+ ] = None,
+ as_non_tensor: bool = None,
+ render_method: str = "render",
+ **kwargs,
+ ) -> None:
+ if out_keys is None:
+ out_keys = ["pixels"]
+ elif isinstance(out_keys, (str, tuple)):
+ out_keys = [out_keys]
+ if len(out_keys) != 1:
+ raise RuntimeError(
+ f"Expected one and only one out_key, got out_keys={out_keys}"
+ )
+ if preproc is not None and not _can_be_pickled(preproc):
+ preproc = CloudpickleWrapper(preproc)
+ self.preproc = preproc
+ self.as_non_tensor = as_non_tensor
+ self.kwargs = kwargs
+ self.render_method = render_method
+ self._enabled = True
+ super().__init__(in_keys=[], out_keys=out_keys)
+
+ def _reset(
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
+ ) -> TensorDictBase:
+ return self._call(tensordict_reset)
+
+ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
+ if not self._enabled:
+ return tensordict
+
+ array = getattr(self.parent, self.render_method)(**self.kwargs)
+ if self.preproc:
+ array = self.preproc(array)
+ if self.as_non_tensor is None:
+ if isinstance(array, list):
+ if isinstance(array[0], np.ndarray):
+ array = np.asarray(array)
+ else:
+ array = torch.as_tensor(array)
+ if (
+ array.ndim == 3
+ and array.shape[-1] == 3
+ and self.parent.batch_size != ()
+ ):
+ self.as_non_tensor = True
+ else:
+ self.as_non_tensor = False
+ if not self.as_non_tensor:
+ try:
+ tensordict.set(self.out_keys[0], array)
+ except Exception:
+ raise RuntimeError(
+ f"An exception was raised while writing the rendered array "
+ f"(shape={getattr(array, 'shape', None)}, dtype={getattr(array, 'dtype', None)}) in the tensordict with shape {tensordict.shape}. "
+ f"Consider adapting your preproc function in {type(self).__name__}. You can also "
+ f"pass keyword arguments to the render function of the parent environment, or save "
+ f"this observation as a non-tensor data with as_non_tensor=True."
+ )
+ else:
+ tensordict.set_non_tensor(self.out_keys[0], array)
+ return tensordict
+
+ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
+ # Adds the pixel observation spec by calling render on the parent env
+ switch = False
+ if not self.enabled:
+ switch = True
+ self.switch()
+ parent = self.parent
+ td_in = TensorDict({}, batch_size=parent.batch_size, device=parent.device)
+ self._call(td_in)
+ obs = td_in.get(self.out_keys[0])
+ if isinstance(obs, NonTensorData):
+ spec = NonTensorSpec(device=obs.device, dtype=obs.dtype, shape=obs.shape)
+ else:
+ spec = UnboundedContinuousTensorSpec(
+ device=obs.device, dtype=obs.dtype, shape=obs.shape
+ )
+ observation_spec[self.out_keys[0]] = spec
+ if switch:
+ self.switch()
+ return observation_spec
+
+ def switch(self, mode: str | bool = None):
+ """Sets the transform on or off.
+
+ Args:
+ mode (str or bool, optional): if provided, sets the switch to the desired mode.
+ ``"on"``, ``"off"``, ``True`` and ``False`` are accepted values.
+ By default, ``switch`` sets the mode to the opposite of the current one.
+
+ """
+ if mode is None:
+ mode = not self._enabled
+ if not isinstance(mode, bool):
+ if mode not in ("on", "off"):
+ raise ValueError("mode must be either 'on' or 'off', or a boolean.")
+ mode = mode == "on"
+ self._enabled = mode
+
+ @property
+ def enabled(self) -> bool:
+ """Whether the recorder is enabled."""
+ return self._enabled
+
+ def set_container(self, container: Union[Transform, EnvBase]) -> None:
+ out = super().set_container(container)
+ if isinstance(self.parent, EnvBase):
+ # Start the env if needed
+ method = getattr(self.parent, self.render_method, None)
+ if method is None or not callable(method):
+ raise ValueError(
+ f"The render method must exist and be a callable. Got render={method}."
+ )
+ return out