diff --git a/convert_dataset_v1_to_v2.py b/convert_dataset_v1_to_v2.py index 9343c898b..79749667a 100644 --- a/convert_dataset_v1_to_v2.py +++ b/convert_dataset_v1_to_v2.py @@ -12,7 +12,7 @@ # 1. Single task dataset If your dataset contains a single task, you can simply provide it directly via the CLI with the -'--single-task' option (see examples below). +'--single-task' option. Examples: @@ -67,7 +67,15 @@ # 3. Multi task episodes If you have multiple tasks per episodes, your dataset should contain a language instruction column in its parquet file, and you must provide this column's name with the '--tasks-col' arg. -TODO + +Example: + +```bash +python convert_dataset_v1_to_v2.py \ + --repo-id lerobot/stanford_kuka_multimodal_dataset \ + --tasks-col "language_instruction" \ + --local-dir data +``` """ import argparse @@ -87,12 +95,12 @@ from PIL import Image from safetensors.torch import load_file -from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict +from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub -V1_6 = "v1.6" -V2_0 = "v2.0" +V16 = "v1.6" +V20 = "v2.0" PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4" @@ -385,18 +393,19 @@ def convert_dataset( tasks_col: Path | None = None, robot_config: dict | None = None, ): - v1_6_dir = local_dir / V1_6 / repo_id - v2_0_dir = local_dir / V2_0 / repo_id - v1_6_dir.mkdir(parents=True, exist_ok=True) - v2_0_dir.mkdir(parents=True, exist_ok=True) + v1 = get_hub_safe_version(repo_id, V16) + v1x_dir = local_dir / v1 / repo_id + v20_dir = local_dir / V20 / repo_id + v1x_dir.mkdir(parents=True, exist_ok=True) + v20_dir.mkdir(parents=True, exist_ok=True) hub_api = HfApi() hub_api.snapshot_download( - repo_id=repo_id, repo_type="dataset", revision=V1_6, local_dir=v1_6_dir, ignore_patterns="videos/" + repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos/" ) - metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json") - dataset = datasets.load_dataset("parquet", data_dir=v1_6_dir / "data", split="train") + metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json") + dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train") keys = get_keys(dataset) # Episodes @@ -422,21 +431,22 @@ def convert_dataset( assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] - write_json(task_json, v2_0_dir / "meta" / "tasks.json") + write_json(task_json, v20_dir / "meta" / "tasks.json") # Split data into 1 parquet file by episode - episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v2_0_dir) + episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir) # Shapes sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]} image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {} if len(keys["video"]) > 0: - assert metadata_v1_6.get("video", False) - videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"]) + assert metadata_v1.get("video", False) + videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"]) video_shapes = get_video_shapes(videos_info, keys["video"]) for img_key in keys["video"]: - assert videos_info[img_key]["video.pix_fmt"] == metadata_v1_6["encoding"]["pix_fmt"] - assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1_6["fps"], rel_tol=1e-3) + assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) + if "encoding" in metadata_v1: + assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] else: assert len(keys["video"]) == 0 videos_info = None @@ -461,16 +471,16 @@ def convert_dataset( {"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]} for ep_idx in episode_indices ] - write_json(episodes, v2_0_dir / "meta" / "episodes.json") + write_json(episodes, v20_dir / "meta" / "episodes.json") # Assemble metadata v2.0 metadata_v2_0 = { - "codebase_version": V2_0, + "codebase_version": V20, "data_path": PARQUET_PATH, "robot_type": robot_type, "total_episodes": total_episodes, "total_tasks": len(tasks), - "fps": metadata_v1_6["fps"], + "fps": metadata_v1["fps"], "splits": {"train": f"0:{total_episodes}"}, "keys": keys["sequence"], "video_keys": keys["video"], @@ -479,14 +489,14 @@ def convert_dataset( "names": names, "videos": videos_info, } - write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json") - convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta") + write_json(metadata_v2_0, v20_dir / "meta" / "info.json") + convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta") #### TODO: delete - repo_id = f"aliberts/{repo_id.split('/')[1]}" + # repo_id = f"aliberts/{repo_id.split('/')[1]}" # if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"): # hub_api.delete_repo(repo_id=repo_id, repo_type="dataset") - hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True) + # hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True) #### with contextlib.suppress(EntryNotFoundError): @@ -498,28 +508,28 @@ def convert_dataset( hub_api.upload_folder( repo_id=repo_id, path_in_repo="data", - folder_path=v2_0_dir / "data", + folder_path=v20_dir / "data", repo_type="dataset", revision="main", ) hub_api.upload_folder( repo_id=repo_id, path_in_repo="videos", - folder_path=v1_6_dir / "videos", + folder_path=v1x_dir / "videos", repo_type="dataset", revision="main", ) hub_api.upload_folder( repo_id=repo_id, path_in_repo="meta", - folder_path=v2_0_dir / "meta", + folder_path=v20_dir / "meta", repo_type="dataset", revision="main", ) card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```" push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text) - create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset") + create_branch(repo_id=repo_id, branch=V20, repo_type="dataset") # TODO: # - [X] Add shapes