From 19730b34127a6a39e31a25a8edb3bd711df9ed90 Mon Sep 17 00:00:00 2001 From: Cadene Date: Thu, 14 Mar 2024 16:59:37 +0000 Subject: [PATCH 1/5] Add pusht on hf dataset (WIP) --- lerobot/common/datasets/pusht.py | 5 +++++ poetry.lock | 2 +- pyproject.toml | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index ae987ad1e..c72bc9c1b 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -8,6 +8,7 @@ import torch import torchrl import tqdm +from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage @@ -112,6 +113,10 @@ def __init__( ) def _download_and_preproc(self): + snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir) + return TensorStorage(TensorDict.load_memmap(self.data_dir)) + + def _download_and_preproc_obsolete(self): raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): diff --git a/poetry.lock b/poetry.lock index 59de0ec55..b2c8cc368 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3254,4 +3254,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8" +content-hash = "0794a87fd309dffa0ad2982b6902bed7f35ae9e2a82433420516798da04c7197" diff --git a/pyproject.toml b/pyproject.toml index 85af7f825..8542383e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" dm-control = "1.0.14" +huggingface-hub = "^0.21.4" [tool.poetry.group.dev.dependencies] From a311d387969be88aa0429f698c758e5c1f5486c6 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 15 Mar 2024 00:30:11 +0000 Subject: [PATCH 2/5] Add aloha + improve readme --- README.md | 19 ++++++++ lerobot/common/datasets/abstract.py | 27 +++-------- lerobot/common/datasets/aloha.py | 8 +-- lerobot/common/datasets/factory.py | 6 +-- lerobot/common/datasets/pusht.py | 9 +--- lerobot/common/datasets/simxarm.py | 6 +-- poetry.lock | 75 ++++++++++++++++++++++++++++- pyproject.toml | 2 +- 8 files changed, 115 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index e30e1bd65..1051c8a6b 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,25 @@ Run tests DATA_DIR="tests/data" pytest -sx tests ``` +**Datasets** + +To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: +``` +huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential +``` + +Then you can upload it to the hub with: +``` +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET +``` + +For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used: +``` +HF_USER=cadene +DATASET=pusht +``` + + ## Acknowledgment - Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) - Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 514fa0388..0e8fcc2b0 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -1,4 +1,3 @@ -import abc import logging from pathlib import Path from typing import Callable @@ -7,8 +6,8 @@ import torch import torchrl import tqdm +from huggingface_hub import snapshot_download from tensordict import TensorDict -from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id @@ -33,11 +32,8 @@ def __init__( ): self.dataset_id = dataset_id self.shuffle = shuffle - self.root = _get_root_dir(self.dataset_id) if root is None else root - self.root = Path(self.root) - self.data_dir = self.root / self.dataset_id - - storage = self._download_or_load_storage() + self.root = root + storage = self._download_or_load_dataset() super().__init__( storage=storage, @@ -98,19 +94,12 @@ def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: torch.save(stats, stats_path) return stats - @abc.abstractmethod - def _download_and_preproc(self) -> torch.StorageBase: - raise NotImplementedError() - - def _download_or_load_storage(self): - if not self._is_downloaded(): - storage = self._download_and_preproc() + def _download_or_load_dataset(self) -> torch.StorageBase: + if self.root is None: + data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") else: - storage = TensorStorage(TensorDict.load_memmap(self.data_dir)) - return storage - - def _is_downloaded(self) -> bool: - return self.data_dir.is_dir() + data_dir = Path(self.root) / self.dataset_id + return TensorStorage(TensorDict.load_memmap(data_dir)) def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 3b53fed1e..68a3aa82d 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -124,8 +124,8 @@ def stats_patterns(self) -> dict: def image_keys(self) -> list: return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] - def _download_and_preproc(self): - raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" + def _download_and_preproc_obsolete(self, data_dir="data"): + raw_dir = Path(data_dir) / f"{self.dataset_id}_raw" if not raw_dir.is_dir(): download(raw_dir, self.dataset_id) @@ -174,7 +174,9 @@ def _download_and_preproc(self): if ep_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir) + td_data = ( + ep_td[0].expand(total_num_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") + ) td_data[idxtd : idxtd + len(ep_td)] = ep_td idxtd = idxtd + len(ep_td) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index fd284ae2b..876b6a50a 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,13 +1,13 @@ import logging import os -from pathlib import Path import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler from lerobot.common.envs.transforms import NormalizeTransform, Prod -DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) +# used for unit tests +DATA_DIR = os.environ.get("DATA_DIR", None) def make_offline_buffer( @@ -77,9 +77,9 @@ def make_offline_buffer( offline_buffer = clsfunc( dataset_id=dataset_id, - root=DATA_DIR, sampler=sampler, batch_size=batch_size, + root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, ) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index c72bc9c1b..ed2ec4eed 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -8,7 +8,6 @@ import torch import torchrl import tqdm -from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage @@ -112,12 +111,8 @@ def __init__( transform=transform, ) - def _download_and_preproc(self): - snapshot_download(repo_id="cadene/pusht", local_dir=self.data_dir) - return TensorStorage(TensorDict.load_memmap(self.data_dir)) - def _download_and_preproc_obsolete(self): - raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" + raw_dir = Path(self.root) / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) @@ -213,7 +208,7 @@ def _download_and_preproc_obsolete(self): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir) + td_data = ep_td[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") td_data[idxtd : idxtd + len(ep_td)] = ep_td diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 1d56850ec..1d620c358 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -64,11 +64,11 @@ def __init__( transform=transform, ) - def _download_and_preproc(self): + def _download_and_preproc_obsolete(self): # TODO(rcadene): finish download download() - dataset_path = self.data_dir / "buffer.pkl" + dataset_path = Path(self.root) / "data" / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -110,7 +110,7 @@ def _download_and_preproc(self): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) + td_data = episode[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") td_data[idx0:idx1] = episode diff --git a/poetry.lock b/poetry.lock index b2c8cc368..a76858bd0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -838,6 +838,78 @@ files = [ [package.dependencies] numpy = ">=1.17.3" +[[package]] +name = "hf-transfer" +version = "0.1.6" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, + {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, + {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, + {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, + {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, + {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, + {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, + {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, + {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, + {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, + {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, + {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, + {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, + {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, +] + [[package]] name = "huggingface-hub" version = "0.21.4" @@ -852,6 +924,7 @@ files = [ [package.dependencies] filelock = "*" fsspec = ">=2023.5.0" +hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf_transfer\""} packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -3254,4 +3327,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "0794a87fd309dffa0ad2982b6902bed7f35ae9e2a82433420516798da04c7197" +content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9" diff --git a/pyproject.toml b/pyproject.toml index 8542383e7..2e818a440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" dm-control = "1.0.14" -huggingface-hub = "^0.21.4" +huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} [tool.poetry.group.dev.dependencies] From b10c9507d4f2df0984b77abcc2948f2cbbb31e9b Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 15 Mar 2024 00:36:55 +0000 Subject: [PATCH 3/5] Small fix --- lerobot/common/datasets/abstract.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 0e8fcc2b0..61a0d25bc 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -96,10 +96,10 @@ def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: def _download_or_load_dataset(self) -> torch.StorageBase: if self.root is None: - data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") + self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") else: - data_dir = Path(self.root) / self.dataset_id - return TensorStorage(TensorDict.load_memmap(data_dir)) + self.data_dir = Path(self.root) / self.dataset_id + return TensorStorage(TensorDict.load_memmap(self.data_dir)) def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( From 41521f7e962ba52d668ef87f21f3bf306b9db2f6 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 15 Mar 2024 10:56:46 +0000 Subject: [PATCH 4/5] self.root is Path or None + The following packages are already present in the pyproject.toml and will be skipped: - huggingface-hub If you want to update it to the latest compatible version, you can use `poetry update package`. If you prefer to upgrade it to the latest available version, you can use `poetry add package@latest`. Nothing to add. --- .github/poetry/cpu/poetry.lock | 4 ++-- .github/poetry/cpu/pyproject.toml | 1 + lerobot/common/datasets/abstract.py | 4 ++-- lerobot/common/datasets/aloha.py | 9 ++++----- lerobot/common/datasets/pusht.py | 5 +++-- lerobot/common/datasets/simxarm.py | 5 +++-- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index d224b6682..c07e34395 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -3123,4 +3123,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "66c60543d2f59ac3d0e1fcda298ea14c0c60a8c6bcea73902f4f6aa3dd47661b" +content-hash = "4aa6a1e3f29560dd4a1c24d493ee1154089da4aa8d2190ad1f786c125ab2b735" diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index 4880f61e2..fd7eb226a 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -51,6 +51,7 @@ torchvision = {version = "^0.17.1", source = "torch-cpu"} h5py = "^3.10.0" dm = "^1.3" dm-control = "^1.0.16" +huggingface-hub = "^0.21.4" [tool.poetry.group.dev.dependencies] diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 61a0d25bc..3e0e2c320 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -32,7 +32,7 @@ def __init__( ): self.dataset_id = dataset_id self.shuffle = shuffle - self.root = root + self.root = root if root is None else Path(root) storage = self._download_or_load_dataset() super().__init__( @@ -98,7 +98,7 @@ def _download_or_load_dataset(self) -> torch.StorageBase: if self.root is None: self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") else: - self.data_dir = Path(self.root) / self.dataset_id + self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir)) def _compute_stats(self, num_batch=100, batch_size=32): diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 68a3aa82d..2ea4b831a 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -124,8 +124,9 @@ def stats_patterns(self) -> dict: def image_keys(self) -> list: return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] - def _download_and_preproc_obsolete(self, data_dir="data"): - raw_dir = Path(data_dir) / f"{self.dataset_id}_raw" + def _download_and_preproc_obsolete(self): + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" if not raw_dir.is_dir(): download(raw_dir, self.dataset_id) @@ -174,9 +175,7 @@ def _download_and_preproc_obsolete(self, data_dir="data"): if ep_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ( - ep_td[0].expand(total_num_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") - ) + td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idxtd : idxtd + len(ep_td)] = ep_td idxtd = idxtd + len(ep_td) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index ed2ec4eed..bac742d99 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -112,7 +112,8 @@ def __init__( ) def _download_and_preproc_obsolete(self): - raw_dir = Path(self.root) / f"{self.dataset_id}_raw" + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) @@ -208,7 +209,7 @@ def _download_and_preproc_obsolete(self): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") + td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idxtd : idxtd + len(ep_td)] = ep_td diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 1d620c358..b4dd824f9 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -65,10 +65,11 @@ def __init__( ) def _download_and_preproc_obsolete(self): + assert self.root is not None # TODO(rcadene): finish download download() - dataset_path = Path(self.root) / "data" / "buffer.pkl" + dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -110,7 +111,7 @@ def _download_and_preproc_obsolete(self): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(Path(self.root) / f"{self.dataset_id}") + td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idx0:idx1] = episode From 5805a7ffb110281e63db0463c1fb9c9b57bca885 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 15 Mar 2024 12:44:52 +0000 Subject: [PATCH 5/5] small fix in type + comments --- lerobot/common/datasets/abstract.py | 4 ++-- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/datasets/factory.py | 7 +++++-- lerobot/common/datasets/pusht.py | 2 +- lerobot/common/datasets/simxarm.py | 2 +- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 3e0e2c320..e96133103 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -22,7 +22,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, @@ -32,7 +32,7 @@ def __init__( ): self.dataset_id = dataset_id self.shuffle = shuffle - self.root = root if root is None else Path(root) + self.root = root storage = self._download_or_load_dataset() super().__init__( diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 2ea4b831a..52a5676ee 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -87,7 +87,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 876b6a50a..3f4772c40 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,13 +1,16 @@ import logging import os +from pathlib import Path import torch from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler from lerobot.common.envs.transforms import NormalizeTransform, Prod -# used for unit tests -DATA_DIR = os.environ.get("DATA_DIR", None) +# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and +# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` +# to load a subset of our datasets for faster continuous integration. +DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_offline_buffer( diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index bac742d99..f4f6d9aca 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -90,7 +90,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index b4dd824f9..7bcb03fbd 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -43,7 +43,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None,