From c3c0141738d133546022632f4a29ed27dd7c87c2 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 23 Oct 2024 00:05:31 +0200 Subject: [PATCH] Update & fix conversion script --- .../datasets/v2/convert_dataset_v1_to_v2.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 65a2061ed..7ab5ae14a 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -124,19 +124,26 @@ DEFAULT_CHUNK_SIZE, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, + EPISODES_PATH, + INFO_PATH, + STATS_PATH, + TASKS_PATH, create_branch, create_lerobot_dataset_card, flatten_dict, get_hub_safe_version, unflatten_dict, ) +from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401 from lerobot.common.utils.utils import init_hydra_config V16 = "v1.6" V20 = "v2.0" GITATTRIBUTES_REF = "aliberts/gitattributes_reference" -VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4" +V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4" +V1_INFO_PATH = "meta_data/info.json" +V1_STATS_PATH = "meta_data/stats.safetensors" def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]: @@ -180,17 +187,18 @@ def write_json(data: dict, fpath: Path) -> None: def write_jsonlines(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) with jsonlines.open(fpath, "w") as writer: writer.write_all(data) -def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None: - safetensor_path = input_dir / "stats.safetensors" +def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: + safetensor_path = v1_dir / V1_STATS_PATH stats = load_file(safetensor_path) serialized_stats = {key: value.tolist() for key, value in stats.items()} serialized_stats = unflatten_dict(serialized_stats) - json_path = output_dir / "stats.json" + json_path = v2_dir / STATS_PATH json_path.parent.mkdir(exist_ok=True, parents=True) with open(json_path, "w") as f: json.dump(serialized_stats, f, indent=4) @@ -279,7 +287,7 @@ def split_parquet_by_episodes( ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) episode_lengths.insert(ep_idx, len(ep_table)) output_file = output_dir / DEFAULT_PARQUET_PATH.format( - episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes + episode_chunk=ep_chunk, episode_index=ep_idx ) pq.write_table(ep_table, output_file) @@ -336,7 +344,7 @@ def move_videos( target_path = DEFAULT_VIDEO_PATH.format( episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx ) - video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) + video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) if len(video_dirs) == 1: video_path = video_dirs[0] / video_file else: @@ -572,7 +580,7 @@ def convert_dataset( branch = test_branch create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset") - metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json") + metadata_v1 = load_json(v1x_dir / V1_INFO_PATH) dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train") keys = get_keys(dataset) @@ -611,7 +619,7 @@ def convert_dataset( assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] - write_jsonlines(tasks, v20_dir / "meta" / "tasks.json") + write_jsonlines(tasks, v20_dir / TASKS_PATH) # Shapes sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]} @@ -667,7 +675,7 @@ def convert_dataset( {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} for ep_idx in episode_indices ] - write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl") + write_jsonlines(episodes, v20_dir / EPISODES_PATH) # Assemble metadata v2.0 metadata_v2_0 = { @@ -689,8 +697,8 @@ def convert_dataset( "names": names, "videos": videos_info, } - write_json(metadata_v2_0, v20_dir / "meta" / "info.json") - convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta") + write_json(metadata_v2_0, v20_dir / INFO_PATH) + convert_stats_to_json(v1x_dir, v20_dir) with contextlib.suppress(EntryNotFoundError): hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)