Skip to content

Commit

Permalink
Add suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Jun 13, 2024
1 parent a79d332 commit da9baa5
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 80 deletions.
149 changes: 100 additions & 49 deletions lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,118 @@
# limitations under the License.
"""
This file contains download scripts for raw datasets.
Example of usage:
```
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
--raw-dir data/cadene/pusht_raw \
--repo-id cadene/pusht_raw
```
"""

import argparse
import logging
import warnings
from pathlib import Path

from huggingface_hub import snapshot_download


def download_raw(raw_dir: Path, dataset_id: str):
def download_raw(raw_dir: Path, repo_id: str):
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")

if not dataset_id.endswith("_raw"):
warnings.warn(
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
stacklevel=1,
)

raw_dir = Path(raw_dir)
# Send warning if raw_dir isn't well formated
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
warnings.warn(
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
raw_dir.mkdir(parents=True, exist_ok=True)

logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")


if __name__ == "__main__":
def download_all_raw_datasets():
data_dir = Path("data")
dataset_ids = [
"pusht_image",
"xarm_lift_medium_image",
"xarm_lift_medium_replay_image",
"xarm_push_medium_image",
"xarm_push_medium_replay_image",
"aloha_sim_insertion_human_image",
"aloha_sim_insertion_scripted_image",
"aloha_sim_transfer_cube_human_image",
"aloha_sim_transfer_cube_scripted_image",
"pusht",
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"aloha_mobile_cabinet",
"aloha_mobile_chair",
"aloha_mobile_elevator",
"aloha_mobile_shrimp",
"aloha_mobile_wash_pan",
"aloha_mobile_wipe_wine",
"aloha_static_battery",
"aloha_static_candy",
"aloha_static_coffee",
"aloha_static_coffee_new",
"aloha_static_cups_open",
"aloha_static_fork_pick_up",
"aloha_static_pingpong_test",
"aloha_static_pro_pencil",
"aloha_static_screw_driver",
"aloha_static_tape",
"aloha_static_thread_velcro",
"aloha_static_towel",
"aloha_static_vinh_cup",
"aloha_static_vinh_cup_left",
"aloha_static_ziploc_slide",
"umi_cup_in_the_wild",
repo_ids = [
"cadene/pusht_image_raw",
"cadene/xarm_lift_medium_image_raw",
"cadene/xarm_lift_medium_replay_image_raw",
"cadene/xarm_push_medium_image_raw",
"cadene/xarm_push_medium_replay_image_raw",
"cadene/aloha_sim_insertion_human_image_raw",
"cadene/aloha_sim_insertion_scripted_image_raw",
"cadene/aloha_sim_transfer_cube_human_image_raw",
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
"cadene/pusht_raw",
"cadene/xarm_lift_medium_raw",
"cadene/xarm_lift_medium_replay_raw",
"cadene/xarm_push_medium_raw",
"cadene/xarm_push_medium_replay_raw",
"cadene/aloha_sim_insertion_human_raw",
"cadene/aloha_sim_insertion_scripted_raw",
"cadene/aloha_sim_transfer_cube_human_raw",
"cadene/aloha_sim_transfer_cube_scripted_raw",
"cadene/aloha_mobile_cabinet_raw",
"cadene/aloha_mobile_chair_raw",
"cadene/aloha_mobile_elevator_raw",
"cadene/aloha_mobile_shrimp_raw",
"cadene/aloha_mobile_wash_pan_raw",
"cadene/aloha_mobile_wipe_wine_raw",
"cadene/aloha_static_battery_raw",
"cadene/aloha_static_candy_raw",
"cadene/aloha_static_coffee_raw",
"cadene/aloha_static_coffee_new_raw",
"cadene/aloha_static_cups_open_raw",
"cadene/aloha_static_fork_pick_up_raw",
"cadene/aloha_static_pingpong_test_raw",
"cadene/aloha_static_pro_pencil_raw",
"cadene/aloha_static_screw_driver_raw",
"cadene/aloha_static_tape_raw",
"cadene/aloha_static_thread_velcro_raw",
"cadene/aloha_static_towel_raw",
"cadene/aloha_static_vinh_cup_raw",
"cadene/aloha_static_vinh_cup_left_raw",
"cadene/aloha_static_ziploc_slide_raw",
"cadene/umi_cup_in_the_wild_raw",
]
for dataset_id in dataset_ids:
raw_dir = data_dir / f"{dataset_id}_raw"
download_raw(raw_dir, dataset_id)
for repo_id in repo_ids:
raw_dir = data_dir / repo_id
download_raw(raw_dir, repo_id)


def main():
parser = argparse.ArgumentParser()

parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
)
args = parser.parse_args()
download_raw(**vars(args))


if __name__ == "__main__":
main()
47 changes: 27 additions & 20 deletions lerobot/scripts/push_dataset_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,31 +142,46 @@ def push_dataset_to_hub(
cache_dir: Path = Path("/tmp"),
tests_data_dir: Path | None = None,
):
# Robustify when `raw_dir` is str instead of Path
raw_dir = Path(raw_dir)

if not raw_dir.exists():
raise NotADirectoryError(
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
)

# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")

# Robustify when `local_dir` is str instead of Path
if local_dir is not None:
if local_dir:
# Robustify when `local_dir` is str instead of Path
local_dir = Path(local_dir)

# Send warning if local_dir isn't well formated
if local_dir is not None:
user_id, dataset_id = repo_id.split("/")
# Send warning if local_dir isn't well formated
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
warnings.warn(
f"`local_dir` is expected to contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'), but is {local_dir}.",
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)

# Check we don't override an existing `local_dir` by mistake
if local_dir is not None and local_dir.exists():
if force_override:
shutil.rmtree(local_dir)
else:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
# Check we don't override an existing `local_dir` by mistake
if local_dir.exists():
if force_override:
shutil.rmtree(local_dir)
else:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")

meta_data_dir = local_dir / "meta_data"
videos_dir = local_dir / "videos"
else:
# Temporary directory used to store images, videos, meta_data
meta_data_dir = Path(cache_dir) / "meta_data"
videos_dir = Path(cache_dir) / "videos"

# Download the raw dataset if available
if not raw_dir.exists():
Expand All @@ -177,14 +192,6 @@ def push_dataset_to_hub(
raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir)

if local_dir:
meta_data_dir = local_dir / "meta_data"
videos_dir = local_dir / "videos"
else:
# Temporary directory used to store images, videos, meta_data
meta_data_dir = Path(cache_dir) / "meta_data"
videos_dir = Path(cache_dir) / "videos"

# convert dataset from original raw format to LeRobot format
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
Expand Down
21 changes: 10 additions & 11 deletions tests/test_push_dataset_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

from pathlib import Path
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -215,19 +214,19 @@ def write_parquet(key, timestamps, values):
encode_video_frames(tmp_imgs_dir, video_path, fps)


def _mock_download_raw(raw_dir, dataset_id):
if "wrist_gripper" in dataset_id:
def _mock_download_raw(raw_dir, repo_id):
if "wrist_gripper" in repo_id:
_mock_download_raw_dora(raw_dir)
elif "aloha" in dataset_id:
elif "aloha" in repo_id:
_mock_download_raw_aloha(raw_dir)
elif "pusht" in dataset_id:
elif "pusht" in repo_id:
_mock_download_raw_pusht(raw_dir)
elif "xarm" in dataset_id:
elif "xarm" in repo_id:
_mock_download_raw_xarm(raw_dir)
elif "umi" in dataset_id:
elif "umi" in repo_id:
_mock_download_raw_umi(raw_dir)
else:
raise ValueError(dataset_id)
raise ValueError(repo_id)


def test_push_dataset_to_hub_invalid_repo_id():
Expand Down Expand Up @@ -267,9 +266,9 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
num_episodes = 3
tmpdir = Path(tmpdir)

_, dataset_id = repo_id.split("/")
raw_dir = tmpdir / f"{repo_id}_raw"
_mock_download_raw(raw_dir, repo_id)

raw_dir = tmpdir / f"{dataset_id}_raw"
local_dir = tmpdir / repo_id

lerobot_dataset = push_dataset_to_hub(
Expand Down Expand Up @@ -314,7 +313,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
],
)
@pytest.mark.skip(
"Not compatible with our CI since it downloads raw datasets. Uncomment to test backward compatibility locally."
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")
Expand Down

0 comments on commit da9baa5

Please sign in to comment.