Skip to content

Commit

Permalink
Merge pull request #10 from Cadene/user/rcadene/2024_03_06_fix_tests
Browse files Browse the repository at this point in the history
Fix env tests
  • Loading branch information
aliberts authored Mar 8, 2024
2 parents c2c0ef9 + f1e2837 commit 4cc7e15
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 17 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ jobs:
#----------------------------------------------
# run tests
#----------------------------------------------
- name: Run tests
run: |
source .venv/bin/activate
pytest tests
- name: Test train pusht end-to-end
run: |
source .venv/bin/activate
Expand Down
6 changes: 3 additions & 3 deletions lerobot/common/datasets/pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import torch
import torchrl
import tqdm
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer

from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.utils import download_and_extract_zip

Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
)

def _download_and_preproc(self):
raw_dir = self.data_dir / "raw"
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
Expand Down
1 change: 1 addition & 0 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def make_env(cfg, transform=None):
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
}

if cfg.env.name == "simxarm":
Expand Down
18 changes: 10 additions & 8 deletions lerobot/common/envs/pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import deque
from typing import Optional

import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
Expand All @@ -28,7 +29,7 @@ def __init__(
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_obs=0,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
Expand Down Expand Up @@ -65,7 +66,8 @@ def __init__(
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
self._prev_action_queue = deque(maxlen=self.num_prev_action)
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)

def render(self, mode="rgb_array", width=384, height=384):
if width != height:
Expand Down Expand Up @@ -133,7 +135,7 @@ def _step(self, tensordict: TensorDict):
sum_reward = 0

if action.ndim == 1:
action = action.repeat(self.frame_skip, 1)
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
Expand Down Expand Up @@ -172,7 +174,7 @@ def _make_spec(self):
if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs, *image_shape)
image_shape = (self.num_prev_obs + 1, *image_shape)

obs["image"] = BoundedTensorSpec(
low=0,
Expand All @@ -184,24 +186,24 @@ def _make_spec(self):
if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
state_shape = (self.num_prev_obs + 1, *state_shape)

obs["state"] = BoundedTensorSpec(
low=0,
high=512,
shape=self._env.observation_space["agent_pos"].shape,
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs, *state_shape)
state_shape = (self.num_prev_obs + 1, *state_shape)

obs["state"] = UnboundedContinuousTensorSpec(
# TODO:
shape=self._env.observation_space["observation"].shape,
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
Expand Down
16 changes: 11 additions & 5 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@


@pytest.mark.parametrize(
"env_name",
"env_name,dataset_id",
[
"simxarm",
"pusht",
# TODO(rcadene): simxarm is depreciated for now
# ("simxarm", "lift"),
("pusht", "pusht"),
# TODO(aliberts): add aloha when dataset is available on hub
# ("aloha", "sim_insertion_human"),
# ("aloha", "sim_insertion_scripted"),
# ("aloha", "sim_transfer_cube_human"),
# ("aloha", "sim_transfer_cube_scripted"),
],
)
def test_factory(env_name):
cfg = init_config(overrides=[f"env={env_name}"])
def test_factory(env_name, dataset_id):
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}"])
offline_buffer = make_offline_buffer(cfg)
3 changes: 2 additions & 1 deletion tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def simple_rollout(steps=100):
print("data from rollout:", simple_rollout(100))


@pytest.mark.skip(reason="Simxarm is deprecated")
@pytest.mark.parametrize(
"task,from_pixels,pixels_only",
[
Expand Down Expand Up @@ -80,7 +81,7 @@ def test_pusht(from_pixels, pixels_only):
@pytest.mark.parametrize(
"env_name",
[
"simxarm",
# "simxarm",
"pusht",
],
)
Expand Down

0 comments on commit 4cc7e15

Please sign in to comment.