From 18ffa4248b03287009a00e3f21fbc9754a753929 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 24 Oct 2024 11:49:53 +0200 Subject: [PATCH] Add json/jsonl io functions --- lerobot/common/datasets/lerobot_dataset.py | 6 ++-- lerobot/common/datasets/utils.py | 33 ++++++++++++------- .../datasets/v2/convert_dataset_v1_to_v2.py | 21 ++---------- 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 513a931bd..d4e6d2263 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -36,7 +36,7 @@ STATS_PATH, TASKS_PATH, _get_info_from_robot, - append_jsonl, + append_jsonlines, check_delta_timestamps, check_timestamps_sync, check_version_compatibility, @@ -648,7 +648,7 @@ def _save_episode_to_metadata( "task_index": task_index, "task": task, } - append_jsonl(task_dict, self.root / TASKS_PATH) + append_jsonlines(task_dict, self.root / TASKS_PATH) chunk = self.get_episode_chunk(episode_index) if chunk >= self.total_chunks: @@ -664,7 +664,7 @@ def _save_episode_to_metadata( "length": episode_length, } self.episode_dicts.append(episode_dict) - append_jsonl(episode_dict, self.root / EPISODES_PATH) + append_jsonlines(episode_dict, self.root / EPISODES_PATH) def clear_episode_buffer(self) -> None: episode_index = self.episode_buffer["episode_index"] diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index f2ce9b55a..008d7843e 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -18,7 +18,7 @@ from itertools import accumulate from pathlib import Path from pprint import pformat -from typing import Dict +from typing import Any, Dict import datasets import jsonlines @@ -80,13 +80,29 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict: return outdict +def load_json(fpath: Path) -> Any: + with open(fpath) as f: + return json.load(f) + + def write_json(data: dict, fpath: Path) -> None: fpath.parent.mkdir(exist_ok=True, parents=True) with open(fpath, "w") as f: json.dump(data, f, indent=4, ensure_ascii=False) -def append_jsonl(data: dict, fpath: Path) -> None: +def load_jsonlines(fpath: Path) -> list[Any]: + with jsonlines.open(fpath, "r") as reader: + return list(reader) + + +def write_jsonlines(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) + with jsonlines.open(fpath, "w") as writer: + writer.write_all(data) + + +def append_jsonlines(data: dict, fpath: Path) -> None: fpath.parent.mkdir(exist_ok=True, parents=True) with jsonlines.open(fpath, "a") as writer: writer.write(data) @@ -170,27 +186,22 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> def load_info(local_dir: Path) -> dict: - with open(local_dir / INFO_PATH) as f: - return json.load(f) + return load_json(local_dir / INFO_PATH) def load_stats(local_dir: Path) -> dict: - with open(local_dir / STATS_PATH) as f: - stats = json.load(f) + stats = load_json(local_dir / STATS_PATH) stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) def load_tasks(local_dir: Path) -> dict: - with jsonlines.open(local_dir / TASKS_PATH, "r") as reader: - tasks = list(reader) - + tasks = load_jsonlines(local_dir / TASKS_PATH) return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} def load_episode_dicts(local_dir: Path) -> dict: - with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader: - return list(reader) + return load_jsonlines(local_dir / EPISODES_PATH) def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]: diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 7ab5ae14a..120076b92 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -110,7 +110,6 @@ from pathlib import Path import datasets -import jsonlines import pyarrow.compute as pc import pyarrow.parquet as pq import torch @@ -132,7 +131,10 @@ create_lerobot_dataset_card, flatten_dict, get_hub_safe_version, + load_json, unflatten_dict, + write_json, + write_jsonlines, ) from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401 from lerobot.common.utils.utils import init_hydra_config @@ -175,23 +177,6 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N } -def load_json(fpath: Path) -> dict: - with open(fpath) as f: - return json.load(f) - - -def write_json(data: dict, fpath: Path) -> None: - fpath.parent.mkdir(exist_ok=True, parents=True) - with open(fpath, "w") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - - -def write_jsonlines(data: dict, fpath: Path) -> None: - fpath.parent.mkdir(exist_ok=True, parents=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(data) - - def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: safetensor_path = v1_dir / V1_STATS_PATH stats = load_file(safetensor_path)