Skip to content

Commit

Permalink
Add Aloha env and ACT policy
Browse files Browse the repository at this point in the history
WIP Aloha env tests pass

Rendering works (fps look fast tho? TODO action bounding is too wide [-1,1])

Update README

Copy past from act repo

Remove download.py add a WIP for Simxarm

Remove download.py add a WIP for Simxarm

Add act yaml (TODO: try train.py)

Training can runs (TODO: eval)

Add tasks without end_effector that are compatible with dataset, Eval can run (TODO: training and pretrained model)

Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)

poetry lock

fix bug in compute_stats for action normalization

fix more bugs in normalization

fix training

fix import

PushtEnv inheriates AbstractEnv, Improve factory Normalization

Add _make_env to EnvAbstract

Add call_rendering_hooks to pusht env

SimxarmEnv inherites from AbstractEnv (NOT TESTED)

Add aloha tests artifacts + update pusht stats

fix image normalization: before env was in [0,1] but dataset in [0,255], and now both in [0,255]

Small fix on simxarm

Add next to obs

Add top camera to Aloha env (TODO: make it compatible with set of cameras)

Add top camera to Aloha env (TODO: make it compatible with set of cameras)
  • Loading branch information
Cadene committed Mar 12, 2024
1 parent 060bac7 commit 9d00203
Show file tree
Hide file tree
Showing 116 changed files with 3,658 additions and 301 deletions.
17 changes: 4 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,10 @@ env=pusht

## TODO

- [x] priority update doesn't match FOWM or original paper
- [x] self.step=100000 should be updated at every step to adjust to the horizon of the planner
- [ ] prefetch replay buffer to speedup training
- [ ] parallelize env to speed up eval
- [ ] clean checkpointing / loading
- [ ] clean logging
- [ ] clean config
- [ ] clean hyperparameter tuning
- [ ] add pusht
- [ ] add aloha
- [ ] add act
- [ ] add diffusion
- [ ] add aloha 2
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)

Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.


## Profile

Expand Down
5 changes: 4 additions & 1 deletion lerobot/common/datasets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def transform(self):
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
self._transform = Compose(transform)
if isinstance(transform, list):
self._transform = Compose(*transform)
else:
self._transform = Compose(transform)
else:
self._transform = transform

Expand Down
9 changes: 3 additions & 6 deletions lerobot/common/datasets/aloha.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def download(data_dir, dataset_id):

data_dir.mkdir(parents=True, exist_ok=True)

gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))

# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)


class AlohaExperienceReplay(AbstractExperienceReplay):
Expand Down Expand Up @@ -124,9 +124,6 @@ def stats_patterns(self) -> dict:
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]

# def _is_downloaded(self) -> bool:
# return False

def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():
Expand Down
28 changes: 20 additions & 8 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler

from lerobot.common.envs.transforms import NormalizeTransform
from lerobot.common.envs.transforms import NormalizeTransform, Prod

DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))

Expand Down Expand Up @@ -84,6 +84,16 @@ def make_offline_buffer(
prefetch=prefetch if isinstance(prefetch, int) else None,
)

if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys

transforms = [Prod(in_keys=img_keys, prod=1 / 255)]

if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
Expand All @@ -92,11 +102,10 @@ def make_offline_buffer(
in_keys = [("observation", "state"), ("action")]

if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", *key))
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys += img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state"))

if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
Expand All @@ -106,8 +115,11 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)

transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))

offline_buffer.set_transform(transforms)

if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)
Expand Down
21 changes: 19 additions & 2 deletions lerobot/common/datasets/simxarm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
import zipfile
from pathlib import Path
from typing import Callable

Expand All @@ -15,6 +16,22 @@
from lerobot.common.datasets.abstract import AbstractExperienceReplay


def download():
raise NotImplementedError()
import gdown

url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
Path(download_path).unlink()


class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [
"xarm_lift_medium",
Expand Down Expand Up @@ -48,8 +65,8 @@ def __init__(
)

def _download_and_preproc(self):
# download
# TODO(rcadene)
# TODO(rcadene): finish download
download()

dataset_path = self.data_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
Expand Down
80 changes: 80 additions & 0 deletions lerobot/common/envs/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import abc
from collections import deque
from typing import Optional

from tensordict import TensorDict
from torchrl.envs import EnvBase


class AbstractEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []

if pixels_only:
assert from_pixels
if from_pixels:
assert image_size

self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)

if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)

def register_rendering_hook(self, func):
self._rendering_hooks.append(func)

def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)

def reset_rendering_hooks(self):
self._rendering_hooks = []

@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()

@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()

@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()

@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()

@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>

<equality>
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
</equality>


<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />

<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
</body>
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
</body>

<body name="peg" pos="0.2 0.5 0.05">
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
</body>

<body name="socket" pos="-0.2 0.5 0.05">
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
</body>

</worldbody>

<actuator>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>

<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>

</actuator>

<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
</keyframe>


</mujoco>
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>

<equality>
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
</equality>


<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />

<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
</body>
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
</body>

<body name="box" pos="0.2 0.5 0.05">
<joint name="red_box_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
</body>

</worldbody>

<actuator>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>

<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>

</actuator>

<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
</keyframe>


</mujoco>
Loading

0 comments on commit 9d00203

Please sign in to comment.