Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Aloha env and ACT policy #11

Merged
merged 7 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ jobs:
env:
POETRY_VERSION: 1.8.1
DATA_DIR: tests/data
TMPDIR: ~/tmp
TEMP: ~/tmp
TMP: ~/tmp
PYOPENGL_PLATFORM: egl
MUJOCO_GL: egl
LEROBOT_TESTS_DEVICE: cpu
steps:
#----------------------------------------------
# check-out repo and set-up python
Expand All @@ -26,11 +32,13 @@ jobs:
uses: actions/checkout@v4
with:
lfs: true

- name: Set up python
id: setup-python
uses: actions/setup-python@v5
with:
python-version: '3.10'

#----------------------------------------------
# install & configure poetry
#----------------------------------------------
Expand All @@ -40,13 +48,15 @@ jobs:
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache

- name: Install Poetry
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
with:
version: ${{ env.POETRY_VERSION }}
virtualenvs-create: true
installer-parallel: true

- name: Save cached Poetry installation
if: |
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
Expand All @@ -56,8 +66,10 @@ jobs:
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache

- name: Configure Poetry
run: poetry config virtualenvs.in-project true

#----------------------------------------------
# install dependencies
#----------------------------------------------
Expand All @@ -67,9 +79,21 @@ jobs:
with:
path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}

- name: Info
run: |
sudo du -sh /tmp
sudo df -h

- name: Install dependencies
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
run: |
mkdir ~/tmp
echo $TMPDIR
echo $TEMP
echo $TMP
poetry install --no-interaction --no-root

- name: Save cached venv
if: |
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
Expand All @@ -79,18 +103,24 @@ jobs:
with:
path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}

- name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl)
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev

#----------------------------------------------
# install project
#----------------------------------------------
- name: Install project
run: poetry install --no-interaction

#----------------------------------------------
# run tests
#----------------------------------------------
- name: Run tests
run: |
source .venv/bin/activate
pytest tests

- name: Test train pusht end-to-end
run: |
source .venv/bin/activate
Expand All @@ -104,6 +134,7 @@ jobs:
save_model=true \
save_freq=1 \
hydra.run.dir=tests/outputs/

- name: Test eval pusht end-to-end
run: |
source .venv/bin/activate
Expand Down
17 changes: 4 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,10 @@ env=pusht

## TODO

- [x] priority update doesn't match FOWM or original paper
- [x] self.step=100000 should be updated at every step to adjust to the horizon of the planner
- [ ] prefetch replay buffer to speedup training
- [ ] parallelize env to speed up eval
- [ ] clean checkpointing / loading
- [ ] clean logging
- [ ] clean config
- [ ] clean hyperparameter tuning
- [ ] add pusht
- [ ] add aloha
- [ ] add act
- [ ] add diffusion
- [ ] add aloha 2
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)

Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.


## Profile

Expand Down
5 changes: 4 additions & 1 deletion lerobot/common/datasets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def transform(self):
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
self._transform = Compose(transform)
if isinstance(transform, list):
self._transform = Compose(*transform)
else:
self._transform = Compose(transform)
else:
self._transform = transform

Expand Down
9 changes: 3 additions & 6 deletions lerobot/common/datasets/aloha.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def download(data_dir, dataset_id):

data_dir.mkdir(parents=True, exist_ok=True)

gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))

# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)


class AlohaExperienceReplay(AbstractExperienceReplay):
Expand Down Expand Up @@ -124,9 +124,6 @@ def stats_patterns(self) -> dict:
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]

# def _is_downloaded(self) -> bool:
# return False

def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():
Expand Down
28 changes: 20 additions & 8 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler

from lerobot.common.envs.transforms import NormalizeTransform
from lerobot.common.envs.transforms import NormalizeTransform, Prod

DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))

Expand Down Expand Up @@ -84,6 +84,16 @@ def make_offline_buffer(
prefetch=prefetch if isinstance(prefetch, int) else None,
)

if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys

transforms = [Prod(in_keys=img_keys, prod=1 / 255)]

if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
Expand All @@ -92,11 +102,10 @@ def make_offline_buffer(
in_keys = [("observation", "state"), ("action")]

if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", *key))
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys += img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state"))

if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
Expand All @@ -106,8 +115,11 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)

transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))

offline_buffer.set_transform(transforms)

if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)
Expand Down
21 changes: 19 additions & 2 deletions lerobot/common/datasets/simxarm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
import zipfile
from pathlib import Path
from typing import Callable

Expand All @@ -15,6 +16,22 @@
from lerobot.common.datasets.abstract import AbstractExperienceReplay


def download():
raise NotImplementedError()
import gdown

url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
Path(download_path).unlink()


class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [
"xarm_lift_medium",
Expand Down Expand Up @@ -48,8 +65,8 @@ def __init__(
)

def _download_and_preproc(self):
# download
# TODO(rcadene)
# TODO(rcadene): finish download
download()

dataset_path = self.data_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
Expand Down
80 changes: 80 additions & 0 deletions lerobot/common/envs/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import abc
from collections import deque
from typing import Optional

from tensordict import TensorDict
from torchrl.envs import EnvBase


class AbstractEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []

if pixels_only:
assert from_pixels
if from_pixels:
assert image_size

self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)

if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)

def register_rendering_hook(self, func):
self._rendering_hooks.append(func)

def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)

def reset_rendering_hooks(self):
self._rendering_hooks = []

@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()

@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()

@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()

@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()

@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()
Loading
Loading