From 125bd93e2975566eb55eebe3e0909aaf544f5ab2 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Thu, 13 Jun 2024 15:18:02 +0200 Subject: [PATCH] Improve `push_dataset_to_hub` API + Add unit tests (#231) Co-authored-by: Remi Co-authored-by: Simon Alibert Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- .github/workflows/test.yml | 9 +- README.md | 10 +- .../push_dataset_to_hub/_download_raw.py | 233 +++++------- .../push_dataset_to_hub/aloha_hdf5_format.py | 46 +-- ..._dora_format.py => dora_parquet_format.py} | 36 +- .../push_dataset_to_hub/pusht_zarr_format.py | 63 ++-- .../push_dataset_to_hub/umi_zarr_format.py | 87 ++--- .../push_dataset_to_hub/xarm_pkl_format.py | 77 ++-- lerobot/scripts/push_dataset_to_hub.py | 229 ++++++------ tests/test_push_dataset_to_hub.py | 352 ++++++++++++++++++ tests/utils.py | 35 +- 11 files changed, 754 insertions(+), 423 deletions(-) rename lerobot/common/datasets/push_dataset_to_hub/{aloha_dora_format.py => dora_parquet_format.py} (90%) create mode 100644 tests/test_push_dataset_to_hub.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f10f541e2..038b44582 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,8 +34,8 @@ jobs: with: lfs: true # Ensure LFS files are pulled - - name: Install EGL - run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev + - name: Install apt dependencies + run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg - name: Install poetry run: | @@ -72,6 +72,9 @@ jobs: with: lfs: true # Ensure LFS files are pulled + - name: Install apt dependencies + run: sudo apt-get update && sudo apt-get install -y ffmpeg + - name: Install poetry run: | pipx install poetry && poetry config virtualenvs.in-project true @@ -106,7 +109,7 @@ jobs: with: lfs: true # Ensure LFS files are pulled - - name: Install EGL + - name: Install apt dependencies run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev - name: Install poetry diff --git a/README.md b/README.md index 12ebe8d0c..d76969bc7 100644 --- a/README.md +++ b/README.md @@ -228,13 +228,13 @@ To add a dataset to the hub, you need to login using a write-access token, which huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` -Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with: +Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with: ```bash python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id aloha_static_pingpong_test \ ---raw-format aloha_hdf5 \ ---community-id lerobot +--raw-dir data/aloha_static_pingpong_test_raw \ +--out-dir data \ +--repo-id lerobot/aloha_static_pingpong_test \ +--raw-format aloha_hdf5 ``` See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions. diff --git a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py index 7074bcbaa..7974ab8ef 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_download_raw.py @@ -14,156 +14,119 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This file contains all obsolete download scripts. They are centralized here to not have to load -useless dependencies when using datasets. +This file contains download scripts for raw datasets. + +Example of usage: +``` +python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \ +--raw-dir data/cadene/pusht_raw \ +--repo-id cadene/pusht_raw +``` """ -import io +import argparse import logging -import shutil +import warnings from pathlib import Path -import tqdm from huggingface_hub import snapshot_download -def download_raw(raw_dir, dataset_id): - if "aloha" in dataset_id or "image" in dataset_id: - download_hub(raw_dir, dataset_id) - elif "pusht" in dataset_id: - download_pusht(raw_dir) - elif "xarm" in dataset_id: - download_xarm(raw_dir) - elif "umi" in dataset_id: - download_umi(raw_dir) - else: - raise ValueError(dataset_id) +def download_raw(raw_dir: Path, repo_id: str): + # Check repo_id is well formated + if len(repo_id.split("/")) != 2: + raise ValueError( + f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'." + ) + user_id, dataset_id = repo_id.split("/") - -def download_and_extract_zip(url: str, destination_folder: Path) -> bool: - import zipfile - - import requests - - print(f"downloading from {url}") - response = requests.get(url, stream=True) - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) - - zip_file = io.BytesIO() - for chunk in response.iter_content(chunk_size=1024): - if chunk: - zip_file.write(chunk) - progress_bar.update(len(chunk)) - - progress_bar.close() - - zip_file.seek(0) - - with zipfile.ZipFile(zip_file, "r") as zip_ref: - zip_ref.extractall(destination_folder) - - -def download_pusht(raw_dir: str): - pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" - - raw_dir = Path(raw_dir) - raw_dir.mkdir(parents=True, exist_ok=True) - download_and_extract_zip(pusht_url, raw_dir) - # file is created inside a useful "pusht" directory, so we move it out and delete the dir - zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr" - shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path) - shutil.rmtree(raw_dir / "pusht") - - -def download_xarm(raw_dir: Path): - """Download all xarm datasets at once""" - import zipfile - - import gdown + if not dataset_id.endswith("_raw"): + warnings.warn( + f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.", + stacklevel=1, + ) raw_dir = Path(raw_dir) + # Send warning if raw_dir isn't well formated + if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id: + warnings.warn( + f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.", + stacklevel=1, + ) raw_dir.mkdir(parents=True, exist_ok=True) - # from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py - url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya" - zip_path = raw_dir / "data.zip" - gdown.download(url, str(zip_path), quiet=False) - print("Extracting...") - with zipfile.ZipFile(str(zip_path), "r") as zip_f: - for pkl_path in zip_f.namelist(): - if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"): - zip_f.extract(member=pkl_path) - # move to corresponding raw directory - extract_dir = pkl_path.replace("/buffer.pkl", "") - raw_pkl_path = raw_dir / "buffer.pkl" - shutil.move(pkl_path, raw_pkl_path) - shutil.rmtree(extract_dir) - zip_path.unlink() - - -def download_hub(raw_dir: Path, dataset_id: str): - raw_dir = Path(raw_dir) - raw_dir.mkdir(parents=True, exist_ok=True) - - logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}") - snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir) - logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}") + logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}") + snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir) + logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}") -def download_umi(raw_dir: Path): - url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip" - zarr_path = raw_dir / "cup_in_the_wild.zarr" - raw_dir = Path(raw_dir) - raw_dir.mkdir(parents=True, exist_ok=True) - download_and_extract_zip(url_cup_in_the_wild, zarr_path) +def download_all_raw_datasets(): + data_dir = Path("data") + repo_ids = [ + "cadene/pusht_image_raw", + "cadene/xarm_lift_medium_image_raw", + "cadene/xarm_lift_medium_replay_image_raw", + "cadene/xarm_push_medium_image_raw", + "cadene/xarm_push_medium_replay_image_raw", + "cadene/aloha_sim_insertion_human_image_raw", + "cadene/aloha_sim_insertion_scripted_image_raw", + "cadene/aloha_sim_transfer_cube_human_image_raw", + "cadene/aloha_sim_transfer_cube_scripted_image_raw", + "cadene/pusht_raw", + "cadene/xarm_lift_medium_raw", + "cadene/xarm_lift_medium_replay_raw", + "cadene/xarm_push_medium_raw", + "cadene/xarm_push_medium_replay_raw", + "cadene/aloha_sim_insertion_human_raw", + "cadene/aloha_sim_insertion_scripted_raw", + "cadene/aloha_sim_transfer_cube_human_raw", + "cadene/aloha_sim_transfer_cube_scripted_raw", + "cadene/aloha_mobile_cabinet_raw", + "cadene/aloha_mobile_chair_raw", + "cadene/aloha_mobile_elevator_raw", + "cadene/aloha_mobile_shrimp_raw", + "cadene/aloha_mobile_wash_pan_raw", + "cadene/aloha_mobile_wipe_wine_raw", + "cadene/aloha_static_battery_raw", + "cadene/aloha_static_candy_raw", + "cadene/aloha_static_coffee_raw", + "cadene/aloha_static_coffee_new_raw", + "cadene/aloha_static_cups_open_raw", + "cadene/aloha_static_fork_pick_up_raw", + "cadene/aloha_static_pingpong_test_raw", + "cadene/aloha_static_pro_pencil_raw", + "cadene/aloha_static_screw_driver_raw", + "cadene/aloha_static_tape_raw", + "cadene/aloha_static_thread_velcro_raw", + "cadene/aloha_static_towel_raw", + "cadene/aloha_static_vinh_cup_raw", + "cadene/aloha_static_vinh_cup_left_raw", + "cadene/aloha_static_ziploc_slide_raw", + "cadene/umi_cup_in_the_wild_raw", + ] + for repo_id in repo_ids: + raw_dir = data_dir / repo_id + download_raw(raw_dir, repo_id) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--raw-dir", + type=Path, + required=True, + help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).", + ) + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).", + ) + args = parser.parse_args() + download_raw(**vars(args)) if __name__ == "__main__": - data_dir = Path("data") - dataset_ids = [ - "pusht_image", - "xarm_lift_medium_image", - "xarm_lift_medium_replay_image", - "xarm_push_medium_image", - "xarm_push_medium_replay_image", - "aloha_sim_insertion_human_image", - "aloha_sim_insertion_scripted_image", - "aloha_sim_transfer_cube_human_image", - "aloha_sim_transfer_cube_scripted_image", - "pusht", - "xarm_lift_medium", - "xarm_lift_medium_replay", - "xarm_push_medium", - "xarm_push_medium_replay", - "aloha_sim_insertion_human", - "aloha_sim_insertion_scripted", - "aloha_sim_transfer_cube_human", - "aloha_sim_transfer_cube_scripted", - "aloha_mobile_cabinet", - "aloha_mobile_chair", - "aloha_mobile_elevator", - "aloha_mobile_shrimp", - "aloha_mobile_wash_pan", - "aloha_mobile_wipe_wine", - "aloha_static_battery", - "aloha_static_candy", - "aloha_static_coffee", - "aloha_static_coffee_new", - "aloha_static_cups_open", - "aloha_static_fork_pick_up", - "aloha_static_pingpong_test", - "aloha_static_pro_pencil", - "aloha_static_screw_driver", - "aloha_static_tape", - "aloha_static_thread_velcro", - "aloha_static_towel", - "aloha_static_vinh_cup", - "aloha_static_vinh_cup_left", - "aloha_static_ziploc_slide", - "umi_cup_in_the_wild", - ] - for dataset_id in dataset_ids: - raw_dir = data_dir / f"{dataset_id}_raw" - download_raw(raw_dir, dataset_id) + main() diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index 1c2f066ed..024045a03 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -30,6 +30,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -70,16 +71,17 @@ def check_format(raw_dir) -> bool: assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): # only frames from simulation are uncompressed compressed_images = "sim" not in raw_dir.name - hdf5_files = list(raw_dir.glob("*.hdf5")) - ep_dicts = [] - episode_data_index = {"from": [], "to": []} + hdf5_files = sorted(raw_dir.glob("episode_*.hdf5")) + num_episodes = len(hdf5_files) - id_from = 0 - for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)): + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx in tqdm.tqdm(ep_ids): + ep_path = hdf5_files[ep_idx] with h5py.File(ep_path, "r") as ep: num_frames = ep["/action"].shape[0] @@ -114,12 +116,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -147,19 +149,13 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): assert isinstance(ep_idx, int) ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from += num_frames - gc.collect() - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - return data_dict, episode_data_index + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict def to_hf_dataset(data_dict, video) -> Dataset: @@ -197,16 +193,22 @@ def to_hf_dataset(data_dict, video) -> Dataset: return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) if fps is None: fps = 50 - data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) - hf_dataset = to_hf_dataset(data_dir, video) - + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) + hf_dataset = to_hf_dataset(data_dict, video) + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py similarity index 90% rename from lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py rename to lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py index 4a21bc2df..1dc2e67e1 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py @@ -17,7 +17,6 @@ Contains utilities to process raw data format from dora-record """ -import logging import re from pathlib import Path @@ -26,10 +25,10 @@ from datasets import Dataset, Features, Image, Sequence, Value from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame -from lerobot.common.utils.utils import init_logging def check_format(raw_dir) -> bool: @@ -41,7 +40,7 @@ def check_format(raw_dir) -> bool: return True -def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): # Load data stream that will be used as reference for the timestamps synchronization reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) if len(reference_files) == 0: @@ -122,8 +121,7 @@ def get_episode_index(row): raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}") # Create symlink to raw videos directory (that needs to be absolute not relative) - out_dir.mkdir(parents=True, exist_ok=True) - videos_dir = out_dir / "videos" + videos_dir.parent.mkdir(parents=True, exist_ok=True) videos_dir.symlink_to((raw_dir / "videos").absolute()) # sanity check the video paths are well formated @@ -156,16 +154,7 @@ def get_episode_index(row): else: raise ValueError(key) - # Get the episode index containing for each unique episode index - first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index() - from_ = first_ep_index_df["start_index"].tolist() - to_ = from_[1:] + [len(df)] - episode_data_index = { - "from": from_, - "to": to_, - } - - return data_dict, episode_data_index + return data_dict def to_hf_dataset(data_dict, video) -> Dataset: @@ -203,12 +192,13 @@ def to_hf_dataset(data_dict, video) -> Dataset: return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): - init_logging() - - if debug: - logging.warning("debug=True not implemented. Falling back to debug=False.") - +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) @@ -220,9 +210,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru if not video: raise NotImplementedError() - data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps) + data_df = load_from_raw(raw_dir, videos_dir, fps, episodes) hf_dataset = to_hf_dataset(data_df, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 8133a36af..d9c7eb65e 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -27,6 +27,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -53,7 +54,7 @@ def check_format(raw_dir): assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): try: import pymunk from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely @@ -71,7 +72,6 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) episode_ids = torch.from_numpy(zarr_data.get_episode_idxs()) - num_episodes = zarr_data.meta["episode_ends"].shape[0] assert len( {zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118 ), "Some data type dont have the same number of total frames." @@ -84,25 +84,34 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): states = torch.from_numpy(zarr_data["state"]) actions = torch.from_numpy(zarr_data["action"]) - ep_dicts = [] - episode_data_index = {"from": [], "to": []} + # load data indices from which each episode starts and ends + from_ids, to_ids = [], [] + from_idx = 0 + for to_idx in zarr_data.meta["episode_ends"]: + from_ids.append(from_idx) + to_ids.append(to_idx) + from_idx = to_idx + + num_episodes = len(from_ids) - id_from = 0 - for ep_idx in tqdm.tqdm(range(num_episodes)): - id_to = zarr_data.meta["episode_ends"][ep_idx] - num_frames = id_to - id_from + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)): + from_idx = from_ids[selected_ep_idx] + to_idx = to_ids[selected_ep_idx] + num_frames = to_idx - from_idx # sanity check - assert (episode_ids[id_from:id_to] == ep_idx).all() + assert (episode_ids[from_idx:to_idx] == ep_idx).all() # get image - image = imgs[id_from:id_to] + image = imgs[from_idx:to_idx] assert image.min() >= 0.0 assert image.max() <= 255.0 image = image.type(torch.uint8) # get state - state = states[id_from:id_to] + state = states[from_idx:to_idx] agent_pos = state[:, :2] block_pos = state[:, 2:4] block_angle = state[:, 4] @@ -143,12 +152,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): img_key = "observation.image" if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -160,7 +169,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict["observation.state"] = agent_pos - ep_dict["action"] = actions[id_from:id_to] + ep_dict["action"] = actions[from_idx:to_idx] ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps @@ -172,17 +181,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]]) ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from += num_frames - - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - return data_dict, episode_data_index + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict def to_hf_dataset(data_dict, video): @@ -212,16 +215,22 @@ def to_hf_dataset(data_dict, video): return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) if fps is None: fps = 10 - data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) hf_dataset = to_hf_dataset(data_dict, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index cab2bdc52..6cd80c611 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -19,7 +19,6 @@ import shutil from pathlib import Path -import numpy as np import torch import tqdm import zarr @@ -29,6 +28,7 @@ from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -59,23 +59,7 @@ def check_format(raw_dir) -> bool: assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) -def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray: - # Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374 - from numba import jit - - @jit(nopython=True) - def _get_episode_idxs(episode_ends): - result = np.zeros((episode_ends[-1],), dtype=np.int64) - start_idx = 0 - for episode_number, end_idx in enumerate(episode_ends): - result[start_idx:end_idx] = episode_number - start_idx = end_idx - return result - - return _get_episode_idxs(episode_ends) - - -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): zarr_path = raw_dir / "cup_in_the_wild.zarr" zarr_data = zarr.open(zarr_path, mode="r") @@ -92,39 +76,41 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): episode_ends = zarr_data["meta/episode_ends"][:] num_episodes = episode_ends.shape[0] - episode_ids = torch.from_numpy(get_episode_idxs(episode_ends)) - # We convert it in torch tensor later because the jit function does not support torch tensors episode_ends = torch.from_numpy(episode_ends) - ep_dicts = [] - episode_data_index = {"from": [], "to": []} - - id_from = 0 - for ep_idx in tqdm.tqdm(range(num_episodes)): - id_to = episode_ends[ep_idx] - num_frames = id_to - id_from + # load data indices from which each episode starts and ends + from_ids, to_ids = [], [] + from_idx = 0 + for to_idx in episode_ends: + from_ids.append(from_idx) + to_ids.append(to_idx) + from_idx = to_idx - # sanity heck - assert (episode_ids[id_from:id_to] == ep_idx).all() + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)): + from_idx = from_ids[selected_ep_idx] + to_idx = to_ids[selected_ep_idx] + num_frames = to_idx - from_idx # TODO(rcadene): save temporary images of the episode? - state = states[id_from:id_to] + state = states[from_idx:to_idx] ep_dict = {} # load 57MB of images in RAM (400x224x224x3 uint8) - imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to] + imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx] img_key = "observation.image" if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -139,27 +125,18 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps - ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames) - ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames) - ep_dict["end_pose"] = end_pose[id_from:id_to] - ep_dict["start_pos"] = start_pos[id_from:id_to] - ep_dict["gripper_width"] = gripper_width[id_from:id_to] + ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames) + ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames) + ep_dict["end_pose"] = end_pose[from_idx:to_idx] + ep_dict["start_pos"] = start_pos[from_idx:to_idx] + ep_dict["gripper_width"] = gripper_width[from_idx:to_idx] ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - id_from += num_frames - - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - total_frames = id_from + total_frames = data_dict["frame_index"].shape[0] data_dict["index"] = torch.arange(0, total_frames, 1) - - return data_dict, episode_data_index + return data_dict def to_hf_dataset(data_dict, video): @@ -199,7 +176,13 @@ def to_hf_dataset(data_dict, video): return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) @@ -212,9 +195,9 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru "Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM." ) - data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) hf_dataset = to_hf_dataset(data_dict, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index 899ebdde7..57a36dba7 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -27,6 +27,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently from lerobot.common.datasets.utils import ( + calculate_episode_data_index, hf_transform_to_torch, ) from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames @@ -54,37 +55,42 @@ def check_format(raw_dir): assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict) -def load_from_raw(raw_dir, out_dir, fps, video, debug): +def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): pkl_path = raw_dir / "buffer.pkl" with open(pkl_path, "rb") as f: pkl_data = pickle.load(f) - ep_dicts = [] - episode_data_index = {"from": [], "to": []} - - id_from = 0 - id_to = 0 - ep_idx = 0 - total_frames = pkl_data["actions"].shape[0] - for i in tqdm.tqdm(range(total_frames)): - id_to += 1 - - if not pkl_data["dones"][i]: + # load data indices from which each episode starts and ends + from_ids, to_ids = [], [] + from_idx, to_idx = 0, 0 + for done in pkl_data["dones"]: + to_idx += 1 + if not done: continue + from_ids.append(from_idx) + to_ids.append(to_idx) + from_idx = to_idx - num_frames = id_to - id_from + num_episodes = len(from_ids) + + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)): + from_idx = from_ids[selected_ep_idx] + to_idx = to_ids[selected_ep_idx] + num_frames = to_idx - from_idx - image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to]) + image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx]) image = einops.rearrange(image, "b c h w -> b h w c") - state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to]) - action = torch.tensor(pkl_data["actions"][id_from:id_to]) + state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx]) + action = torch.tensor(pkl_data["actions"][from_idx:to_idx]) # TODO(rcadene): we have a missing last frame which is the observation when the env is done # it is critical to have this frame for tdmpc to predict a "done observation/state" - # next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to]) - # next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to]) - next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to]) - next_done = torch.tensor(pkl_data["dones"][id_from:id_to]) + # next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx]) + # next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx]) + next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx]) + next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx]) ep_dict = {} @@ -92,12 +98,12 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): img_key = "observation.image" if video: # save png images in temporary directory - tmp_imgs_dir = out_dir / "tmp_images" + tmp_imgs_dir = videos_dir / "tmp_images" save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" - video_path = out_dir / "videos" / fname + video_path = videos_dir / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory @@ -119,18 +125,11 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): ep_dict["next.done"] = next_done ep_dicts.append(ep_dict) - episode_data_index["from"].append(id_from) - episode_data_index["to"].append(id_from + num_frames) - - id_from = id_to - ep_idx += 1 - - # process first episode only - if debug: - break - data_dict = concatenate_episodes(ep_dicts) - return data_dict, episode_data_index + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + return data_dict def to_hf_dataset(data_dict, video): @@ -161,16 +160,22 @@ def to_hf_dataset(data_dict, video): return hf_dataset -def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False): +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, +): # sanity check check_format(raw_dir) if fps is None: fps = 15 - data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug) + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes) hf_dataset = to_hf_dataset(data_dict, video) - + episode_data_index = calculate_episode_data_index(hf_dataset) info = { "fps": fps, "video": video, diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 9cf72017b..18714a40a 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -18,58 +18,39 @@ or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any installation of neural net specific packages like pytorch, tensorflow, jax. -Example: +Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub: ``` python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id pusht \ +--raw-dir data/pusht_raw \ --raw-format pusht_zarr \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/pusht python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id xarm_lift_medium \ +--raw-dir data/xarm_lift_medium_raw \ --raw-format xarm_pkl \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/xarm_lift_medium python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id aloha_sim_insertion_scripted \ +--raw-dir data/aloha_sim_insertion_scripted_raw \ --raw-format aloha_hdf5 \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/aloha_sim_insertion_scripted python lerobot/scripts/push_dataset_to_hub.py \ ---data-dir data \ ---dataset-id umi_cup_in_the_wild \ +--raw-dir data/umi_cup_in_the_wild_raw \ --raw-format umi_zarr \ ---community-id lerobot \ ---dry-run 1 \ ---save-to-disk 1 \ ---save-tests-to-disk 0 \ ---debug 1 +--repo-id lerobot/umi_cup_in_the_wild ``` """ import argparse import json import shutil +import warnings from pathlib import Path from typing import Any import torch -from huggingface_hub import HfApi +from huggingface_hub import HfApi, create_branch from safetensors.torch import save_file from lerobot.common.datasets.compute_stats import compute_stats @@ -85,8 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str): from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format elif raw_format == "aloha_hdf5": from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format - elif raw_format == "aloha_dora": - from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format + elif raw_format == "dora_parquet": + from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format elif raw_format == "xarm_pkl": from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format else: @@ -147,39 +128,61 @@ def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | Non def push_dataset_to_hub( - data_dir: Path, - dataset_id: str, - raw_format: str | None, - community_id: str, - revision: str, - dry_run: bool, - save_to_disk: bool, - tests_data_dir: Path, - save_tests_to_disk: bool, - fps: int | None, - video: bool, - batch_size: int, - num_workers: int, - debug: bool, + raw_dir: Path, + raw_format: str, + repo_id: str, + push_to_hub: bool = True, + local_dir: Path | None = None, + fps: int | None = None, + video: bool = True, + batch_size: int = 32, + num_workers: int = 8, + episodes: list[int] | None = None, + force_override: bool = False, + cache_dir: Path = Path("/tmp"), + tests_data_dir: Path | None = None, ): - repo_id = f"{community_id}/{dataset_id}" - - raw_dir = data_dir / f"{dataset_id}_raw" - - out_dir = data_dir / repo_id - meta_data_dir = out_dir / "meta_data" - videos_dir = out_dir / "videos" - - tests_out_dir = tests_data_dir / repo_id - tests_meta_data_dir = tests_out_dir / "meta_data" - tests_videos_dir = tests_out_dir / "videos" + # Check repo_id is well formated + if len(repo_id.split("/")) != 2: + raise ValueError( + f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'." + ) + user_id, dataset_id = repo_id.split("/") - if out_dir.exists(): - shutil.rmtree(out_dir) + # Robustify when `raw_dir` is str instead of Path + raw_dir = Path(raw_dir) + if not raw_dir.exists(): + raise NotADirectoryError( + f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:" + f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw" + ) - if tests_out_dir.exists() and save_tests_to_disk: - shutil.rmtree(tests_out_dir) + if local_dir: + # Robustify when `local_dir` is str instead of Path + local_dir = Path(local_dir) + + # Send warning if local_dir isn't well formated + if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id: + warnings.warn( + f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.", + stacklevel=1, + ) + + # Check we don't override an existing `local_dir` by mistake + if local_dir.exists(): + if force_override: + shutil.rmtree(local_dir) + else: + raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.") + + meta_data_dir = local_dir / "meta_data" + videos_dir = local_dir / "videos" + else: + # Temporary directory used to store images, videos, meta_data + meta_data_dir = Path(cache_dir) / "meta_data" + videos_dir = Path(cache_dir) / "videos" + # Download the raw dataset if available if not raw_dir.exists(): download_raw(raw_dir, dataset_id) @@ -188,14 +191,14 @@ def push_dataset_to_hub( raise NotImplementedError() # raw_format = auto_find_raw_format(raw_dir) - from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) - # convert dataset from original raw format to LeRobot format - hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug) + from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) + hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( + raw_dir, videos_dir, fps, video, episodes + ) lerobot_dataset = LeRobotDataset.from_preloaded( repo_id=repo_id, - version=revision, hf_dataset=hf_dataset, episode_data_index=episode_data_index, info=info, @@ -203,103 +206,80 @@ def push_dataset_to_hub( ) stats = compute_stats(lerobot_dataset, batch_size, num_workers) - if save_to_disk: + if local_dir: hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved - hf_dataset.save_to_disk(str(out_dir / "train")) + hf_dataset.save_to_disk(str(local_dir / "train")) - if not dry_run or save_to_disk: + if push_to_hub or local_dir: # mandatory for upload save_meta_data(info, stats, episode_data_index, meta_data_dir) - if not dry_run: - # TODO(rcadene): token needs to be a str | None - hf_dataset.push_to_hub(repo_id, token=True, revision="main") - hf_dataset.push_to_hub(repo_id, token=True, revision=revision) - + if push_to_hub: + hf_dataset.push_to_hub(repo_id, revision="main") push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") - push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision) - if video: push_videos_to_hub(repo_id, videos_dir, revision="main") - push_videos_to_hub(repo_id, videos_dir, revision=revision) + create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) - if save_tests_to_disk: + if tests_data_dir: # get the first episode num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0] test_hf_dataset = hf_dataset.select(range(num_items_first_ep)) test_hf_dataset = test_hf_dataset.with_format(None) - test_hf_dataset.save_to_disk(str(tests_out_dir / "train")) + test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train")) - save_meta_data(info, stats, episode_data_index, tests_meta_data_dir) + tests_meta_data = tests_data_dir / repo_id / "meta_data" + save_meta_data(info, stats, episode_data_index, tests_meta_data) # copy videos of first episode to tests directory episode_index = 0 + tests_videos_dir = tests_data_dir / repo_id / "videos" tests_videos_dir.mkdir(parents=True, exist_ok=True) for key in lerobot_dataset.video_frame_keys: fname = f"{key}_episode_{episode_index:06d}.mp4" shutil.copy(videos_dir / fname, tests_videos_dir / fname) - if not save_to_disk and out_dir.exists(): - # remove possible temporary files remaining in the output directory - shutil.rmtree(out_dir) + if local_dir is None: + # clear cache + shutil.rmtree(meta_data_dir) + shutil.rmtree(videos_dir) + + return lerobot_dataset def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--data-dir", + "--raw-dir", type=Path, required=True, - help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).", - ) - parser.add_argument( - "--dataset-id", - type=str, - required=True, - help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).", + help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).", ) + # TODO(rcadene): add automatic detection of the format parser.add_argument( "--raw-format", type=str, - help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.", - ) - parser.add_argument( - "--community-id", - type=str, - default="lerobot", - help="Community or user ID under which the dataset will be hosted on the Hub.", + required=True, + help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).", ) parser.add_argument( - "--revision", + "--repo-id", type=str, - default=CODEBASE_VERSION, - help="Codebase version used to generate the dataset.", - ) - parser.add_argument( - "--dry-run", - type=int, - default=0, - help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.", - ) - parser.add_argument( - "--save-to-disk", - type=int, - default=1, - help="Save the dataset in the directory specified by `--data-dir`.", + required=True, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", ) parser.add_argument( - "--tests-data-dir", + "--local-dir", type=Path, - default="tests/data", - help="Directory containing tests artifacts datasets.", + help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).", ) parser.add_argument( - "--save-tests-to-disk", + "--push-to-hub", type=int, default=1, - help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.", + help="Upload to hub.", ) parser.add_argument( "--fps", @@ -325,10 +305,21 @@ def main(): help="Number of processes of Dataloader for computing the dataset statistics.", ) parser.add_argument( - "--debug", + "--episodes", + type=int, + nargs="*", + help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.", + ) + parser.add_argument( + "--force-override", type=int, default=0, - help="Debug mode process the first episode only.", + help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.", + ) + parser.add_argument( + "--tests-data-dir", + type=Path, + help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).", ) args = parser.parse_args() diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py new file mode 100644 index 000000000..7ddbe7aab --- /dev/null +++ b/tests/test_push_dataset_to_hub.py @@ -0,0 +1,352 @@ +""" +This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API. +Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets, +we skip them for now in our CI. + +Example to run backward compatiblity tests locally: +``` +DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility +``` +""" + +from pathlib import Path + +import numpy as np +import pytest +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently +from lerobot.common.datasets.video_utils import encode_video_frames +from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub +from tests.utils import require_package_arg + + +def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3): + import zarr + + raw_dir.mkdir(parents=True, exist_ok=True) + zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr" + store = zarr.DirectoryStore(zarr_path) + zarr_data = zarr.group(store=store) + + zarr_data.create_dataset( + "data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/img", + shape=(num_frames, 96, 96, 3), + chunks=(num_frames, 96, 96, 3), + dtype=np.uint8, + overwrite=True, + ) + zarr_data.create_dataset( + "data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True + ) + + zarr_data["data/action"][:] = np.random.randn(num_frames, 1) + zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) + zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2) + zarr_data["data/state"][:] = np.random.randn(num_frames, 5) + zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2) + zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) + + store.close() + + +def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3): + import zarr + + raw_dir.mkdir(parents=True, exist_ok=True) + zarr_path = raw_dir / "cup_in_the_wild.zarr" + store = zarr.DirectoryStore(zarr_path) + zarr_data = zarr.group(store=store) + + zarr_data.create_dataset( + "data/camera0_rgb", + shape=(num_frames, 96, 96, 3), + chunks=(num_frames, 96, 96, 3), + dtype=np.uint8, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_demo_end_pose", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_demo_start_pose", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True + ) + zarr_data.create_dataset( + "data/robot0_eef_rot_axis_angle", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "data/robot0_gripper_width", + shape=(num_frames, 5), + chunks=(num_frames, 5), + dtype=np.float32, + overwrite=True, + ) + zarr_data.create_dataset( + "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True + ) + + zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) + zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5) + zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5) + zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) + + store.close() + + +def _mock_download_raw_xarm(raw_dir, num_frames=4): + import pickle + + dataset_dict = { + "observations": { + "rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8), + "state": np.random.randn(num_frames, 4), + }, + "actions": np.random.randn(num_frames, 3), + "rewards": np.random.randn(num_frames), + "masks": np.random.randn(num_frames), + "dones": np.array([False, True, True, True]), + } + + raw_dir.mkdir(parents=True, exist_ok=True) + pkl_path = raw_dir / "buffer.pkl" + with open(pkl_path, "wb") as f: + pickle.dump(dataset_dict, f) + + +def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3): + import h5py + + for ep_idx in range(num_episodes): + raw_dir.mkdir(parents=True, exist_ok=True) + path_h5 = raw_dir / f"episode_{ep_idx}.hdf5" + with h5py.File(str(path_h5), "w") as f: + f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14)) + f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14)) + f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14)) + f.create_dataset( + "observations/images/top", + data=np.random.randint( + 0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8 + ), + ) + + +def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30): + from datetime import datetime, timedelta, timezone + + import pandas + + def write_parquet(key, timestamps, values): + data = { + "timestamp_utc": timestamps, + key: values, + } + df = pandas.DataFrame(data) + raw_dir.mkdir(parents=True, exist_ok=True) + df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow") + + episode_indices = [None, None, -1, None, None, -1, None, None, -1] + episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2] + frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1] + + cam_key = "observation.images.cam_high" + timestamps = [] + actions = [] + states = [] + frames = [] + # `+ num_episodes`` for buffer frames associated to episode_index=-1 + for i, frame_idx in enumerate(frame_indices): + t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps) + action = np.random.randn(21).tolist() + state = np.random.randn(21).tolist() + ep_idx = episode_indices_mapping[i] + frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}] + timestamps.append(t_utc) + actions.append(action) + states.append(state) + frames.append(frame) + + write_parquet(cam_key, timestamps, frames) + write_parquet("observation.state", timestamps, states) + write_parquet("action", timestamps, actions) + write_parquet("episode_index", timestamps, episode_indices) + + # write fake mp4 file for each episode + for ep_idx in range(num_episodes): + imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8) + + tmp_imgs_dir = raw_dir / "tmp_images" + save_images_concurrently(imgs_array, tmp_imgs_dir) + + fname = f"{cam_key}_episode_{ep_idx:06d}.mp4" + video_path = raw_dir / "videos" / fname + encode_video_frames(tmp_imgs_dir, video_path, fps) + + +def _mock_download_raw(raw_dir, repo_id): + if "wrist_gripper" in repo_id: + _mock_download_raw_dora(raw_dir) + elif "aloha" in repo_id: + _mock_download_raw_aloha(raw_dir) + elif "pusht" in repo_id: + _mock_download_raw_pusht(raw_dir) + elif "xarm" in repo_id: + _mock_download_raw_xarm(raw_dir) + elif "umi" in repo_id: + _mock_download_raw_umi(raw_dir) + else: + raise ValueError(repo_id) + + +def test_push_dataset_to_hub_invalid_repo_id(tmpdir): + with pytest.raises(ValueError): + push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id") + + +def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir): + tmpdir = Path(tmpdir) + out_dir = tmpdir / "out" + raw_dir = tmpdir / "raw" + # mkdir to skip download + raw_dir.mkdir(parents=True, exist_ok=True) + with pytest.raises(ValueError): + push_dataset_to_hub( + raw_dir=raw_dir, + raw_format="some_format", + repo_id="user/dataset", + local_dir=out_dir, + force_override=False, + ) + + +@pytest.mark.parametrize( + "required_packages, raw_format, repo_id", + [ + (["gym-pusht"], "pusht_zarr", "lerobot/pusht"), + (None, "xarm_pkl", "lerobot/xarm_lift_medium"), + (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), + (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"), + (None, "dora_parquet", "cadene/wrist_gripper"), + ], +) +@require_package_arg +def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id): + num_episodes = 3 + tmpdir = Path(tmpdir) + + raw_dir = tmpdir / f"{repo_id}_raw" + _mock_download_raw(raw_dir, repo_id) + + local_dir = tmpdir / repo_id + + lerobot_dataset = push_dataset_to_hub( + raw_dir=raw_dir, + raw_format=raw_format, + repo_id=repo_id, + push_to_hub=False, + local_dir=local_dir, + force_override=False, + cache_dir=tmpdir / "cache", + ) + + # minimal generic tests on the local directory containing LeRobotDataset + assert (local_dir / "meta_data" / "info.json").exists() + assert (local_dir / "meta_data" / "stats.safetensors").exists() + assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists() + for i in range(num_episodes): + for cam_key in lerobot_dataset.camera_keys: + assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists() + assert (local_dir / "train" / "dataset_info.json").exists() + assert (local_dir / "train" / "state.json").exists() + assert len(list((local_dir / "train").glob("*.arrow"))) > 0 + + # minimal generic tests on the item + item = lerobot_dataset[0] + assert "index" in item + assert "episode_index" in item + assert "timestamp" in item + for cam_key in lerobot_dataset.camera_keys: + assert cam_key in item + + +@pytest.mark.parametrize( + "raw_format, repo_id", + [ + # TODO(rcadene): add raw dataset test artifacts + ("pusht_zarr", "lerobot/pusht"), + ("xarm_pkl", "lerobot/xarm_lift_medium"), + ("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), + ("umi_zarr", "lerobot/umi_cup_in_the_wild"), + ("dora_parquet", "cadene/wrist_gripper"), + ], +) +@pytest.mark.skip( + "Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`" +) +def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id): + _, dataset_id = repo_id.split("/") + + tmpdir = Path(tmpdir) + raw_dir = tmpdir / f"{dataset_id}_raw" + local_dir = tmpdir / repo_id + + push_dataset_to_hub( + raw_dir=raw_dir, + raw_format=raw_format, + repo_id=repo_id, + push_to_hub=False, + local_dir=local_dir, + force_override=False, + cache_dir=tmpdir / "cache", + episodes=[0], + ) + + ds_actual = LeRobotDataset(repo_id, root=tmpdir) + ds_reference = LeRobotDataset(repo_id) + + assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset) + + def check_same_items(item1, item2): + assert item1.keys() == item2.keys(), "Keys mismatch" + + for key in item1: + if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor): + assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}" + else: + assert item1[key] == item2[key], f"Mismatch found in key: {key}" + + for i in range(len(ds_reference.hf_dataset)): + item_reference = ds_reference.hf_dataset[i] + item_actual = ds_actual.hf_dataset[i] + check_same_items(item_reference, item_actual) diff --git a/tests/utils.py b/tests/utils.py index ba49ee706..c1575656c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -76,6 +76,7 @@ def require_env(func): """ Decorator that skips the test if the required environment package is not installed. As it need 'env_name' in args, it also checks whether it is provided as an argument. + If 'env_name' is None, this check is skipped. """ @wraps(func) @@ -91,7 +92,7 @@ def wrapper(*args, **kwargs): # Perform the package check package_name = f"gym_{env_name}" - if not is_package_available(package_name): + if env_name is not None and not is_package_available(package_name): pytest.skip(f"gym-{env_name} not installed") return func(*args, **kwargs) @@ -99,6 +100,38 @@ def wrapper(*args, **kwargs): return wrapper +def require_package_arg(func): + """ + Decorator that skips the test if the required package is not installed. + This is similar to `require_env` but more general in that it can check any package (not just environments). + As it need 'required_packages' in args, it also checks whether it is provided as an argument. + If 'required_packages' is None, this check is skipped. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + # Determine if 'required_packages' is provided and extract its value + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + if "required_packages" in arg_names: + # Get the index of 'required_packages' and retrieve the value from args + index = arg_names.index("required_packages") + required_packages = args[index] if len(args) > index else kwargs.get("required_packages") + else: + raise ValueError("Function does not have 'required_packages' as an argument.") + + if required_packages is None: + return func(*args, **kwargs) + + # Perform the package check + for package in required_packages: + if not is_package_available(package): + pytest.skip(f"{package} not installed") + + return func(*args, **kwargs) + + return wrapper + + def require_package(package_name): """ Decorator that skips the test if the specified package is not installed.