Skip to content

Commit

Permalink
Merge pull request #12 from Cadene/user/aliberts/2024_03_08_test_data
Browse files Browse the repository at this point in the history
Add pusht test artifact
  • Loading branch information
aliberts authored Mar 9, 2024
2 parents 7dbdbb0 + 450e32e commit fa7a947
Show file tree
Hide file tree
Showing 21 changed files with 109 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.memmap filter=lfs diff=lfs merge=lfs -text
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
runs-on: ubuntu-latest
env:
POETRY_VERSION: 1.8.1
DATA_DIR: tests/data
steps:
#----------------------------------------------
# check-out repo and set-up python
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
!tests/data
htmlcov/
.tox/
.nox/
Expand Down
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,31 @@ pre-commit run -a
```

**Tests**

Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
```
brew install git-lfs
git lfs install
```

On Ubuntu:
```
sudo apt-get install git-lfs
git lfs install
```

Pull artifacts if they're not in [tests/data](tests/data)
```
git lfs pull
```

When adding a new dataset, mock it with
```
python tests/scripts/mock_dataset.py --in-data-dir data/<dataset_id> --out-data-dir tests/data/<dataset_id>
```

Run tests
```
pytest -sx tests
DATA_DIR="tests/data" pytest -sx tests
```
5 changes: 5 additions & 0 deletions lerobot/common/datasets/pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def _download_and_preproc(self):
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
Expand All @@ -142,6 +145,8 @@ def _download_and_preproc(self):
idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
# to create test artifact
# idx1 = 51

num_frames = idx1 - idx0

Expand Down
3 changes: 3 additions & 0 deletions tests/data/pusht/action.memmap
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/pusht/episode.memmap
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/pusht/frame_id.memmap
Git LFS file not shown
1 change: 1 addition & 0 deletions tests/data/pusht/meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"action": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
3 changes: 3 additions & 0 deletions tests/data/pusht/next/done.memmap
Git LFS file not shown
1 change: 1 addition & 0 deletions tests/data/pusht/next/meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"reward": {"device": "cpu", "shape": [50, 1], "dtype": "torch.float32"}, "done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "success": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
3 changes: 3 additions & 0 deletions tests/data/pusht/next/observation/image.memmap
Git LFS file not shown
1 change: 1 addition & 0 deletions tests/data/pusht/next/observation/meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
3 changes: 3 additions & 0 deletions tests/data/pusht/next/observation/state.memmap
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/pusht/next/reward.memmap
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/pusht/next/success.memmap
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/pusht/observation/image.memmap
Git LFS file not shown
1 change: 1 addition & 0 deletions tests/data/pusht/observation/meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"image": {"device": "cpu", "shape": [50, 3, 96, 96], "dtype": "torch.float32"}, "state": {"device": "cpu", "shape": [50, 2], "dtype": "torch.float32"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
3 changes: 3 additions & 0 deletions tests/data/pusht/observation/state.memmap
Git LFS file not shown
Binary file added tests/data/pusht/stats.pth
Binary file not shown.
41 changes: 41 additions & 0 deletions tests/scripts/mock_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
usage: `python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
"""

import argparse
import shutil

from tensordict import TensorDict
from pathlib import Path


def mock_dataset(in_data_dir, out_data_dir, num_frames=50):
# load full dataset as a tensor dict
in_td_data = TensorDict.load_memmap(in_data_dir)

# use 1 frame to know the specification of the dataset
# and copy it over `n` frames in the test artifact directory
out_td_data = in_td_data[0].expand(num_frames).memmap_like(out_data_dir)

# copy the first `n` frames so that we have real data
out_td_data[:num_frames] = in_td_data[:num_frames].clone()

# make sure everything has been properly written
out_td_data.lock_()

# copy the full statistics of dataset since it's pretty small
in_stats_path = Path(in_data_dir) / "stats.pth"
out_stats_path = Path(out_data_dir) / "stats.pth"
shutil.copy(in_stats_path, out_stats_path)


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="Create dataset")

parser.add_argument("--in-data-dir", type=str, help="Path to input data")
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")

args = parser.parse_args()

mock_dataset(args.in_data_dir, args.out_data_dir)

0 comments on commit fa7a947

Please sign in to comment.