Skip to content

Commit

Permalink
Extend v1 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 14, 2024
1 parent cf63334 commit cbc51e1
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions convert_dataset_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# 1. Single task dataset
If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option (see examples below).
'--single-task' option.
Examples:
Expand Down Expand Up @@ -67,7 +67,15 @@
# 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg.
TODO
Example:
```bash
python convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
"""

import argparse
Expand All @@ -87,12 +95,12 @@
from PIL import Image
from safetensors.torch import load_file

from lerobot.common.datasets.utils import create_branch, flatten_dict, unflatten_dict
from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub

V1_6 = "v1.6"
V2_0 = "v2.0"
V16 = "v1.6"
V20 = "v2.0"

PARQUET_PATH = "data/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
VIDEO_PATH = "videos/{video_key}_episode_{episode_index:06d}.mp4"
Expand Down Expand Up @@ -385,18 +393,19 @@ def convert_dataset(
tasks_col: Path | None = None,
robot_config: dict | None = None,
):
v1_6_dir = local_dir / V1_6 / repo_id
v2_0_dir = local_dir / V2_0 / repo_id
v1_6_dir.mkdir(parents=True, exist_ok=True)
v2_0_dir.mkdir(parents=True, exist_ok=True)
v1 = get_hub_safe_version(repo_id, V16)
v1x_dir = local_dir / v1 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
v20_dir.mkdir(parents=True, exist_ok=True)

hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=V1_6, local_dir=v1_6_dir, ignore_patterns="videos/"
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos/"
)

metadata_v1_6 = load_json(v1_6_dir / "meta_data" / "info.json")
dataset = datasets.load_dataset("parquet", data_dir=v1_6_dir / "data", split="train")
metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
keys = get_keys(dataset)

# Episodes
Expand All @@ -422,21 +431,22 @@ def convert_dataset(

assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
task_json = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_json(task_json, v2_0_dir / "meta" / "tasks.json")
write_json(task_json, v20_dir / "meta" / "tasks.json")

# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v2_0_dir)
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, episode_indices, v20_dir)

# Shapes
sequence_shapes = {key: len(dataset[key][0]) for key in keys["sequence"]}
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
if len(keys["video"]) > 0:
assert metadata_v1_6.get("video", False)
videos_info = get_videos_info(repo_id, v1_6_dir, video_keys=keys["video"])
assert metadata_v1.get("video", False)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"])
video_shapes = get_video_shapes(videos_info, keys["video"])
for img_key in keys["video"]:
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1_6["encoding"]["pix_fmt"]
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1_6["fps"], rel_tol=1e-3)
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
if "encoding" in metadata_v1:
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else:
assert len(keys["video"]) == 0
videos_info = None
Expand All @@ -461,16 +471,16 @@ def convert_dataset(
{"episode_index": ep_idx, "tasks": [tasks_by_episodes[ep_idx]], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_json(episodes, v2_0_dir / "meta" / "episodes.json")
write_json(episodes, v20_dir / "meta" / "episodes.json")

# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V2_0,
"codebase_version": V20,
"data_path": PARQUET_PATH,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_tasks": len(tasks),
"fps": metadata_v1_6["fps"],
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"keys": keys["sequence"],
"video_keys": keys["video"],
Expand All @@ -479,14 +489,14 @@ def convert_dataset(
"names": names,
"videos": videos_info,
}
write_json(metadata_v2_0, v2_0_dir / "meta" / "info.json")
convert_stats_to_json(v1_6_dir / "meta_data", v2_0_dir / "meta")
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")

#### TODO: delete
repo_id = f"aliberts/{repo_id.split('/')[1]}"
# repo_id = f"aliberts/{repo_id.split('/')[1]}"
# if hub_api.repo_exists(repo_id=repo_id, repo_type="dataset"):
# hub_api.delete_repo(repo_id=repo_id, repo_type="dataset")
hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
# hub_api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True)
####

with contextlib.suppress(EntryNotFoundError):
Expand All @@ -498,28 +508,28 @@ def convert_dataset(
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
folder_path=v2_0_dir / "data",
folder_path=v20_dir / "data",
repo_type="dataset",
revision="main",
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="videos",
folder_path=v1_6_dir / "videos",
folder_path=v1x_dir / "videos",
repo_type="dataset",
revision="main",
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
folder_path=v2_0_dir / "meta",
folder_path=v20_dir / "meta",
repo_type="dataset",
revision="main",
)

card_text = f"[meta/info.json](meta/info.json)\n```json\n{json.dumps(metadata_v2_0, indent=4)}\n```"
push_dataset_card_to_hub(repo_id=repo_id, revision="main", tags=repo_tags, text=card_text)
create_branch(repo_id=repo_id, branch=V2_0, repo_type="dataset")
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")

# TODO:
# - [X] Add shapes
Expand Down

0 comments on commit cbc51e1

Please sign in to comment.