Skip to content

Commit

Permalink
backup wip
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare committed Mar 19, 2024
1 parent 8834796 commit ea17f4c
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 46 deletions.
6 changes: 3 additions & 3 deletions lerobot/common/datasets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> c",
("action",): "b c -> c",
}

@property
Expand Down
6 changes: 3 additions & 3 deletions lerobot/common/datasets/aloha.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def __init__(
@property
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("action",): "b c -> c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
d[("observation", "image", cam)] = "b c h w -> c"
return d

@property
Expand Down
55 changes: 40 additions & 15 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv


def make_env(cfg, seed=None, transform=None):
def make_env(cfg, transform=None):
"""
Provide seed to override the seed in the cfg (useful for batched environments).
"""
# assert cfg.rollout_batch_size == 1, \
# """
# For the time being, rollout batch sizes of > 1 are not supported. This is because the SerialEnv rollout does not
# correctly handle terminated environments. If you really want to use a larger batch size, read on...

# When calling `EnvBase.rollout` with `break_when_any_done == True` all environments stop rolling out as soon as the
# first is terminated or truncated. This almost certainly results in incorrect success metrics, as all but the first
# environment get an opportunity to reach the goal. A possible work around is to comment out `if any_done: break`
# inf `EnvBase._rollout_stop_early`. One potential downside is that the environments `step` function will continue
# to be called and the outputs will continue to be added to the rollout.

# When calling `EnvBase.rollout` with `break_when_any_done == False` environments are reset when done.
# """

kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
"num_prev_obs": cfg.n_obs_steps - 1,
"seed": seed if seed is not None else cfg.seed,
"seed": cfg.seed,
}

if cfg.env.name == "simxarm":
Expand All @@ -33,22 +47,33 @@ def make_env(cfg, seed=None, transform=None):
else:
raise ValueError(cfg.env.name)

env = clsfunc(**kwargs)
def _make_env(seed):
nonlocal kwargs
kwargs["seed"] = seed
env = clsfunc(**kwargs)

# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))

# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()

if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env

return env
# return SerialEnv(
# cfg.rollout_batch_size,
# create_env_fn=_make_env,
# create_env_kwargs={
# "seed": env_seed for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
# },
# )


# def make_env(env_name, frame_skip, device, is_test=False):
Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/policies/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def select_action(self, observation) -> Tensor:
Should return a (batch_size, n_action_steps, *) tensor of actions.
"""

def forward(self, *args, **kwargs):
def forward(self, *args, **kwargs) -> Tensor:
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Expand Down
8 changes: 5 additions & 3 deletions lerobot/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ hydra:

seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
rollout_batch_size: 10
# NOTE: batch size of 1 is not yet supported! This is just a placeholder for future support. See
# `lerobot.common.envs.factory.make_env` for more information.
rollout_batch_size: 1
device: cuda # cpu
prefetch: 4
eval_freq: ???
save_freq: ???
eval_episodes: ???
save_video: false
save_model: false
save_model: true
save_buffer: false
train_steps: ???
fps: ???
Expand All @@ -31,7 +33,7 @@ env: ???
policy: ???

wandb:
enable: true
enable: false
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot
Expand Down
4 changes: 2 additions & 2 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ keypoint_visible_rate: 1.0
obs_as_global_cond: True

eval_episodes: 1
eval_freq: 10000
save_freq: 100000
eval_freq: 5000
save_freq: 5000
log_freq: 250

offline_steps: 1344000
Expand Down
11 changes: 2 additions & 9 deletions lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import tqdm
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase, SerialEnv
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase

from lerobot.common.datasets.factory import make_offline_buffer
Expand Down Expand Up @@ -131,14 +131,7 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg)

logging.info("make_env")
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform}
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
env = make_env(cfg, transform=offline_buffer.transform)

if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)
Expand Down
9 changes: 0 additions & 9 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from tensordict.nn import TensorDictModule
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from torchrl.envs import SerialEnv

from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
Expand Down Expand Up @@ -149,14 +148,6 @@ def train(cfg: dict, out_dir=None, job_name=None):

logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)

logging.info("make_policy")
policy = make_policy(cfg)
Expand Down
Binary file modified tests/data/aloha_sim_insertion_human/stats.pth
Binary file not shown.
Binary file modified tests/data/pusht/stats.pth
Binary file not shown.
16 changes: 15 additions & 1 deletion tests/test_policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

from omegaconf import open_dict
import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
Expand All @@ -7,7 +8,8 @@
from torchrl.envs import EnvBase

from lerobot.common.policies.factory import make_policy

from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy

from .utils import DEVICE, init_config
Expand All @@ -30,7 +32,19 @@ def test_factory(env_name, policy_name):
f"device={DEVICE}",
]
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_action and get the appropriate output.
if env_name == "simxarm":
# TODO(rcadene): Not implemented
return
if policy_name == "tdmpc":
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
with open_dict(cfg):
cfg['n_obs_steps'] = 1
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))


def test_abstract_policy_forward():
Expand Down

0 comments on commit ea17f4c

Please sign in to comment.