Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into user/alexander-soare/…
Browse files Browse the repository at this point in the history
…multistep_policy_and_serial_env
  • Loading branch information
alexander-soare committed Mar 15, 2024
2 parents a222c88 + 9c88071 commit a45896d
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 54 deletions.
4 changes: 2 additions & 2 deletions .github/poetry/cpu/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .github/poetry/cpu/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ torchvision = {version = "^0.17.1", source = "torch-cpu"}
h5py = "^3.10.0"
dm = "^1.3"
dm-control = "^1.0.16"
huggingface-hub = "^0.21.4"


[tool.poetry.group.dev.dependencies]
Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,25 @@ Run tests
DATA_DIR="tests/data" pytest -sx tests
```

**Datasets**

To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
```

Then you can upload it to the hub with:
```
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET
```

For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
```
HF_USER=cadene
DATASET=pusht
```


## Acknowledgment
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
Expand Down
29 changes: 9 additions & 20 deletions lerobot/common/datasets/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import logging
from pathlib import Path
from typing import Callable
Expand All @@ -7,8 +6,8 @@
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
Expand All @@ -23,7 +22,7 @@ def __init__(
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
Expand All @@ -33,11 +32,8 @@ def __init__(
):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = _get_root_dir(self.dataset_id) if root is None else root
self.root = Path(self.root)
self.data_dir = self.root / self.dataset_id

storage = self._download_or_load_storage()
self.root = root
storage = self._download_or_load_dataset()

super().__init__(
storage=storage,
Expand Down Expand Up @@ -98,19 +94,12 @@ def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
torch.save(stats, stats_path)
return stats

@abc.abstractmethod
def _download_and_preproc(self) -> torch.StorageBase:
raise NotImplementedError()

def _download_or_load_storage(self):
if not self._is_downloaded():
storage = self._download_and_preproc()
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
else:
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
return storage

def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir))

def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(
Expand Down
9 changes: 5 additions & 4 deletions lerobot/common/datasets/aloha.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
Expand Down Expand Up @@ -124,8 +124,9 @@ def stats_patterns(self) -> dict:
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]

def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)

Expand Down Expand Up @@ -174,7 +175,7 @@ def _download_and_preproc(self):

if ep_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")

td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td)
Expand Down
7 changes: 5 additions & 2 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from lerobot.common.envs.transforms import NormalizeTransform, Prod

DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None


def make_offline_buffer(
Expand Down Expand Up @@ -77,9 +80,9 @@ def make_offline_buffer(

offline_buffer = clsfunc(
dataset_id=dataset_id,
root=DATA_DIR,
sampler=sampler,
batch_size=batch_size,
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
Expand Down
9 changes: 5 additions & 4 deletions lerobot/common/datasets/pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
Expand All @@ -111,8 +111,9 @@ def __init__(
transform=transform,
)

def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -208,7 +209,7 @@ def _download_and_preproc(self):

if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")

td_data[idxtd : idxtd + len(ep_td)] = ep_td

Expand Down
9 changes: 5 additions & 4 deletions lerobot/common/datasets/simxarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
batch_size: int = None,
*,
shuffle: bool = True,
root: Path = None,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
Expand All @@ -64,11 +64,12 @@ def __init__(
transform=transform,
)

def _download_and_preproc(self):
def _download_and_preproc_obsolete(self):
assert self.root is not None
# TODO(rcadene): finish download
download()

dataset_path = self.data_dir / "buffer.pkl"
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
Expand Down Expand Up @@ -110,7 +111,7 @@ def _download_and_preproc(self):

if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")

td_data[idx0:idx1] = episode

Expand Down
Loading

0 comments on commit a45896d

Please sign in to comment.