diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index aec538773..529bf6db3 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -57,9 +57,9 @@ def __init__( @property def stats_patterns(self) -> dict: return { - ("observation", "state"): "b c -> 1 c", - ("observation", "image"): "b c h w -> 1 c 1 1", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("observation", "image"): "b c h w -> c 1 1", + ("action",): "b c -> c", } @property diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 829896592..ec58e7010 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -115,11 +115,11 @@ def __init__( @property def stats_patterns(self) -> dict: d = { - ("observation", "state"): "b c -> 1 c", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" + d[("observation", "image", cam)] = "b c h w -> c 1 1" return d @property diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index 0754fb768..a449e23fe 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -1,4 +1,3 @@ -import abc from collections import deque from typing import Optional @@ -27,7 +26,6 @@ def __init__( self.image_size = image_size self.num_prev_obs = num_prev_obs self.num_prev_action = num_prev_action - self._rendering_hooks = [] if pixels_only: assert from_pixels @@ -45,36 +43,20 @@ def __init__( raise NotImplementedError() # self._prev_action_queue = deque(maxlen=self.num_prev_action) - def register_rendering_hook(self, func): - self._rendering_hooks.append(func) - - def call_rendering_hooks(self): - for func in self._rendering_hooks: - func(self) - - def reset_rendering_hooks(self): - self._rendering_hooks = [] - - @abc.abstractmethod def render(self, mode="rgb_array", width=640, height=480): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _reset(self, tensordict: Optional[TensorDict] = None): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _step(self, tensordict: TensorDict): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _make_env(self): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _make_spec(self): - raise NotImplementedError() + raise NotImplementedError("Abstract method") - @abc.abstractmethod def _set_seed(self, seed: Optional[int]): - raise NotImplementedError() + raise NotImplementedError("Abstract method") diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 1211a37a8..e09564fbf 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -35,6 +35,8 @@ class AlohaEnv(AbstractEnv): + _reset_warning_issued = False + def __init__( self, task, @@ -120,90 +122,76 @@ def _format_raw_obs(self, raw_obs): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) - - # TODO(rcadene): do not use global variable for this - if "sim_transfer_cube" in self.task: - BOX_POSE[0] = sample_box_pose() # used in sim reset - elif "sim_insertion" in self.task: - BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset - - raw_obs = self._env.reset() - # TODO(rcadene): add assert - # assert self._current_seed == self._env._seed - - obs = self._format_raw_obs(raw_obs.observation) + if tensordict is not None and not AlohaEnv._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + AlohaEnv._reset_warning_issued = True + + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) + + # TODO(rcadene): do not use global variable for this + if "sim_transfer_cube" in self.task: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif "sim_insertion" in self.task: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + + raw_obs = self._env.reset() + # TODO(rcadene): add assert + # assert self._current_seed == self._env._seed + + obs = self._format_raw_obs(raw_obs.observation) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) - self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() - # step expects shape=(4,) so we pad if necessary + assert action.ndim == 1 # TODO(rcadene): add info["is_success"] and info["success"] ? - sum_reward = 0 - - if action.ndim == 1: - action = einops.repeat(action, "c -> t c", t=self.frame_skip) - else: - if self.frame_skip > 1: - raise NotImplementedError() - num_action_steps = action.shape[0] - for i in range(num_action_steps): - _, reward, discount, raw_obs = self._env.step(action[i]) - del discount # not used + _, reward, _, raw_obs = self._env.step(action) - # TOOD(rcadene): add an enum - success = done = reward == 4 - sum_reward += reward - obs = self._format_raw_obs(raw_obs) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]["top"]) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + # TODO(rcadene): add an enum + success = done = reward == 4 + obs = self._format_raw_obs(raw_obs) - self.call_rendering_hooks() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]["top"]) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([sum_reward], dtype=torch.float32), + "reward": torch.tensor([reward], dtype=torch.float32), # succes and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool), diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 921cbad74..e187d7131 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,14 +1,18 @@ +from torchrl.envs import SerialEnv from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv def make_env(cfg, transform=None): + """ + Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying + environments. The env therefore returns batches.` + """ + kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, - # TODO(rcadene): do we want a specific eval_env_seed? - "seed": cfg.seed, "num_prev_obs": cfg.n_obs_steps - 1, } @@ -31,22 +35,33 @@ def make_env(cfg, transform=None): else: raise ValueError(cfg.env.name) - env = clsfunc(**kwargs) + def _make_env(seed): + nonlocal kwargs + kwargs["seed"] = seed + env = clsfunc(**kwargs) + + # limit rollout to max_steps + env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) + if transform is not None: + # useful to add normalization + if isinstance(transform, Compose): + for tf in transform: + env.append_transform(tf.clone()) + elif isinstance(transform, Transform): + env.append_transform(transform.clone()) + else: + raise NotImplementedError() - if transform is not None: - # useful to add normalization - if isinstance(transform, Compose): - for tf in transform: - env.append_transform(tf.clone()) - elif isinstance(transform, Transform): - env.append_transform(transform.clone()) - else: - raise NotImplementedError() + return env - return env + return SerialEnv( + cfg.rollout_batch_size, + create_env_fn=_make_env, + create_env_kwargs=[ + {"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + ], + ) # def make_env(env_name, frame_skip, device, is_test=False): diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 4a7ccb2c8..f440d443d 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -1,8 +1,8 @@ import importlib +import logging from collections import deque from typing import Optional -import einops import torch from tensordict import TensorDict from torchrl.data.tensor_specs import ( @@ -20,6 +20,8 @@ class PushtEnv(AbstractEnv): + _reset_warning_issued = False + def __init__( self, task="pusht", @@ -80,80 +82,67 @@ def _format_raw_obs(self, raw_obs): return obs def _reset(self, tensordict: Optional[TensorDict] = None): - td = tensordict - if td is None or td.is_empty(): - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) - raw_obs = self._env.reset() - assert self._current_seed == self._env._seed - - obs = self._format_raw_obs(raw_obs) + if tensordict is not None and not PushtEnv._reset_warning_issued: + logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") + PushtEnv._reset_warning_issued = True + + # we need to handle seed iteration, since self._env.reset() rely an internal _seed. + self._current_seed += 1 + self.set_seed(self._current_seed) + raw_obs = self._env.reset() + assert self._current_seed == self._env._seed + + obs = self._format_raw_obs(raw_obs) + + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue = deque( + [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue = deque( + [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) + ) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue = deque( - [obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue = deque( - [obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1) - ) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - td = TensorDict( - { - "observation": TensorDict(obs, batch_size=[]), - "done": torch.tensor([False], dtype=torch.bool), - }, - batch_size=[], - ) - else: - raise NotImplementedError() + td = TensorDict( + { + "observation": TensorDict(obs, batch_size=[]), + "done": torch.tensor([False], dtype=torch.bool), + }, + batch_size=[], + ) - self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() - # step expects shape=(4,) so we pad if necessary + assert action.ndim == 1 # TODO(rcadene): add info["is_success"] and info["success"] ? - sum_reward = 0 - if action.ndim == 1: - action = einops.repeat(action, "c -> t c", t=self.frame_skip) - else: - if self.frame_skip > 1: - raise NotImplementedError() - - num_action_steps = action.shape[0] - for i in range(num_action_steps): - raw_obs, reward, done, info = self._env.step(action[i]) - sum_reward += reward + raw_obs, reward, done, info = self._env.step(action) - obs = self._format_raw_obs(raw_obs) - - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + obs = self._format_raw_obs(raw_obs) - self.call_rendering_hooks() + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([sum_reward], dtype=torch.float32), - # succes and done are true when coverage > self.success_threshold in env + "reward": torch.tensor([reward], dtype=torch.float32), + # success and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([done], dtype=torch.bool), }, diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index d06126257..eac3666d7 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -118,7 +118,6 @@ def _reset(self, tensordict: Optional[TensorDict] = None): else: raise NotImplementedError() - self.call_rendering_hooks() return td def _step(self, tensordict: TensorDict): @@ -152,8 +151,6 @@ def _step(self, tensordict: TensorDict): stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) obs = stacked_obs - self.call_rendering_hooks() - td = TensorDict( { "observation": self._format_raw_obs(raw_obs), diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py new file mode 100644 index 000000000..e9c331a0e --- /dev/null +++ b/lerobot/common/policies/abstract.py @@ -0,0 +1,70 @@ +from collections import deque + +import torch +from torch import Tensor, nn + + +class AbstractPolicy(nn.Module): + """Base policy which all policies should be derived from. + + The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its + documentation for more information. + """ + + def __init__(self, n_action_steps: int | None): + """ + n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single + action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then + adds that dimension. + """ + super().__init__() + self.n_action_steps = n_action_steps + self.clear_action_queue() + + def update(self, replay_buffer, step): + """One step of the policy's learning algorithm.""" + raise NotImplementedError("Abstract method") + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + def select_actions(self, observation) -> Tensor: + """Select an action (or trajectory of actions) based on an observation during rollout. + + If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of + actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions. + """ + raise NotImplementedError("Abstract method") + + def clear_action_queue(self): + """This should be called whenever the environment is reset.""" + if self.n_action_steps is not None: + self._action_queue = deque([], maxlen=self.n_action_steps) + + def forward(self, *args, **kwargs) -> Tensor: + """Inference step that makes multi-step policies compatible with their single-step environments. + + WARNING: In general, this should not be overriden. + + Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit + into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an + observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment + observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that + the subclass doesn't have to. + + This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made: + 1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is + the action trajectory horizon and * is the action dimensions. + 2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined. + """ + if self.n_action_steps is None: + return self.select_actions(*args, **kwargs) + if len(self._action_queue) == 0: + # `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape + # (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1)) + return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index d011cb762..539cdcf54 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -2,10 +2,10 @@ import time import torch -import torch.nn as nn import torch.nn.functional as F # noqa: N812 import torchvision.transforms as transforms +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.act.detr_vae import build @@ -40,9 +40,9 @@ def kl_divergence(mu, logvar): return total_kld, dimension_wise_kld, mean_kld -class ActionChunkingTransformerPolicy(nn.Module): +class ActionChunkingTransformerPolicy(AbstractPolicy): def __init__(self, cfg, device, n_action_steps=1): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg self.n_action_steps = n_action_steps self.device = device @@ -147,16 +147,15 @@ def compute_loss(self, batch): return loss @torch.no_grad() - def forward(self, observation, step_count): + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") + # TODO(rcadene): remove unused step_count del step_count self.eval() - # TODO(rcadene): remove unsqueeze hack to add bsize=1 - observation["image", "top"] = observation["image", "top"].unsqueeze(0) - # observation["state"] = observation["state"].unsqueeze(0) - # TODO(rcadene): remove hack # add 1 camera dimension observation["image", "top"] = observation["image", "top"].unsqueeze(1) @@ -180,11 +179,8 @@ def forward(self, observation, step_count): # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) - # remove bsize=1 - action = action.squeeze(0) - # take first predicted action or n first actions - action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps] + action = action[: self.n_action_steps] return action def _forward(self, qpos, image, actions=None, is_pad=None): diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 3df76aa4a..2c47f172e 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -3,14 +3,14 @@ import hydra import torch -import torch.nn as nn +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder -class DiffusionPolicy(nn.Module): +class DiffusionPolicy(AbstractPolicy): def __init__( self, cfg, @@ -34,7 +34,7 @@ def __init__( # parameters passed to step **kwargs, ): - super().__init__() + super().__init__(n_action_steps) self.cfg = cfg noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) @@ -93,21 +93,16 @@ def __init__( ) @torch.no_grad() - def forward(self, observation, step_count): + def select_actions(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count - # TODO(rcadene): remove unsqueeze hack to add bsize=1 - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) - obs_dict = { "image": observation["image"], "agent_pos": observation["state"], } out = self.diffusion.predict_action(obs_dict) - - action = out["action"].squeeze(0) + action = out["action"] return action def update(self, replay_buffer, step): diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index c5e45300c..085baab58 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,4 +1,7 @@ def make_policy(cfg): + if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1: + raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.") + if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPC diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index ae9888a50..320f6f2bd 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -9,6 +9,7 @@ import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h +from lerobot.common.policies.abstract import AbstractPolicy FIRST_FRAME = 0 @@ -85,11 +86,11 @@ def Q(self, z, a, return_type): # noqa: N802 return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 -class TDMPC(nn.Module): +class TDMPC(AbstractPolicy): """Implementation of TD-MPC learning + inference.""" def __init__(self, cfg, device): - super().__init__() + super().__init__(None) self.action_dim = cfg.action_dim self.cfg = cfg @@ -124,20 +125,19 @@ def load(self, fp): self.model_target.load_state_dict(d["model_target"]) @torch.no_grad() - def forward(self, observation, step_count): - t0 = step_count.item() == 0 + def select_actions(self, observation, step_count): + if observation["image"].shape[0] != 1: + raise NotImplementedError("Batch size > 1 not handled") - # TODO(rcadene): remove unsqueeze hack... - if observation["image"].ndim == 3: - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) + t0 = step_count.item() == 0 obs = { # TODO(rcadene): remove contiguous hack... "rgb": observation["image"].contiguous(), "state": observation["state"].contiguous(), } - action = self.act(obs, t0=t0, step=self.step.item()) + # Note: unsqueeze needed because `act` still uses non-batch logic. + action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0) return action @torch.no_grad() diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 9a97b50d1..52fd1d601 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -10,6 +10,9 @@ hydra: name: default seed: 1337 +# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index +# NOTE: only diffusion policy supports rollout_batch_size > 1 +rollout_batch_size: 1 device: cuda # cpu prefetch: 4 eval_freq: ??? diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index dd2d68afd..41d58b914 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -3,6 +3,7 @@ import time from pathlib import Path +import einops import hydra import imageio import numpy as np @@ -10,10 +11,12 @@ import tqdm from tensordict.nn import TensorDictModule from torchrl.envs import EnvBase +from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy from lerobot.common.utils import init_logging, set_seed @@ -23,8 +26,8 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( - env: EnvBase, - policy: TensorDictModule = None, + env: BatchedEnvBase, + policy: AbstractPolicy, num_episodes: int = 10, max_steps: int = 30, save_video: bool = False, @@ -37,55 +40,75 @@ def eval_policy( sum_rewards = [] max_rewards = [] successes = [] - threads = [] - for i in tqdm.tqdm(range(num_episodes)): + threads = [] # for video saving threads + episode_counter = 0 # for saving the correct number of videos + + # TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than + # needed as I'm currently taking a ceil. + for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))): ep_frames = [] - if save_video or (return_first_video and i == 0): - def render_frame(env): + def maybe_render_frame(env: EnvBase, _): + if save_video or (return_first_video and i == 0): # noqa: B023 ep_frames.append(env.render()) # noqa: B023 - env.register_rendering_hook(render_frame) - with torch.inference_mode(): + # TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all + # envs are done the first time. But we only use the first rollout. This is a waste of compute. + policy.clear_action_queue() rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, + callback=maybe_render_frame, + break_when_any_done=env.batch_size[0] == 1, ) - # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() - sum_rewards.append(ep_sum_reward.item()) - max_rewards.append(ep_max_reward.item()) - successes.append(ep_success.item()) + # Figure out where in each rollout sequence the first done condition was encountered (results after this won't + # be included). + # Note: this assumes that the shape of the done key is (batch_size, max_steps, 1). + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + rollout_steps = rollout["next", "done"].shape[1] + done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps) + mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1) + batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum") + batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max") + batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any") + sum_rewards.extend(batch_sum_reward.tolist()) + max_rewards.extend(batch_max_reward.tolist()) + successes.extend(batch_success.tolist()) if save_video or (return_first_video and i == 0): - stacked_frames = np.stack(ep_frames) + batch_stacked_frames = np.stack(ep_frames) # (t, b, *) + batch_stacked_frames = batch_stacked_frames.transpose( + 1, 0, *range(2, batch_stacked_frames.ndim) + ) # (b, t, *) if save_video: - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{i}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames, fps), - ) - thread.start() - threads.append(thread) + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames[:done_index], fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 if return_first_video and i == 0: - first_video = stacked_frames.transpose(0, 3, 1, 2) - - env.reset_rendering_hooks() + first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) for thread in threads: thread.join() info = { - "avg_sum_reward": np.nanmean(sum_rewards), - "avg_max_reward": np.nanmean(max_rewards), - "pc_success": np.nanmean(successes) * 100, + "avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]), + "avg_max_reward": np.nanmean(max_rewards[:num_episodes]), + "pc_success": np.nanmean(successes[:num_episodes]) * 100, "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, } @@ -139,7 +162,7 @@ def eval(cfg: dict, out_dir=None): save_video=True, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, num_episodes=cfg.eval_episodes, ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 7af75391c..242c77bc6 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -112,6 +112,8 @@ def train(cfg: dict, out_dir=None, job_name=None): raise NotImplementedError() if job_name is None: raise NotImplementedError() + if cfg.online_steps > 0: + assert cfg.rollout_batch_size == 1, "rollout_batch_size > 1 not supported for online training steps" init_logging() @@ -192,7 +194,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env, td_policy, num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, return_first_video=True, video_dir=Path(out_dir) / "eval", save_video=True, @@ -218,11 +220,11 @@ def train(cfg: dict, out_dir=None, job_name=None): # TODO: add configurable number of rollout? (default=1) with torch.no_grad(): rollout = env.rollout( - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, policy=td_policy, auto_cast_to_device=True, ) - assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps + assert len(rollout) <= cfg.env.episode_length # set same episode index for all time steps contained in this rollout rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int) online_buffer.extend(rollout) diff --git a/poetry.lock b/poetry.lock index a76858bd0..ddb0a0e31 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -658,13 +658,13 @@ typing = ["typing-extensions (>=4.8)"] [[package]] name = "fsspec" -version = "2024.2.0" +version = "2024.3.1" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, - {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, ] [package.extras] @@ -1468,32 +1468,32 @@ setuptools = "*" [[package]] name = "numba" -version = "0.59.0" +version = "0.59.1" description = "compiling Python code using LLVM" optional = false python-versions = ">=3.9" files = [ - {file = "numba-0.59.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8d061d800473fb8fef76a455221f4ad649a53f5e0f96e3f6c8b8553ee6fa98fa"}, - {file = "numba-0.59.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c086a434e7d3891ce5dfd3d1e7ee8102ac1e733962098578b507864120559ceb"}, - {file = "numba-0.59.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9e20736bf62e61f8353fb71b0d3a1efba636c7a303d511600fc57648b55823ed"}, - {file = "numba-0.59.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e86e6786aec31d2002122199486e10bbc0dc40f78d76364cded375912b13614c"}, - {file = "numba-0.59.0-cp310-cp310-win_amd64.whl", hash = "sha256:0307ee91b24500bb7e64d8a109848baf3a3905df48ce142b8ac60aaa406a0400"}, - {file = "numba-0.59.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d540f69a8245fb714419c2209e9af6104e568eb97623adc8943642e61f5d6d8e"}, - {file = "numba-0.59.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1192d6b2906bf3ff72b1d97458724d98860ab86a91abdd4cfd9328432b661e31"}, - {file = "numba-0.59.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:90efb436d3413809fcd15298c6d395cb7d98184350472588356ccf19db9e37c8"}, - {file = "numba-0.59.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd3dac45e25d927dcb65d44fb3a973994f5add2b15add13337844afe669dd1ba"}, - {file = "numba-0.59.0-cp311-cp311-win_amd64.whl", hash = "sha256:753dc601a159861808cc3207bad5c17724d3b69552fd22768fddbf302a817a4c"}, - {file = "numba-0.59.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ce62bc0e6dd5264e7ff7f34f41786889fa81a6b860662f824aa7532537a7bee0"}, - {file = "numba-0.59.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8cbef55b73741b5eea2dbaf1b0590b14977ca95a13a07d200b794f8f6833a01c"}, - {file = "numba-0.59.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:70d26ba589f764be45ea8c272caa467dbe882b9676f6749fe6f42678091f5f21"}, - {file = "numba-0.59.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e125f7d69968118c28ec0eed9fbedd75440e64214b8d2eac033c22c04db48492"}, - {file = "numba-0.59.0-cp312-cp312-win_amd64.whl", hash = "sha256:4981659220b61a03c1e557654027d271f56f3087448967a55c79a0e5f926de62"}, - {file = "numba-0.59.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe4d7562d1eed754a7511ed7ba962067f198f86909741c5c6e18c4f1819b1f47"}, - {file = "numba-0.59.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6feb1504bb432280f900deaf4b1dadcee68812209500ed3f81c375cbceab24dc"}, - {file = "numba-0.59.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:944faad25ee23ea9dda582bfb0189fb9f4fc232359a80ab2a028b94c14ce2b1d"}, - {file = "numba-0.59.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5516a469514bfae52a9d7989db4940653a5cbfac106f44cb9c50133b7ad6224b"}, - {file = "numba-0.59.0-cp39-cp39-win_amd64.whl", hash = "sha256:32bd0a41525ec0b1b853da244808f4e5333867df3c43c30c33f89cf20b9c2b63"}, - {file = "numba-0.59.0.tar.gz", hash = "sha256:12b9b064a3e4ad00e2371fc5212ef0396c80f41caec9b5ec391c8b04b6eaf2a8"}, + {file = "numba-0.59.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97385a7f12212c4f4bc28f648720a92514bee79d7063e40ef66c2d30600fd18e"}, + {file = "numba-0.59.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0b77aecf52040de2a1eb1d7e314497b9e56fba17466c80b457b971a25bb1576d"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3476a4f641bfd58f35ead42f4dcaf5f132569c4647c6f1360ccf18ee4cda3990"}, + {file = "numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24"}, + {file = "numba-0.59.1-cp310-cp310-win_amd64.whl", hash = "sha256:990e395e44d192a12105eca3083b61307db7da10e093972ca285c85bef0963d6"}, + {file = "numba-0.59.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:43727e7ad20b3ec23ee4fc642f5b61845c71f75dd2825b3c234390c6d8d64051"}, + {file = "numba-0.59.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:411df625372c77959570050e861981e9d196cc1da9aa62c3d6a836b5cc338966"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2801003caa263d1e8497fb84829a7ecfb61738a95f62bc05693fcf1733e978e4"}, + {file = "numba-0.59.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dd2842fac03be4e5324ebbbd4d2d0c8c0fc6e0df75c09477dd45b288a0777389"}, + {file = "numba-0.59.1-cp311-cp311-win_amd64.whl", hash = "sha256:0594b3dfb369fada1f8bb2e3045cd6c61a564c62e50cf1f86b4666bc721b3450"}, + {file = "numba-0.59.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1cce206a3b92836cdf26ef39d3a3242fec25e07f020cc4feec4c4a865e340569"}, + {file = "numba-0.59.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8c8b4477763cb1fbd86a3be7050500229417bf60867c93e131fd2626edb02238"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d80bce4ef7e65bf895c29e3889ca75a29ee01da80266a01d34815918e365835"}, + {file = "numba-0.59.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f7ad1d217773e89a9845886401eaaab0a156a90aa2f179fdc125261fd1105096"}, + {file = "numba-0.59.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bf68f4d69dd3a9f26a9b23548fa23e3bcb9042e2935257b471d2a8d3c424b7f"}, + {file = "numba-0.59.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4e0318ae729de6e5dbe64c75ead1a95eb01fabfe0e2ebed81ebf0344d32db0ae"}, + {file = "numba-0.59.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0f68589740a8c38bb7dc1b938b55d1145244c8353078eea23895d4f82c8b9ec1"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:649913a3758891c77c32e2d2a3bcbedf4a69f5fea276d11f9119677c45a422e8"}, + {file = "numba-0.59.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9712808e4545270291d76b9a264839ac878c5eb7d8b6e02c970dc0ac29bc8187"}, + {file = "numba-0.59.1-cp39-cp39-win_amd64.whl", hash = "sha256:8d51ccd7008a83105ad6a0082b6a2b70f1142dc7cfd76deb8c5a862367eb8c86"}, + {file = "numba-0.59.1.tar.gz", hash = "sha256:76f69132b96028d2774ed20415e8c528a34e3299a40581bae178f0994a2f370b"}, ] [package.dependencies] @@ -2684,13 +2684,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", [[package]] name = "sentry-sdk" -version = "1.41.0" +version = "1.42.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"}, - {file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"}, + {file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"}, + {file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"}, ] [package.dependencies] @@ -2714,6 +2714,7 @@ grpcio = ["grpcio (>=1.21.1)"] httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] @@ -2829,18 +2830,18 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "69.1.1" +version = "69.2.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.1.1-py3-none-any.whl", hash = "sha256:02fa291a0471b3a18b2b2481ed902af520c69e8ae0919c13da936542754b4c56"}, - {file = "setuptools-69.1.1.tar.gz", hash = "sha256:5c0806c7d9af348e6dd3777b4f4dbb42c7ad85b190104837488eab9a7c945cf8"}, + {file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"}, + {file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"}, ] [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -2949,7 +2950,7 @@ mpmath = ">=0.19" [[package]] name = "tensordict" -version = "0.4.0+551331d" +version = "0.4.0+ca4256e" description = "" optional = false python-versions = "*" @@ -2970,7 +2971,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures type = "git" url = "https://github.com/pytorch/tensordict" reference = "HEAD" -resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a" +resolved_reference = "b4c91e8828c538ca0a50d8383fd99311a9afb078" [[package]] name = "termcolor" @@ -3311,18 +3312,18 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"] [[package]] name = "zipp" -version = "3.17.0" +version = "3.18.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, - {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, + {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, + {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [metadata] lock-version = "2.0" diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index 869d26cd0..d41ac18cd 100644 Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 037e02f0f..039d5db3d 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ diff --git a/tests/test_policies.py b/tests/test_policies.py index f00429bcb..e6cfdfbc9 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,25 +1,138 @@ +from omegaconf import open_dict import pytest +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +import torch +from torchrl.data import UnboundedContinuousTensorSpec +from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy +from lerobot.common.envs.factory import make_env +from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.policies.abstract import AbstractPolicy from .utils import DEVICE, init_config @pytest.mark.parametrize( - "env_name,policy_name", + "env_name,policy_name,extra_overrides", [ - ("simxarm", "tdmpc"), - ("pusht", "tdmpc"), - ("simxarm", "diffusion"), - ("pusht", "diffusion"), + ("simxarm", "tdmpc", ["policy.mpc=true"]), + ("pusht", "tdmpc", ["policy.mpc=false"]), + ("simxarm", "diffusion", []), + ("pusht", "diffusion", []), + ("aloha", "act", ["env.task=sim_insertion_scripted"]), ], ) -def test_factory(env_name, policy_name): +def test_concrete_policy(env_name, policy_name, extra_overrides): + """ + Tests: + - Making the policy object. + - Updating the policy. + - Using the policy to select actions at inference time. + """ cfg = init_config( overrides=[ f"env={env_name}", f"policy={policy_name}", f"device={DEVICE}", ] + + extra_overrides ) + # Check that we can make the policy object. policy = make_policy(cfg) + # Check that we run select_actions and get the appropriate output. + if env_name == "simxarm": + # TODO(rcadene): Not implemented + return + if policy_name == "tdmpc": + # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. + with open_dict(cfg): + cfg["n_obs_steps"] = 1 + offline_buffer = make_offline_buffer(cfg) + env = make_env(cfg, transform=offline_buffer.transform) + + if env_name != "aloha": + # TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError: + # seq_length as a list is not supported for now. + policy.update(offline_buffer, torch.tensor(0, device=DEVICE)) + + action = policy( + env.observation_spec.rand()["observation"].to(DEVICE), + torch.tensor(0, device=DEVICE), + ) + assert action.shape == env.action_spec.shape + + +def test_abstract_policy_forward(): + """ + Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that: + - The policy is invoked the expected number of times during a rollout. + - The environment's termination condition is respected even when part way through an action trajectory. + - The observations are returned correctly. + """ + + n_action_steps = 8 # our test policy will output 8 action step horizons + terminate_at = 10 # some number that is more than n_action_steps but not a multiple + rollout_max_steps = terminate_at + 1 # some number greater than terminate_at + + # A minimal environment for testing. + class StubEnv(EnvBase): + + def __init__(self): + super().__init__() + self.action_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + + def _step(self, tensordict: TensorDict) -> TensorDict: + self.invocation_count += 1 + return TensorDict( + { + "observation": torch.tensor([self.invocation_count]), + "reward": torch.tensor([self.invocation_count]), + "terminated": torch.tensor( + tensordict["action"].item() == terminate_at + ), + } + ) + + def _reset(self, tensordict: TensorDict) -> TensorDict: + self.invocation_count = 0 + return TensorDict( + { + "observation": torch.tensor([self.invocation_count]), + "reward": torch.tensor([self.invocation_count]), + } + ) + + def _set_seed(self, seed: int | None): + return + + class StubPolicy(AbstractPolicy): + def __init__(self): + super().__init__(n_action_steps) + self.n_policy_invocations = 0 + + def update(self): + pass + + def select_actions(self): + self.n_policy_invocations += 1 + return torch.stack( + [torch.tensor([i]) for i in range(self.n_action_steps)] + ).unsqueeze(0) + + env = StubEnv() + policy = StubPolicy() + policy = TensorDictModule( + policy, + in_keys=[], + out_keys=["action"], + ) + + # Keep track to make sure the policy is called the expected number of times + rollout = env.rollout(rollout_max_steps, policy) + + assert len(rollout) == terminate_at + 1 # +1 for the reset observation + assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 + assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))