diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 34b33c2e6..5db97497a 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -49,9 +49,9 @@ def __init__( @property def stats_patterns(self) -> dict: return { - ("observation", "state"): "b c -> 1 c", - ("observation", "image"): "b c h w -> 1 c 1 1", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("observation", "image"): "b c h w -> c", + ("action",): "b c -> c", } @property diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 52a5676ee..0637f8a37 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -113,11 +113,11 @@ def __init__( @property def stats_patterns(self) -> dict: d = { - ("observation", "state"): "b c -> 1 c", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" + d[("observation", "image", cam)] = "b c h w -> c" return d @property diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index d6b294ebe..de86b3ad8 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,17 +1,31 @@ from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv -def make_env(cfg, seed=None, transform=None): +def make_env(cfg, transform=None): """ Provide seed to override the seed in the cfg (useful for batched environments). """ + # assert cfg.rollout_batch_size == 1, \ + # """ + # For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not + # correctly handle terminated environments. If you really want to use a larger batch size, read on... + + # When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the + # first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first + # environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break` + # inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue + # to be called and the outputs will continue to be added to the rollout. + + # When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done. + # """ + kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, "num_prev_obs": cfg.n_obs_steps - 1, - "seed": seed if seed is not None else cfg.seed, + "seed": cfg.seed, } if cfg.env.name == "simxarm": @@ -33,22 +47,33 @@ def make_env(cfg, seed=None, transform=None): else: raise ValueError(cfg.env.name) - env = clsfunc(**kwargs) + def _make_env(seed): + nonlocal kwargs + kwargs["seed"] = seed + env = clsfunc(**kwargs) + + # limit rollout to max_steps + env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) + if transform is not None: + # useful to add normalization + if isinstance(transform, Compose): + for tf in transform: + env.append_transform(tf.clone()) + elif isinstance(transform, Transform): + env.append_transform(transform.clone()) + else: + raise NotImplementedError() - if transform is not None: - # useful to add normalization - if isinstance(transform, Compose): - for tf in transform: - env.append_transform(tf.clone()) - elif isinstance(transform, Transform): - env.append_transform(transform.clone()) - else: - raise NotImplementedError() + return env - return env + # return SerialEnv( + # cfg.rollout_batch_size, + # create_env_fn=_make_env, + # create_env_kwargs={ + # "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + # }, + # ) # def make_env(env_name, frame_skip, device, is_test=False): diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 9c652c0a3..ca2d8570a 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -30,7 +30,7 @@ def select_action(self, observation) -> Tensor: Should return a (batch_size, n_action_steps, *) tensor of actions. """ - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> Tensor: """Inference step that makes multi-step policies compatible with their single-step environments. WARNING: In general, this should not be overriden. diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 5cc8acd24..27b75c88d 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -11,14 +11,16 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index -rollout_batch_size: 10 +# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See +# `lerobot.common.envs.factory.make_env` for more information. +rollout_batch_size: 1 device: cuda # cpu prefetch: 4 eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: false +save_model: true save_buffer: false train_steps: ??? fps: ??? @@ -31,7 +33,7 @@ env: ??? policy: ??? wandb: - enable: true + enable: false # Set to true to disable saving an artifact despite save_model == True disable_artifact: false project: lerobot diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 0dae5056d..ce8acbd47 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0 obs_as_global_cond: True eval_episodes: 1 -eval_freq: 10000 -save_freq: 100000 +eval_freq: 5000 +save_freq: 5000 log_freq: 250 offline_steps: 1344000 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7cfb796af..c0199c0c4 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -9,7 +9,7 @@ import torch import tqdm from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase, SerialEnv +from torchrl.envs import EnvBase from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer @@ -131,14 +131,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = SerialEnv( - cfg.rollout_batch_size, - create_env_fn=make_env, - create_env_kwargs=[ - {"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform} - for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) + env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2c7bb5751..5ecd616d4 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -7,7 +7,6 @@ from tensordict.nn import TensorDictModule from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers import PrioritizedSliceSampler -from torchrl.envs import SerialEnv from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env @@ -149,14 +148,6 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_env") env = make_env(cfg, transform=offline_buffer.transform) - env = SerialEnv( - cfg.rollout_batch_size, - create_env_fn=make_env, - create_env_kwargs=[ - {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} - for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) logging.info("make_policy") policy = make_policy(cfg) diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index 869d26cd0..f909ed075 100644 Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and b/tests/data/aloha_sim_insertion_human/stats.pth differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 037e02f0f..8846b8f65 100644 Binary files a/tests/data/pusht/stats.pth and b/tests/data/pusht/stats.pth differ diff --git a/tests/test_policies.py b/tests/test_policies.py index 92324485d..ee5abdb79 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,5 @@ +from omegaconf import open_dict import pytest from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -7,7 +8,8 @@ from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy - +from lerobot.common.envs.factory import make_env +from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.policies.abstract import AbstractPolicy from .utils import DEVICE, init_config @@ -30,7 +32,19 @@ def test_factory(env_name, policy_name): f"device={DEVICE}", ] ) + # Check that we can make the policy object. policy = make_policy(cfg) + # Check that we run select_action and get the appropriate output. + if env_name == "simxarm": + # TODO(rcadene): Not implemented + return + if policy_name == "tdmpc": + # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. + with open_dict(cfg): + cfg['n_obs_steps'] = 1 + offline_buffer = make_offline_buffer(cfg) + env = make_env(cfg, transform=offline_buffer.transform) + policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE)) def test_abstract_policy_forward():