Skip to content

Commit

Permalink
Merge pull request #11 from Cadene/user/rcadene/2024_03_06_aloha_env
Browse files Browse the repository at this point in the history
[WIP] Add Aloha env and ACT policy
  • Loading branch information
Cadene authored Mar 12, 2024
2 parents 060bac7 + 998dd2b commit 8c56770
Show file tree
Hide file tree
Showing 119 changed files with 3,699 additions and 307 deletions.
33 changes: 32 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ jobs:
env:
POETRY_VERSION: 1.8.1
DATA_DIR: tests/data
TMPDIR: ~/tmp
TEMP: ~/tmp
TMP: ~/tmp
PYOPENGL_PLATFORM: egl
MUJOCO_GL: egl
LEROBOT_TESTS_DEVICE: cpu
steps:
#----------------------------------------------
# check-out repo and set-up python
Expand All @@ -26,11 +32,13 @@ jobs:
uses: actions/checkout@v4
with:
lfs: true

- name: Set up python
id: setup-python
uses: actions/setup-python@v5
with:
python-version: '3.10'

#----------------------------------------------
# install & configure poetry
#----------------------------------------------
Expand All @@ -40,13 +48,15 @@ jobs:
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache

- name: Install Poetry
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
with:
version: ${{ env.POETRY_VERSION }}
virtualenvs-create: true
installer-parallel: true

- name: Save cached Poetry installation
if: |
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
Expand All @@ -56,8 +66,10 @@ jobs:
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache

- name: Configure Poetry
run: poetry config virtualenvs.in-project true

#----------------------------------------------
# install dependencies
#----------------------------------------------
Expand All @@ -67,9 +79,21 @@ jobs:
with:
path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}

- name: Info
run: |
sudo du -sh /tmp
sudo df -h
- name: Install dependencies
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
run: |
mkdir ~/tmp
echo $TMPDIR
echo $TEMP
echo $TMP
poetry install --no-interaction --no-root
- name: Save cached venv
if: |
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
Expand All @@ -79,18 +103,24 @@ jobs:
with:
path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}

- name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl)
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev

#----------------------------------------------
# install project
#----------------------------------------------
- name: Install project
run: poetry install --no-interaction

#----------------------------------------------
# run tests
#----------------------------------------------
- name: Run tests
run: |
source .venv/bin/activate
pytest tests
- name: Test train pusht end-to-end
run: |
source .venv/bin/activate
Expand All @@ -104,6 +134,7 @@ jobs:
save_model=true \
save_freq=1 \
hydra.run.dir=tests/outputs/
- name: Test eval pusht end-to-end
run: |
source .venv/bin/activate
Expand Down
17 changes: 4 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,10 @@ env=pusht

## TODO

- [x] priority update doesn't match FOWM or original paper
- [x] self.step=100000 should be updated at every step to adjust to the horizon of the planner
- [ ] prefetch replay buffer to speedup training
- [ ] parallelize env to speed up eval
- [ ] clean checkpointing / loading
- [ ] clean logging
- [ ] clean config
- [ ] clean hyperparameter tuning
- [ ] add pusht
- [ ] add aloha
- [ ] add act
- [ ] add diffusion
- [ ] add aloha 2
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)

Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.


## Profile

Expand Down
5 changes: 4 additions & 1 deletion lerobot/common/datasets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def transform(self):
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
self._transform = Compose(transform)
if isinstance(transform, list):
self._transform = Compose(*transform)
else:
self._transform = Compose(transform)
else:
self._transform = transform

Expand Down
9 changes: 3 additions & 6 deletions lerobot/common/datasets/aloha.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def download(data_dir, dataset_id):

data_dir.mkdir(parents=True, exist_ok=True)

gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))

# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)


class AlohaExperienceReplay(AbstractExperienceReplay):
Expand Down Expand Up @@ -124,9 +124,6 @@ def stats_patterns(self) -> dict:
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]

# def _is_downloaded(self) -> bool:
# return False

def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():
Expand Down
28 changes: 20 additions & 8 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler

from lerobot.common.envs.transforms import NormalizeTransform
from lerobot.common.envs.transforms import NormalizeTransform, Prod

DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))

Expand Down Expand Up @@ -84,6 +84,16 @@ def make_offline_buffer(
prefetch=prefetch if isinstance(prefetch, int) else None,
)

if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys

transforms = [Prod(in_keys=img_keys, prod=1 / 255)]

if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
Expand All @@ -92,11 +102,10 @@ def make_offline_buffer(
in_keys = [("observation", "state"), ("action")]

if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", *key))
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys += img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state"))

if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
Expand All @@ -106,8 +115,11 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)

transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))

offline_buffer.set_transform(transforms)

if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)
Expand Down
21 changes: 19 additions & 2 deletions lerobot/common/datasets/simxarm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
import zipfile
from pathlib import Path
from typing import Callable

Expand All @@ -15,6 +16,22 @@
from lerobot.common.datasets.abstract import AbstractExperienceReplay


def download():
raise NotImplementedError()
import gdown

url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
Path(download_path).unlink()


class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [
"xarm_lift_medium",
Expand Down Expand Up @@ -48,8 +65,8 @@ def __init__(
)

def _download_and_preproc(self):
# download
# TODO(rcadene)
# TODO(rcadene): finish download
download()

dataset_path = self.data_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
Expand Down
80 changes: 80 additions & 0 deletions lerobot/common/envs/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import abc
from collections import deque
from typing import Optional

from tensordict import TensorDict
from torchrl.envs import EnvBase


class AbstractEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []

if pixels_only:
assert from_pixels
if from_pixels:
assert image_size

self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)

if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)

def register_rendering_hook(self, func):
self._rendering_hooks.append(func)

def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)

def reset_rendering_hooks(self):
self._rendering_hooks = []

@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()

@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()

@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()

@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()

@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()
Loading

0 comments on commit 8c56770

Please sign in to comment.