diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py index 852046018..7974ab8ef 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -15,67 +15,118 @@ # limitations under the License. """ This file contains download scripts for raw datasets. + +Example of usage: +``` +python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \ +--raw-dir data/cadene/pusht_raw \ +--repo-id cadene/pusht_raw +``` """ +import argparse import logging +import warnings from pathlib import Path from huggingface_hub import snapshot_download -def download_raw(raw_dir: Path, dataset_id: str): +def download_raw(raw_dir: Path, repo_id: str): + # Check repo_id is well formated + if len(repo_id.split("/")) != 2: + raise ValueError( + f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'." + ) + user_id, dataset_id = repo_id.split("/") + + if not dataset_id.endswith("_raw"): + warnings.warn( + f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.", + stacklevel=1, + ) + raw_dir = Path(raw_dir) + # Send warning if raw_dir isn't well formated + if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id: + warnings.warn( + f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.", + stacklevel=1, + ) raw_dir.mkdir(parents=True, exist_ok=True) - logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}") - snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir) - logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}") + logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") + snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir) + logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") -if __name__ == "__main__": +def download_all_raw_datasets(): data_dir = Path("data") - dataset_ids = [ - "pusht_image", - "xarm_lift_medium_image", - "xarm_lift_medium_replay_image", - "xarm_push_medium_image", - "xarm_push_medium_replay_image", - "aloha_sim_insertion_human_image", - "aloha_sim_insertion_scripted_image", - "aloha_sim_transfer_cube_human_image", - "aloha_sim_transfer_cube_scripted_image", - "pusht", - "xarm_lift_medium", - "xarm_lift_medium_replay", - "xarm_push_medium", - "xarm_push_medium_replay", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", - "aloha_mobile_cabinet", - "aloha_mobile_chair", - "aloha_mobile_elevator", - "aloha_mobile_shrimp", - "aloha_mobile_wash_pan", - "aloha_mobile_wipe_wine", - "aloha_static_battery", - "aloha_static_candy", - "aloha_static_coffee", - "aloha_static_coffee_new", - "aloha_static_cups_open", - "aloha_static_fork_pick_up", - "aloha_static_pingpong_test", - "aloha_static_pro_pencil", - "aloha_static_screw_driver", - "aloha_static_tape", - "aloha_static_thread_velcro", - "aloha_static_towel", - "aloha_static_vinh_cup", - "aloha_static_vinh_cup_left", - "aloha_static_ziploc_slide", - "umi_cup_in_the_wild", + repo_ids = [ + "cadene/pusht_image_raw", + "cadene/xarm_lift_medium_image_raw", + "cadene/xarm_lift_medium_replay_image_raw", + "cadene/xarm_push_medium_image_raw", + "cadene/xarm_push_medium_replay_image_raw", + "cadene/aloha_sim_insertion_human_image_raw", + "cadene/aloha_sim_insertion_scripted_image_raw", + "cadene/aloha_sim_transfer_cube_human_image_raw", + "cadene/aloha_sim_transfer_cube_scripted_image_raw", + "cadene/pusht_raw", + "cadene/xarm_lift_medium_raw", + "cadene/xarm_lift_medium_replay_raw", + "cadene/xarm_push_medium_raw", + "cadene/xarm_push_medium_replay_raw", + "cadene/aloha_sim_insertion_human_raw", + "cadene/aloha_sim_insertion_scripted_raw", + "cadene/aloha_sim_transfer_cube_human_raw", + "cadene/aloha_sim_transfer_cube_scripted_raw", + "cadene/aloha_mobile_cabinet_raw", + "cadene/aloha_mobile_chair_raw", + "cadene/aloha_mobile_elevator_raw", + "cadene/aloha_mobile_shrimp_raw", + "cadene/aloha_mobile_wash_pan_raw", + "cadene/aloha_mobile_wipe_wine_raw", + "cadene/aloha_static_battery_raw", + "cadene/aloha_static_candy_raw", + "cadene/aloha_static_coffee_raw", + "cadene/aloha_static_coffee_new_raw", + "cadene/aloha_static_cups_open_raw", + "cadene/aloha_static_fork_pick_up_raw", + "cadene/aloha_static_pingpong_test_raw", + "cadene/aloha_static_pro_pencil_raw", + "cadene/aloha_static_screw_driver_raw", + "cadene/aloha_static_tape_raw", + "cadene/aloha_static_thread_velcro_raw", + "cadene/aloha_static_towel_raw", + "cadene/aloha_static_vinh_cup_raw", + "cadene/aloha_static_vinh_cup_left_raw", + "cadene/aloha_static_ziploc_slide_raw", + "cadene/umi_cup_in_the_wild_raw", ] - for dataset_id in dataset_ids: - raw_dir = data_dir / f"{dataset_id}_raw" - download_raw(raw_dir, dataset_id) + for repo_id in repo_ids: + raw_dir = data_dir / repo_id + download_raw(raw_dir, repo_id) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--raw-dir", + type=Path, + required=True, + help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).", + ) + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).", + ) + args = parser.parse_args() + download_raw(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 5379f3b4b..e43e5f274 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -142,31 +142,46 @@ def push_dataset_to_hub( cache_dir: Path = Path("/tmp"), tests_data_dir: Path | None = None, ): + # Robustify when `raw_dir` is str instead of Path + raw_dir = Path(raw_dir) + + if not raw_dir.exists(): + raise NotADirectoryError( + f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:" + f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw" + ) + # Check repo_id is well formated if len(repo_id.split("/")) != 2: raise ValueError( f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'." ) + user_id, dataset_id = repo_id.split("/") - # Robustify when `local_dir` is str instead of Path - if local_dir is not None: + if local_dir: + # Robustify when `local_dir` is str instead of Path local_dir = Path(local_dir) - # Send warning if local_dir isn't well formated - if local_dir is not None: - user_id, dataset_id = repo_id.split("/") + # Send warning if local_dir isn't well formated if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id: warnings.warn( - f"`local_dir` is expected to contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'), but is {local_dir}.", + f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.", stacklevel=1, ) - # Check we don't override an existing `local_dir` by mistake - if local_dir is not None and local_dir.exists(): - if force_override: - shutil.rmtree(local_dir) - else: - raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.") + # Check we don't override an existing `local_dir` by mistake + if local_dir.exists(): + if force_override: + shutil.rmtree(local_dir) + else: + raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.") + + meta_data_dir = local_dir / "meta_data" + videos_dir = local_dir / "videos" + else: + # Temporary directory used to store images, videos, meta_data + meta_data_dir = Path(cache_dir) / "meta_data" + videos_dir = Path(cache_dir) / "videos" # Download the raw dataset if available if not raw_dir.exists(): @@ -177,14 +192,6 @@ def push_dataset_to_hub( raise NotImplementedError() # raw_format = auto_find_raw_format(raw_dir) - if local_dir: - meta_data_dir = local_dir / "meta_data" - videos_dir = local_dir / "videos" - else: - # Temporary directory used to store images, videos, meta_data - meta_data_dir = Path(cache_dir) / "meta_data" - videos_dir = Path(cache_dir) / "videos" - # convert dataset from original raw format to LeRobot format from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py index b32a49408..2822e6ebe 100644 --- a/tests/test_push_dataset_to_hub.py +++ b/tests/test_push_dataset_to_hub.py @@ -10,7 +10,6 @@ """ from pathlib import Path -from unittest.mock import patch import numpy as np import pytest @@ -215,19 +214,19 @@ def write_parquet(key, timestamps, values): encode_video_frames(tmp_imgs_dir, video_path, fps) -def _mock_download_raw(raw_dir, dataset_id): - if "wrist_gripper" in dataset_id: +def _mock_download_raw(raw_dir, repo_id): + if "wrist_gripper" in repo_id: _mock_download_raw_dora(raw_dir) - elif "aloha" in dataset_id: + elif "aloha" in repo_id: _mock_download_raw_aloha(raw_dir) - elif "pusht" in dataset_id: + elif "pusht" in repo_id: _mock_download_raw_pusht(raw_dir) - elif "xarm" in dataset_id: + elif "xarm" in repo_id: _mock_download_raw_xarm(raw_dir) - elif "umi" in dataset_id: + elif "umi" in repo_id: _mock_download_raw_umi(raw_dir) else: - raise ValueError(dataset_id) + raise ValueError(repo_id) def test_push_dataset_to_hub_invalid_repo_id(): @@ -267,9 +266,9 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_ num_episodes = 3 tmpdir = Path(tmpdir) - _, dataset_id = repo_id.split("/") + raw_dir = tmpdir / f"{repo_id}_raw" + _mock_download_raw(raw_dir, repo_id) - raw_dir = tmpdir / f"{dataset_id}_raw" local_dir = tmpdir / repo_id lerobot_dataset = push_dataset_to_hub( @@ -314,7 +313,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_ ], ) @pytest.mark.skip( - "Not compatible with our CI since it downloads raw datasets. Uncomment to test backward compatibility locally." + "Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`" ) def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id): _, dataset_id = repo_id.split("/")