diff --git a/.github/poetry/cpu/poetry.lock b/.github/poetry/cpu/poetry.lock index d224b6682..c07e34395 100644 --- a/.github/poetry/cpu/poetry.lock +++ b/.github/poetry/cpu/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -3123,4 +3123,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "66c60543d2f59ac3d0e1fcda298ea14c0c60a8c6bcea73902f4f6aa3dd47661b" +content-hash = "4aa6a1e3f29560dd4a1c24d493ee1154089da4aa8d2190ad1f786c125ab2b735" diff --git a/.github/poetry/cpu/pyproject.toml b/.github/poetry/cpu/pyproject.toml index 4880f61e2..fd7eb226a 100644 --- a/.github/poetry/cpu/pyproject.toml +++ b/.github/poetry/cpu/pyproject.toml @@ -51,6 +51,7 @@ torchvision = {version = "^0.17.1", source = "torch-cpu"} h5py = "^3.10.0" dm = "^1.3" dm-control = "^1.0.16" +huggingface-hub = "^0.21.4" [tool.poetry.group.dev.dependencies] diff --git a/README.md b/README.md index e30e1bd65..1051c8a6b 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,25 @@ Run tests DATA_DIR="tests/data" pytest -sx tests ``` +**Datasets** + +To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access: +``` +huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential +``` + +Then you can upload it to the hub with: +``` +HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET +``` + +For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used: +``` +HF_USER=cadene +DATASET=pusht +``` + + ## Acknowledgment - Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) - Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 514fa0388..e96133103 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -1,4 +1,3 @@ -import abc import logging from pathlib import Path from typing import Callable @@ -7,8 +6,8 @@ import torch import torchrl import tqdm +from huggingface_hub import snapshot_download from tensordict import TensorDict -from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id @@ -23,7 +22,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, @@ -33,11 +32,8 @@ def __init__( ): self.dataset_id = dataset_id self.shuffle = shuffle - self.root = _get_root_dir(self.dataset_id) if root is None else root - self.root = Path(self.root) - self.data_dir = self.root / self.dataset_id - - storage = self._download_or_load_storage() + self.root = root + storage = self._download_or_load_dataset() super().__init__( storage=storage, @@ -98,19 +94,12 @@ def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: torch.save(stats, stats_path) return stats - @abc.abstractmethod - def _download_and_preproc(self) -> torch.StorageBase: - raise NotImplementedError() - - def _download_or_load_storage(self): - if not self._is_downloaded(): - storage = self._download_and_preproc() + def _download_or_load_dataset(self) -> torch.StorageBase: + if self.root is None: + self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset") else: - storage = TensorStorage(TensorDict.load_memmap(self.data_dir)) - return storage - - def _is_downloaded(self) -> bool: - return self.data_dir.is_dir() + self.data_dir = self.root / self.dataset_id + return TensorStorage(TensorDict.load_memmap(self.data_dir)) def _compute_stats(self, num_batch=100, batch_size=32): rb = TensorDictReplayBuffer( diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 3b53fed1e..52a5676ee 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -87,7 +87,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, @@ -124,8 +124,9 @@ def stats_patterns(self) -> dict: def image_keys(self) -> list: return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] - def _download_and_preproc(self): - raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" + def _download_and_preproc_obsolete(self): + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" if not raw_dir.is_dir(): download(raw_dir, self.dataset_id) @@ -174,7 +175,7 @@ def _download_and_preproc(self): if ep_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir) + td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idxtd : idxtd + len(ep_td)] = ep_td idxtd = idxtd + len(ep_td) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index fd284ae2b..3f4772c40 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -7,7 +7,10 @@ from lerobot.common.envs.transforms import NormalizeTransform, Prod -DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) +# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and +# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data` +# to load a subset of our datasets for faster continuous integration. +DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None def make_offline_buffer( @@ -77,9 +80,9 @@ def make_offline_buffer( offline_buffer = clsfunc( dataset_id=dataset_id, - root=DATA_DIR, sampler=sampler, batch_size=batch_size, + root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, ) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index ae987ad1e..f4f6d9aca 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -90,7 +90,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, @@ -111,8 +111,9 @@ def __init__( transform=transform, ) - def _download_and_preproc(self): - raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" + def _download_and_preproc_obsolete(self): + assert self.root is not None + raw_dir = self.root / f"{self.dataset_id}_raw" zarr_path = (raw_dir / PUSHT_ZARR).resolve() if not zarr_path.is_dir(): raw_dir.mkdir(parents=True, exist_ok=True) @@ -208,7 +209,7 @@ def _download_and_preproc(self): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir) + td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idxtd : idxtd + len(ep_td)] = ep_td diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 1d56850ec..7bcb03fbd 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -43,7 +43,7 @@ def __init__( batch_size: int = None, *, shuffle: bool = True, - root: Path = None, + root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: SliceSampler = None, @@ -64,11 +64,12 @@ def __init__( transform=transform, ) - def _download_and_preproc(self): + def _download_and_preproc_obsolete(self): + assert self.root is not None # TODO(rcadene): finish download download() - dataset_path = self.data_dir / "buffer.pkl" + dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -110,7 +111,7 @@ def _download_and_preproc(self): if episode_id == 0: # hack to initialize tensordict data structure to store episodes - td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) + td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}") td_data[idx0:idx1] = episode diff --git a/poetry.lock b/poetry.lock index 91b313765..5e84ebefa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -914,6 +914,78 @@ files = [ [package.dependencies] numpy = ">=1.17.3" +[[package]] +name = "hf-transfer" +version = "0.1.6" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"}, + {file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"}, + {file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"}, + {file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"}, + {file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"}, + {file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"}, + {file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"}, + {file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"}, + {file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"}, + {file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"}, + {file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"}, + {file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"}, + {file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"}, + {file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"}, + {file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"}, + {file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"}, + {file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"}, + {file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"}, + {file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"}, + {file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"}, + {file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"}, + {file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"}, + {file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"}, + {file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"}, + {file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"}, + {file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"}, + {file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"}, +] + [[package]] name = "huggingface-hub" version = "0.21.4" @@ -928,6 +1000,7 @@ files = [ [package.dependencies] filelock = "*" fsspec = ">=2023.5.0" +hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf_transfer\""} packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -1270,13 +1343,13 @@ source = ["Cython (>=3.0.7)"] [[package]] name = "markdown" -version = "3.5.2" +version = "3.6" description = "Python implementation of John Gruber's Markdown." optional = false python-versions = ">=3.8" files = [ - {file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"}, - {file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"}, + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, ] [package.extras] @@ -2726,13 +2799,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", [[package]] name = "sentry-sdk" -version = "1.41.0" +version = "1.42.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"}, - {file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"}, + {file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"}, + {file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"}, ] [package.dependencies] @@ -2756,6 +2829,7 @@ grpcio = ["grpcio (>=1.21.1)"] httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] @@ -2871,18 +2945,18 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "69.1.1" +version = "69.2.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.1.1-py3-none-any.whl", hash = "sha256:02fa291a0471b3a18b2b2481ed902af520c69e8ae0919c13da936542754b4c56"}, - {file = "setuptools-69.1.1.tar.gz", hash = "sha256:5c0806c7d9af348e6dd3777b4f4dbb42c7ad85b190104837488eab9a7c945cf8"}, + {file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"}, + {file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"}, ] [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -3040,7 +3114,7 @@ protobuf = ">=3.20" [[package]] name = "tensordict" -version = "0.4.0+551331d" +version = "0.4.0+f1c833e" description = "" optional = false python-versions = "*" @@ -3061,7 +3135,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures type = "git" url = "https://github.com/pytorch/tensordict" reference = "HEAD" -resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a" +resolved_reference = "f1c833ecf495aa61f3f76bf09f94dd708db496ec" [[package]] name = "termcolor" @@ -3437,20 +3511,20 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"] [[package]] name = "zipp" -version = "3.17.0" +version = "3.18.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, - {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, + {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, + {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d7c551c44380ac26e784390f077d902287c40582127db91f9a85ab5d7707ad59" +content-hash = "3bf6532037cfea563819989806d5cd171e33ceb077b0d6afddf54710cbbb3c74" diff --git a/pyproject.toml b/pyproject.toml index 1131da8ae..5cef3ef45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ h5py = "^3.10.0" robomimic = "0.2.0" timm = "^0.9.16" dm-control = "1.0.14" +huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} [tool.poetry.group.dev.dependencies]