From ea17f4ce501afe867c73954a86457f12a95fcf42 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 19 Mar 2024 16:02:09 +0000 Subject: [PATCH] backup wip --- lerobot/common/datasets/abstract.py | 6 +- lerobot/common/datasets/aloha.py | 6 +- lerobot/common/envs/factory.py | 55 +++++++++++++----- lerobot/common/policies/abstract.py | 2 +- lerobot/configs/default.yaml | 8 ++- lerobot/configs/policy/diffusion.yaml | 4 +- lerobot/scripts/eval.py | 11 +--- lerobot/scripts/train.py | 9 --- .../data/aloha_sim_insertion_human/stats.pth | Bin 4434 -> 4306 bytes tests/data/pusht/stats.pth | Bin 4306 -> 4242 bytes tests/test_policies.py | 16 ++++- 11 files changed, 71 insertions(+), 46 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 34b33c2e6..5db97497a 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -49,9 +49,9 @@ def __init__( @property def stats_patterns(self) -> dict: return { - ("observation", "state"): "b c -> 1 c", - ("observation", "image"): "b c h w -> 1 c 1 1", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("observation", "image"): "b c h w -> c", + ("action",): "b c -> c", } @property diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 52a5676ee..0637f8a37 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -113,11 +113,11 @@ def __init__( @property def stats_patterns(self) -> dict: d = { - ("observation", "state"): "b c -> 1 c", - ("action",): "b c -> 1 c", + ("observation", "state"): "b c -> c", + ("action",): "b c -> c", } for cam in CAMERAS[self.dataset_id]: - d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" + d[("observation", "image", cam)] = "b c h w -> c" return d @property diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index d6b294ebe..de86b3ad8 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,17 +1,31 @@ from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv -def make_env(cfg, seed=None, transform=None): +def make_env(cfg, transform=None): """ Provide seed to override the seed in the cfg (useful for batched environments). """ + # assert cfg.rollout_batch_size == 1, \ + # """ + # For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not + # correctly handle terminated environments. If you really want to use a larger batch size, read on... + + # When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the + # first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first + # environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break` + # inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue + # to be called and the outputs will continue to be added to the rollout. + + # When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done. + # """ + kwargs = { "frame_skip": cfg.env.action_repeat, "from_pixels": cfg.env.from_pixels, "pixels_only": cfg.env.pixels_only, "image_size": cfg.env.image_size, "num_prev_obs": cfg.n_obs_steps - 1, - "seed": seed if seed is not None else cfg.seed, + "seed": cfg.seed, } if cfg.env.name == "simxarm": @@ -33,22 +47,33 @@ def make_env(cfg, seed=None, transform=None): else: raise ValueError(cfg.env.name) - env = clsfunc(**kwargs) + def _make_env(seed): + nonlocal kwargs + kwargs["seed"] = seed + env = clsfunc(**kwargs) + + # limit rollout to max_steps + env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) - # limit rollout to max_steps - env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length)) + if transform is not None: + # useful to add normalization + if isinstance(transform, Compose): + for tf in transform: + env.append_transform(tf.clone()) + elif isinstance(transform, Transform): + env.append_transform(transform.clone()) + else: + raise NotImplementedError() - if transform is not None: - # useful to add normalization - if isinstance(transform, Compose): - for tf in transform: - env.append_transform(tf.clone()) - elif isinstance(transform, Transform): - env.append_transform(transform.clone()) - else: - raise NotImplementedError() + return env - return env + # return SerialEnv( + # cfg.rollout_batch_size, + # create_env_fn=_make_env, + # create_env_kwargs={ + # "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) + # }, + # ) # def make_env(env_name, frame_skip, device, is_test=False): diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 9c652c0a3..ca2d8570a 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -30,7 +30,7 @@ def select_action(self, observation) -> Tensor: Should return a (batch_size, n_action_steps, *) tensor of actions. """ - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> Tensor: """Inference step that makes multi-step policies compatible with their single-step environments. WARNING: In general, this should not be overriden. diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 5cc8acd24..27b75c88d 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -11,14 +11,16 @@ hydra: seed: 1337 # batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index -rollout_batch_size: 10 +# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See +# `lerobot.common.envs.factory.make_env` for more information. +rollout_batch_size: 1 device: cuda # cpu prefetch: 4 eval_freq: ??? save_freq: ??? eval_episodes: ??? save_video: false -save_model: false +save_model: true save_buffer: false train_steps: ??? fps: ??? @@ -31,7 +33,7 @@ env: ??? policy: ??? wandb: - enable: true + enable: false # Set to true to disable saving an artifact despite save_model == True disable_artifact: false project: lerobot diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 0dae5056d..ce8acbd47 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0 obs_as_global_cond: True eval_episodes: 1 -eval_freq: 10000 -save_freq: 100000 +eval_freq: 5000 +save_freq: 5000 log_freq: 250 offline_steps: 1344000 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7cfb796af..c0199c0c4 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -9,7 +9,7 @@ import torch import tqdm from tensordict.nn import TensorDictModule -from torchrl.envs import EnvBase, SerialEnv +from torchrl.envs import EnvBase from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer @@ -131,14 +131,7 @@ def eval(cfg: dict, out_dir=None): offline_buffer = make_offline_buffer(cfg) logging.info("make_env") - env = SerialEnv( - cfg.rollout_batch_size, - create_env_fn=make_env, - create_env_kwargs=[ - {"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform} - for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) + env = make_env(cfg, transform=offline_buffer.transform) if cfg.policy.pretrained_model_path: policy = make_policy(cfg) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 2c7bb5751..5ecd616d4 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -7,7 +7,6 @@ from tensordict.nn import TensorDictModule from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers import PrioritizedSliceSampler -from torchrl.envs import SerialEnv from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env @@ -149,14 +148,6 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_env") env = make_env(cfg, transform=offline_buffer.transform) - env = SerialEnv( - cfg.rollout_batch_size, - create_env_fn=make_env, - create_env_kwargs=[ - {"cfg": cfg, "seed": s, "transform": offline_buffer.transform} - for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size) - ], - ) logging.info("make_policy") policy = make_policy(cfg) diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth index 869d26cd0528611dc3ffb548c5b35d0cff50f6ef..f909ed075ce48cf7f677cc82a8859615b15924c1 100644 GIT binary patch delta 1186 zcmY*Ye@qis9KS-LKj_`HYi(7i45zlvA%*dy?cMbbNSsp3Dw82oXZW?aMF-e+h8Uu3 zYfxE`U?2@-xAa8-g;UG#GP`}1-Po0p?D zO^S)c44BVf;rC@!Hu!654^-D2*i@j!(C-OtT=+s#_`_{p;pdN=*d#{FPg7qo&6Jui z?QD2O)-Xni_GWes3EUgp*;1b}Tq_}sbnk1d7#rDJViN}HPO?=7Gjru5C2PHgxz~v-&ZLoAi z=A3HAaxWGJky~kh79&o{aoiRay*K$;p>gvm78BbyK2jQgGK5{gH%-`tv)4|tyB`{v ze7l<6w_fHNy~i;ivfGoJnE^P7?Sp{dE~hAWJa#qtKVOUVTTK0_29m*4nD~Zs;_d3p`P&Lb#CEm`eBwy zC-^ekB4Y_@cx7LytEsXp{Q6z5@T0GRHCuY=u`-40Q1>M!lBtpG!nL|uG@8V54Qr>V zWsVT7d3A;!Khi>_lpke+-A&Z6>pp#hvr&_lZu*a`Y5LZg>&_#16Z5!jmRi-Nr{PF% z=>|NA)BKQ}muJtl=RTj6?Qmq~WX3oEhZ20p3mUIceRmP$&e6?IEz!+|f zYfnH?e3e|Pr9Q%hj6^C9HDMRMf{r64VMcohom_fk2eA%yibQ&>7;Zu>1cAO61CsH5 z#9BmK{8 delta 1279 zcmY*Y3rt&87{0f(P-t&UEw&WOIyPY417VCuk$cX$YZY*HSqqGyVLEUNy0AE)F54t4 z;ZgQ5!Z?{5An#2>z=dh%vK4OYjES%WrHrA5x#2WRm=ea27&XTGm=11oZqE0A_k90< zzW=|0*uc7>1Ehkqt2`AtyUI>_a0R(j#3m7cMW%{)gMeEkPhgTrXGsQhW;ph#VOXR~q$df+nh5C9bNZ(xa~ZZ&&WyUzu&S5O_kCgzw2hagD5o zIecIs>(aM0t8J!{G|x1AseKkkJlEi*=GUO-e;yy1{;rUXPDwz;b=mN;--wPZMzYJN zZ$pPYM6Q{kdN32rFq*a9^shNvKx|tiSlir2%je@kL-sBDkJtaAiKubf@KzNFLuFvH zbDm!AIt=WKH2AS674RbU>wlTVeoiESdtvwJAXxCUp~CA;u>YbPm3=w`_ZVByV*CZz zTAc^IG41HwCn?A|*n(`%=5xLKx=a$C$NQT|QfFJG^?gkuBO7ja~@M!olJb$fLLbJ0IsmZD%`5@$W=K*DzY@AyMG7lR_mM{s}pY89BpD zJUqiT-FU|r6W#$!6Flsw{%`hS{}fAbPwu8AHtL6JUx(srFgEUB@7wRNk6XD)Zp04r zm3%*zQIGv{#^5<0Z#9p`S-ZUnR-U-x?V3??0!pk_#^ChO9R_n}Gi)?q_>sOv=*cVE z-y@!W@p-62TW%u=w~aro^>qzD<V!-|&YUjHC>Ta5>7$9_c%YlsgH0QI=_2`h zkkdCscf7F`RBOiRX2l_p@l6@X&t0Y`ZdQT0+ca?NDuwP_$h;`*S4>53O;1TnPRqKS#WBv*;`;ihQ*_GV6(Uq(K`|4$2chVRl?C^6u$2Z!*x1$;R? II#`VU3rHHKMF0Q* diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth index 037e02f0fbe9a84fb24d2119174addd4a5a913dd..8846b8f65ae9b52dc74d369c239d64f42ff214fb 100644 GIT binary patch delta 828 zcmZuvO=uHQ5PrMeWH-s~<_VJ8Ce5ZGSgkSD4@Dbk3JTJ68_`<9 zSPVV*7o9_kUM<*Ku~BKUqJp4aRNCMnhZZV$6;!M@!JU`VdT?Oa@0htC4ObgB#=s zlDWDCoHc;h5?T?TMLaF+L7YeYrG`flH#l0{^ADdrni2?^5`fBqdM^o*;d~}P96mLY z&kkDsSu2@nB(Usggo2|$zZw(frB~x-IweYa=4v?tC8wpn zFiByJ>3+(;DnLa`&6&NHK_UJ`(^uP$4X+C-9kqIF(c;zP8@9Jt;ne6LK=2|;vX5Yqe#qFcY*p_16 z^`U+&K}f|;$|l?|K>44(yNhb&NppM04c0yGDz^jm2Tv3ApS_X%`*j}{T z=nqnUMS^uT%Jo8n0mNMt>%s2}jc@j>i*9cqBW}8(`bq vuPD$XaeD@gZYz~pj>^3v^7(Yn1(l_ zH!~13mK54+Kv+QbS15k>G@f=;Vn!zD&NXlOOOXvVgz=< z+GH7Cc@}U;0l7>PY?H%z<)lHu!~t|12=jtm2?sMKPvBLM0fiZMc?tH(XMhHR!w+Qs zjLB?#@+{ym2Aj_@*^y6<1r%hH1^GeJRX}M__<*=fGMtmQ@>wv>nfwkY3JMj7DA#07 zemNF!H~?xx qlm%fKkTN(B;$fKl9vF270tSLRK=A+s0p4uvAi7|3p@1|SL=*sgDe|%a diff --git a/tests/test_policies.py b/tests/test_policies.py index 92324485d..ee5abdb79 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,5 @@ +from omegaconf import open_dict import pytest from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -7,7 +8,8 @@ from torchrl.envs import EnvBase from lerobot.common.policies.factory import make_policy - +from lerobot.common.envs.factory import make_env +from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.policies.abstract import AbstractPolicy from .utils import DEVICE, init_config @@ -30,7 +32,19 @@ def test_factory(env_name, policy_name): f"device={DEVICE}", ] ) + # Check that we can make the policy object. policy = make_policy(cfg) + # Check that we run select_action and get the appropriate output. + if env_name == "simxarm": + # TODO(rcadene): Not implemented + return + if policy_name == "tdmpc": + # TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this. + with open_dict(cfg): + cfg['n_obs_steps'] = 1 + offline_buffer = make_offline_buffer(cfg) + env = make_env(cfg, transform=offline_buffer.transform) + policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE)) def test_abstract_policy_forward():