Skip to content

Commit

Permalink
Update & fix conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 22, 2024
1 parent c72dc23 commit c3c0141
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,26 @@
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_hub_safe_version,
unflatten_dict,
)
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
from lerobot.common.utils.utils import init_hydra_config

V16 = "v1.6"
V20 = "v2.0"

GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"


def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
Expand Down Expand Up @@ -180,17 +187,18 @@ def write_json(data: dict, fpath: Path) -> None:


def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)


def convert_stats_to_json(input_dir: Path, output_dir: Path) -> None:
safetensor_path = input_dir / "stats.safetensors"
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path)
serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats)

json_path = output_dir / "stats.json"
json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
json.dump(serialized_stats, f, indent=4)
Expand Down Expand Up @@ -279,7 +287,7 @@ def split_parquet_by_episodes(
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
episode_chunk=ep_chunk, episode_index=ep_idx
)
pq.write_table(ep_table, output_file)

Expand Down Expand Up @@ -336,7 +344,7 @@ def move_videos(
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
Expand Down Expand Up @@ -572,7 +580,7 @@ def convert_dataset(
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")

metadata_v1 = load_json(v1x_dir / "meta_data" / "info.json")
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
keys = get_keys(dataset)

Expand Down Expand Up @@ -611,7 +619,7 @@ def convert_dataset(

assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_jsonlines(tasks, v20_dir / "meta" / "tasks.json")
write_jsonlines(tasks, v20_dir / TASKS_PATH)

# Shapes
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
Expand Down Expand Up @@ -667,7 +675,7 @@ def convert_dataset(
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / "meta" / "episodes.jsonl")
write_jsonlines(episodes, v20_dir / EPISODES_PATH)

# Assemble metadata v2.0
metadata_v2_0 = {
Expand All @@ -689,8 +697,8 @@ def convert_dataset(
"names": names,
"videos": videos_info,
}
write_json(metadata_v2_0, v20_dir / "meta" / "info.json")
convert_stats_to_json(v1x_dir / "meta_data", v20_dir / "meta")
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)

with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
Expand Down

0 comments on commit c3c0141

Please sign in to comment.