Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User/rcadene/2024 09 10 train aloha debug #465

Draft
wants to merge 70 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
96cc243
Mock OpenCVCamera
Cadene Sep 1, 2024
2469c99
fix unit tests
Cadene Sep 9, 2024
44b8394
add dynamic import for cv2 and pyrealsense2
Cadene Sep 9, 2024
3bd5ea4
WIP
Cadene Sep 10, 2024
e1763aa
Clean + Add act_aloha_real.yaml + Add act_real.yaml
Cadene Sep 10, 2024
bc0e691
force push aloha_real.yaml
Cadene Sep 10, 2024
4151630
Mock dynamixel_sdk
Cadene Sep 11, 2024
53ebf9c
Mock robots (WIP segmentation fault)
Cadene Sep 11, 2024
cd4d225
Fix unit test
Cadene Sep 12, 2024
3f993d5
fix typo
Cadene Sep 12, 2024
e47856a
Fix unit test test_policies, backward, Remove no_state from test
Cadene Sep 15, 2024
783b78a
Fix unit test test_policies, backward, Remove no_state from test
Cadene Sep 15, 2024
bab19d9
Merge branch 'main' into user/rcadene/2024_09_10_train_aloha
Cadene Sep 15, 2024
ccc0586
Apply suggestions from code review
Cadene Sep 16, 2024
6636db5
Address comments
Cadene Sep 16, 2024
624551b
Address comments
Cadene Sep 16, 2024
adc8dc9
Address comments
Cadene Sep 16, 2024
886923a
Fix opencv segmentation fault (#442)
aliberts Sep 25, 2024
1bf2845
pre-commit run --all-files
Cadene Sep 25, 2024
f0452c2
Merge branch 'main' into user/rcadene/2024_09_01_mock_robot_devices
Cadene Sep 25, 2024
bcf27b8
Skip mocking tests with minimal pytest
Cadene Sep 25, 2024
5584201
mock=False
Cadene Sep 25, 2024
6377d2a
mock)
Cadene Sep 25, 2024
bded8cb
Fix unit tests
Cadene Sep 25, 2024
2c01716
fix aloha mock
Cadene Sep 25, 2024
500d505
Add support for video=False in record (no tested yet)
Cadene Sep 26, 2024
f2b1842
fix unit test
Cadene Sep 26, 2024
3cb85bc
Fix unit test
Cadene Sep 26, 2024
a236382
fix unit tests
Cadene Sep 26, 2024
8b36223
fix unit tests
Cadene Sep 26, 2024
b6b7fda
custom pytest speedup (TOREMOVE)
Cadene Sep 26, 2024
8a7b5c4
Remove @require_x
Cadene Sep 26, 2024
395720a
Revert "Remove @require_x"
Cadene Sep 26, 2024
48be576
fix unit tests
Cadene Sep 26, 2024
89b2b73
fix unit tests
Cadene Sep 26, 2024
e66900e
mock_motor instead of require_mock_motor
Cadene Sep 26, 2024
7450adc
no more require_mock_motor
Cadene Sep 26, 2024
8da0893
move mock_motor in test_motors.py
Cadene Sep 26, 2024
a7350d9
add mock=False
Cadene Sep 27, 2024
bf7e906
add +COLOR_RGB2BGR
Cadene Sep 27, 2024
81f17d5
if not '~cameras' in overrides
Cadene Sep 27, 2024
e499d60
fix unit test
Cadene Sep 27, 2024
0352c61
Add more exception except
Cadene Sep 27, 2024
c704eb9
improve except
Cadene Sep 27, 2024
3f9f3dd
Add pyserial
Cadene Sep 27, 2024
da1888a
revert to all tests
Cadene Sep 27, 2024
675d428
add
Cadene Sep 27, 2024
76cc479
add
Cadene Sep 27, 2024
50a979d
Check if file exists
Cadene Sep 27, 2024
9dea00e
retest
Cadene Sep 27, 2024
2e694fc
test
Cadene Sep 27, 2024
88c2ed4
fix unit tests
Cadene Sep 27, 2024
cc5c623
test
Cadene Sep 27, 2024
2c9defa
test
Cadene Sep 27, 2024
bc479cb
test
Cadene Sep 27, 2024
0e63f7c
test
Cadene Sep 27, 2024
83cfe60
tests
Cadene Sep 27, 2024
1de04e4
Merge branch 'main' into user/rcadene/2024_09_01_mock_robot_devices
Cadene Sep 27, 2024
5c73bec
Address Jess comments
Cadene Sep 28, 2024
48911e0
Merge remote-tracking branch 'origin/main' into user/rcadene/2024_09_…
Cadene Sep 28, 2024
9b76ee9
Merge remote-tracking branch 'origin/user/rcadene/2024_09_01_mock_rob…
Cadene Sep 28, 2024
77ba43d
WIP: add multiprocess
Cadene Sep 28, 2024
8b89d03
Merge remote-tracking branch 'origin/user/rcadene/2024_09_10_train_al…
Cadene Sep 28, 2024
3369d35
Fix slow fps
Cadene Sep 28, 2024
e58e594
Add num_workers >=1 capabilities (default to 1)
Cadene Sep 28, 2024
433e950
Merge remote-tracking branch 'origin/main' into user/rcadene/2024_09_…
Cadene Oct 3, 2024
68fff56
Merge remote-tracking branch 'origin/main' into user/rcadene/2024_09_…
Cadene Oct 4, 2024
dc08c3b
small
Cadene Oct 7, 2024
2a8a9dc
TOREMOVE: remove aloha from __init__ to test if this creates the bug
Cadene Oct 7, 2024
82df3fe
TOREMOVE: isolate aloha on __init__ to see if it creates the bug
Cadene Oct 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lerobot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@

# lists all available robots from `lerobot/common/robot_devices/robots`
available_robots = [
"koch",
"koch_bimanual",
# "koch",
# "koch_bimanual",
"aloha",
]

Expand All @@ -216,7 +216,9 @@
"aloha": ["act"],
"pusht": ["diffusion", "vqbet"],
"xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"],
"koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"],
"dora_aloha_real": ["act_aloha_real"],
}

env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
Expand Down
10 changes: 10 additions & 0 deletions lerobot/configs/env/aloha_real.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# @package _global_

fps: 30

env:
name: real_world
task: null
state_dim: 14
action_dim: 14
fps: ${fps}
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# @package _global_

# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
# Use `act_aloha_real.yaml` to train on real-world datasets collected on Aloha or Aloha-2 robots.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, cam_high, cam_low) instead of 1 camera (i.e. top).
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
#
# Example of usage for training:
# Example of usage for training and inference with `control_robot.py`:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# policy=act_aloha_real \
# env=aloha_real
# ```
#
# Example of usage for training and inference with [Dora-rs](https://github.com/dora-rs/dora-lerobot):
# ```bash
# python lerobot/scripts/train.py \
# policy=act_aloha_real \
# env=dora_aloha_real
# ```

Expand All @@ -36,10 +42,11 @@ override_dataset_stats:
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)

training:
offline_steps: 100000
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 20000
save_freq: 10000
log_freq: 100
save_checkpoint: true

batch_size: 8
Expand All @@ -62,7 +69,7 @@ policy:

# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
chunk_size: 100
n_action_steps: 100

input_shapes:
Expand Down Expand Up @@ -107,7 +114,7 @@ policy:
n_vae_encoder_layers: 4

# Inference.
temporal_ensemble_coeff: null
temporal_ensemble_momentum: null

# Training and loss computation.
dropout: 0.1
Expand Down
110 changes: 0 additions & 110 deletions lerobot/configs/policy/act_real_no_state.yaml

This file was deleted.

66 changes: 54 additions & 12 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import concurrent.futures
import json
import logging
import multiprocessing
import os
import platform
import shutil
Expand Down Expand Up @@ -239,6 +240,48 @@ def is_headless():
return True


def loop_to_save_frame_in_threads(frame_queue, num_image_writers):
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = frame_queue.get()

# Exit if we send None to stop the worker
if frame_data is None:
# Wait for all submitted futures to complete before exiting
for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
):
pass
break

frame, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir))


def start_frame_workers(frame_queue, num_image_writers, num_workers=1):
workers = []
for _ in range(num_workers):
worker = multiprocessing.Process(
target=loop_to_save_frame_in_threads,
args=(frame_queue, num_image_writers),
)
worker.start()
workers.append(worker)
return workers


def stop_workers(workers, frame_queue):
# Send None to each process to signal it to stop
for _ in workers:
frame_queue.put(None)

# Wait for all processes to terminate
for process in workers:
process.join()


def has_method(_object: object, method_name: str):
return hasattr(_object, method_name) and callable(getattr(_object, method_name))

Expand Down Expand Up @@ -465,10 +508,13 @@ def on_press(key):

# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
futures = []
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
num_image_writers = max(num_image_writers, 1)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
frame_queue = multiprocessing.Queue()
frame_workers = start_frame_workers(frame_queue, num_image_writers)

# Using `try` to exist smoothly if an exception is raised
try:
# Start recording all episodes
while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}")
Expand All @@ -489,11 +535,7 @@ def on_press(key):
not_image_keys = [key for key in observation if "image" not in key]

for key in image_keys:
futures += [
executor.submit(
save_image, observation[key], key, frame_index, episode_index, videos_dir
)
]
frame_queue.put((observation[key], key, frame_index, episode_index, videos_dir))

if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
Expand Down Expand Up @@ -640,11 +682,11 @@ def on_press(key):
listener.stop()

logging.info("Waiting for threads writing the images on disk to terminate...")
for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
):
pass
break
stop_workers(frame_workers, frame_queue)

except Exception:
traceback.print_exc()
stop_workers(frame_workers, frame_queue)

robot.disconnect()
if display_cameras and not is_headless():
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

25 changes: 20 additions & 5 deletions tests/test_control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,28 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):

replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)

# TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha":
env_name = "aloha_real"
policy_name = "act_aloha_real"
elif robot_type in ["koch", "koch_bimanual"]:
env_name = "koch_real"
policy_name = "act_koch_real"
else:
raise NotImplementedError(robot_type)

overrides = [
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]

if robot_type == "koch_bimanual":
overrides += ["env.state_dim=12", "env.action_dim=12"]

cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
],
overrides=overrides,
)

policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,11 @@ def test_flatten_unflatten_dict():
# "lerobot/cmu_stretch",
],
)
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""

dataset = LeRobotDataset(
repo_id,
)
dataset = LeRobotDataset(repo_id)

test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id

Expand Down
3 changes: 1 addition & 2 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ def test_normalize(insert_temporal_dim):
),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_aloha_real", ["policy.n_action_steps=10"], ""),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
Expand Down
Loading