diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 1211a37a8..7ef24f2de 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -168,42 +168,31 @@ def _reset(self, tensordict: Optional[TensorDict] = None): def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() - # step expects shape=(4,) so we pad if necessary + assert action.ndim == 1 # TODO(rcadene): add info["is_success"] and info["success"] ? - sum_reward = 0 - if action.ndim == 1: - action = einops.repeat(action, "c -> t c", t=self.frame_skip) - else: - if self.frame_skip > 1: - raise NotImplementedError() - - num_action_steps = action.shape[0] - for i in range(num_action_steps): - _, reward, discount, raw_obs = self._env.step(action[i]) - del discount # not used + _, reward, _, raw_obs = self._env.step(action) - # TOOD(rcadene): add an enum - success = done = reward == 4 - sum_reward += reward - obs = self._format_raw_obs(raw_obs) + # TODO(rcadene): add an enum + success = done = reward == 4 + obs = self._format_raw_obs(raw_obs) - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]["top"]) - stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]["top"]) + stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))} + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs - self.call_rendering_hooks() + self.call_rendering_hooks() td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([sum_reward], dtype=torch.float32), + "reward": torch.tensor([reward], dtype=torch.float32), # succes and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([success], dtype=torch.bool), diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 921cbad74..d6b294ebe 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,15 +1,17 @@ from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv -def make_env(cfg, transform=None): +def make_env(cfg, seed=None, transform=None): + """ + Provide seed to override the seed in the cfg (useful for batched environments). + """ kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, - # TODO(rcadene): do we want a specific eval_env_seed? - "seed": cfg.seed, "num_prev_obs": cfg.n_obs_steps - 1, + "seed": seed if seed is not None else cfg.seed, } if cfg.env.name == "simxarm": diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 4a7ccb2c8..2fe05233d 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -2,7 +2,6 @@ from collections import deque from typing import Optional -import einops import torch from tensordict import TensorDict from torchrl.data.tensor_specs import ( @@ -120,40 +119,30 @@ def _reset(self, tensordict: Optional[TensorDict] = None): def _step(self, tensordict: TensorDict): td = tensordict action = td["action"].numpy() - # step expects shape=(4,) so we pad if necessary + assert action.ndim == 1 # TODO(rcadene): add info["is_success"] and info["success"] ? - sum_reward = 0 - if action.ndim == 1: - action = einops.repeat(action, "c -> t c", t=self.frame_skip) - else: - if self.frame_skip > 1: - raise NotImplementedError() + raw_obs, reward, done, info = self._env.step(action) - num_action_steps = action.shape[0] - for i in range(num_action_steps): - raw_obs, reward, done, info = self._env.step(action[i]) - sum_reward += reward + obs = self._format_raw_obs(raw_obs) - obs = self._format_raw_obs(raw_obs) + if self.num_prev_obs > 0: + stacked_obs = {} + if "image" in obs: + self._prev_obs_image_queue.append(obs["image"]) + stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) + if "state" in obs: + self._prev_obs_state_queue.append(obs["state"]) + stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) + obs = stacked_obs - if self.num_prev_obs > 0: - stacked_obs = {} - if "image" in obs: - self._prev_obs_image_queue.append(obs["image"]) - stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue)) - if "state" in obs: - self._prev_obs_state_queue.append(obs["state"]) - stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue)) - obs = stacked_obs - - self.call_rendering_hooks() + self.call_rendering_hooks() td = TensorDict( { "observation": TensorDict(obs, batch_size=[]), - "reward": torch.tensor([sum_reward], dtype=torch.float32), - # succes and done are true when coverage > self.success_threshold in env + "reward": torch.tensor([reward], dtype=torch.float32), + # success and done are true when coverage > self.success_threshold in env "done": torch.tensor([done], dtype=torch.bool), "success": torch.tensor([done], dtype=torch.bool), }, diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py new file mode 100644 index 000000000..4956530a2 --- /dev/null +++ b/lerobot/common/policies/abstract.py @@ -0,0 +1,54 @@ +from abc import abstractmethod +from collections import deque + +import torch +from torch import Tensor, nn + + +class AbstractPolicy(nn.Module): + @abstractmethod + def update(self, replay_buffer, step): + """One step of the policy's learning algorithm.""" + pass + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + self.load_state_dict(d) + + @abstractmethod + def select_action(self, observation) -> Tensor: + """Select an action (or trajectory of actions) based on an observation during rollout. + + Should return a (batch_size, n_action_steps, *) tensor of actions. + """ + pass + + def forward(self, *args, **kwargs): + """Inference step that makes multi-step policies compatible with their single-step environments. + + WARNING: In general, this should not be overriden. + + Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit + into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an + observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment + observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that + the subclass doesn't have to. + + This method effectively wraps the `select_action` method of the subclass. The following assumptions are made: + 1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is + the action trajectory horizon and * is the action dimensions. + 2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined. + """ + n_action_steps_attr = "n_action_steps" + if not hasattr(self, n_action_steps_attr): + raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute") + if not hasattr(self, "_action_queue"): + self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr)) + if len(self._action_queue) == 0: + # Each element in the queue has shape (B, *). + self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1)) + + return self._action_queue.popleft() diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index d011cb762..e87f155e6 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -2,10 +2,10 @@ import time import torch -import torch.nn as nn import torch.nn.functional as F # noqa: N812 import torchvision.transforms as transforms +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.act.detr_vae import build @@ -40,7 +40,7 @@ def kl_divergence(mu, logvar): return total_kld, dimension_wise_kld, mean_kld -class ActionChunkingTransformerPolicy(nn.Module): +class ActionChunkingTransformerPolicy(AbstractPolicy): def __init__(self, cfg, device, n_action_steps=1): super().__init__() self.cfg = cfg @@ -147,7 +147,7 @@ def compute_loss(self, batch): return loss @torch.no_grad() - def forward(self, observation, step_count): + def select_action(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 3df76aa4a..db004a719 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -3,14 +3,14 @@ import hydra import torch -import torch.nn as nn +from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder -class DiffusionPolicy(nn.Module): +class DiffusionPolicy(AbstractPolicy): def __init__( self, cfg, @@ -44,6 +44,7 @@ def __init__( **cfg_obs_encoder, ) + self.n_action_steps = n_action_steps # needed for the parent class self.diffusion = DiffusionUnetImagePolicy( shape_meta=shape_meta, noise_scheduler=noise_scheduler, @@ -93,21 +94,16 @@ def __init__( ) @torch.no_grad() - def forward(self, observation, step_count): + def select_action(self, observation, step_count): # TODO(rcadene): remove unused step_count del step_count - # TODO(rcadene): remove unsqueeze hack to add bsize=1 - observation["image"] = observation["image"].unsqueeze(0) - observation["state"] = observation["state"].unsqueeze(0) - obs_dict = { "image": observation["image"], "agent_pos": observation["state"], } out = self.diffusion.predict_action(obs_dict) - - action = out["action"].squeeze(0) + action = out["action"] return action def update(self, replay_buffer, step): diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index ae9888a50..48955459f 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -9,6 +9,7 @@ import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h +from lerobot.common.policies.abstract import AbstractPolicy FIRST_FRAME = 0 @@ -85,7 +86,7 @@ def Q(self, z, a, return_type): # noqa: N802 return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2 -class TDMPC(nn.Module): +class TDMPC(AbstractPolicy): """Implementation of TD-MPC learning + inference.""" def __init__(self, cfg, device): @@ -124,7 +125,7 @@ def load(self, fp): self.model_target.load_state_dict(d["model_target"]) @torch.no_grad() - def forward(self, observation, step_count): + def select_action(self, observation, step_count): t0 = step_count.item() == 0 # TODO(rcadene): remove unsqueeze hack... diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 6841cb828..2a7aab6c3 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -10,6 +10,8 @@ hydra: name: default seed: 1337 +# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index +rollout_batch_size: 10 device: cuda # cpu prefetch: 4 eval_freq: ??? diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7ba2812e2..e9d57cba1 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -9,7 +9,8 @@ import torch import tqdm from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase +from torchrl.envs import EnvBase, SerialEnv +from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env @@ -23,7 +24,7 @@ def write_video(video_path, stacked_frames, fps): def eval_policy( - env: EnvBase, + env: BatchedEnvBase, policy: TensorDictModule = None, num_episodes: int = 10, max_steps: int = 30, @@ -36,45 +37,55 @@ def eval_policy( sum_rewards = [] max_rewards = [] successes = [] - threads = [] - for i in tqdm.tqdm(range(num_episodes)): + threads = [] # for video saving threads + episode_counter = 0 # for saving the correct number of videos + + # TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than + # needed as I'm currently taking a ceil. + for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))): ep_frames = [] - if save_video or (return_first_video and i == 0): - def render_frame(env): + def maybe_render_frame(env: EnvBase, _): + if save_video or (return_first_video and i == 0): # noqa: B023 ep_frames.append(env.render()) # noqa: B023 - env.register_rendering_hook(render_frame) - with torch.inference_mode(): rollout = env.rollout( max_steps=max_steps, policy=policy, auto_cast_to_device=True, + callback=maybe_render_frame, ) # print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()])) - ep_sum_reward = rollout["next", "reward"].sum() - ep_max_reward = rollout["next", "reward"].max() - ep_success = rollout["next", "success"].any() - sum_rewards.append(ep_sum_reward.item()) - max_rewards.append(ep_max_reward.item()) - successes.append(ep_success.item()) + batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1) + batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0] + batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1) + sum_rewards.extend(batch_sum_reward.tolist()) + max_rewards.extend(batch_max_reward.tolist()) + successes.extend(batch_success.tolist()) if save_video or (return_first_video and i == 0): - stacked_frames = np.stack(ep_frames) + batch_stacked_frames = np.stack(ep_frames) # (t, b, *) + batch_stacked_frames = batch_stacked_frames.transpose( + 1, 0, *range(2, batch_stacked_frames.ndim) + ) # (b, t, *) if save_video: - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{i}.mp4" - thread = threading.Thread( - target=write_video, - args=(str(video_path), stacked_frames, fps), - ) - thread.start() - threads.append(thread) + for stacked_frames in batch_stacked_frames: + if episode_counter >= num_episodes: + continue + video_dir.mkdir(parents=True, exist_ok=True) + video_path = video_dir / f"eval_episode_{episode_counter}.mp4" + thread = threading.Thread( + target=write_video, + args=(str(video_path), stacked_frames, fps), + ) + thread.start() + threads.append(thread) + episode_counter += 1 if return_first_video and i == 0: - first_video = stacked_frames.transpose(0, 3, 1, 2) + first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2) env.reset_rendering_hooks() @@ -82,9 +93,9 @@ def render_frame(env): thread.join() info = { - "avg_sum_reward": np.nanmean(sum_rewards), - "avg_max_reward": np.nanmean(max_rewards), - "pc_success": np.nanmean(successes) * 100, + "avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]), + "avg_max_reward": np.nanmean(max_rewards[:num_episodes]), + "pc_success": np.nanmean(successes[:num_episodes]) * 100, "eval_s": time.time() - start, "eval_ep_s": (time.time() - start) / num_episodes, } @@ -119,7 +130,14 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = make_env(cfg, transform=offline_buffer.transform) + env = SerialEnv( + cfg.rollout_batch_size, + create_env_fn=make_env, + create_env_kwargs=[ + {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} + for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + ], + ) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) @@ -138,7 +156,7 @@ def eval(cfg: dict, out_dir=None): save_video=True, video_dir=Path(out_dir) / "eval", fps=cfg.env.fps, - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, num_episodes=cfg.eval_episodes, ) print(metrics) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index c063caf87..579f5a585 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -7,6 +7,7 @@ from tensordict.nn import TensorDictModule from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers import PrioritizedSliceSampler +from torchrl.envs import SerialEnv from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env @@ -148,6 +149,14 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_env") env = make_env(cfg, transform=offline_buffer.transform) + env = SerialEnv( + cfg.rollout_batch_size, + create_env_fn=make_env, + create_env_kwargs=[ + {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} + for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + ], + ) logging.info("make_policy") policy = make_policy(cfg) @@ -191,7 +200,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env, td_policy, num_episodes=cfg.eval_episodes, - max_steps=cfg.env.episode_length // cfg.n_action_steps, + max_steps=cfg.env.episode_length, return_first_video=True, video_dir=Path(out_dir) / "eval", save_video=True, diff --git a/tests/test_policies.py b/tests/test_policies.py index f00429bcb..7d9a4dced 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,7 +1,15 @@ + import pytest +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +import torch +from torchrl.data import UnboundedContinuousTensorSpec +from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.abstract import AbstractPolicy + from .utils import DEVICE, init_config @@ -23,3 +31,75 @@ def test_factory(env_name, policy_name): ] ) policy = make_policy(cfg) + + +def test_abstract_policy_forward(): + """ + Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that: + - The policy is invoked the expected number of times during a rollout. + - The environment's termination condition is respected even when part way through an action trajectory. + - The observations are returned correctly. + """ + + n_action_steps = 8 # our test policy will output 8 action step horizons + terminate_at = 10 # some number that is more than n_action_steps but not a multiple + rollout_max_steps = terminate_at + 1 # some number greater than terminate_at + + # A minimal environment for testing. + class StubEnv(EnvBase): + + def __init__(self): + super().__init__() + self.action_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + + def _step(self, tensordict: TensorDict) -> TensorDict: + self.invocation_count += 1 + return TensorDict( + { + "observation": torch.tensor([self.invocation_count]), + "reward": torch.tensor([self.invocation_count]), + "terminated": torch.tensor( + tensordict["action"].item() == terminate_at + ), + } + ) + + def _reset(self, tensordict: TensorDict) -> TensorDict: + self.invocation_count = 0 + return TensorDict( + { + "observation": torch.tensor([self.invocation_count]), + "reward": torch.tensor([self.invocation_count]), + } + ) + + def _set_seed(self, seed: int | None): + return + + + class StubPolicy(AbstractPolicy): + def __init__(self): + super().__init__() + self.n_action_steps = n_action_steps + self.n_policy_invocations = 0 + + def select_action(self): + self.n_policy_invocations += 1 + return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0) + + + env = StubEnv() + policy = StubPolicy() + policy = TensorDictModule( + policy, + in_keys=[], + out_keys=["action"], + ) + + # Keep track to make sure the policy is called the expected number of times + rollout = env.rollout(rollout_max_steps, policy) + + assert len(rollout) == terminate_at + 1 # +1 for the reset observation + assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1 + assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))