Skip to content

Commit

Permalink
Add json/jsonl io functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 24, 2024
1 parent 8bcf81f commit 18ffa42
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 32 deletions.
6 changes: 3 additions & 3 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
STATS_PATH,
TASKS_PATH,
_get_info_from_robot,
append_jsonl,
append_jsonlines,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
Expand Down Expand Up @@ -648,7 +648,7 @@ def _save_episode_to_metadata(
"task_index": task_index,
"task": task,
}
append_jsonl(task_dict, self.root / TASKS_PATH)
append_jsonlines(task_dict, self.root / TASKS_PATH)

chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
Expand All @@ -664,7 +664,7 @@ def _save_episode_to_metadata(
"length": episode_length,
}
self.episode_dicts.append(episode_dict)
append_jsonl(episode_dict, self.root / EPISODES_PATH)
append_jsonlines(episode_dict, self.root / EPISODES_PATH)

def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"]
Expand Down
33 changes: 22 additions & 11 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from typing import Dict
from typing import Any, Dict

import datasets
import jsonlines
Expand Down Expand Up @@ -80,13 +80,29 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict


def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)


def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)


def append_jsonl(data: dict, fpath: Path) -> None:
def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)


def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)


def append_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
Expand Down Expand Up @@ -170,27 +186,22 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->


def load_info(local_dir: Path) -> dict:
with open(local_dir / INFO_PATH) as f:
return json.load(f)
return load_json(local_dir / INFO_PATH)


def load_stats(local_dir: Path) -> dict:
with open(local_dir / STATS_PATH) as f:
stats = json.load(f)
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)


def load_tasks(local_dir: Path) -> dict:
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
tasks = list(reader)

tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}


def load_episode_dicts(local_dir: Path) -> dict:
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
return list(reader)
return load_jsonlines(local_dir / EPISODES_PATH)


def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
Expand Down
21 changes: 3 additions & 18 deletions lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
from pathlib import Path

import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
Expand All @@ -132,7 +131,10 @@
create_lerobot_dataset_card,
flatten_dict,
get_hub_safe_version,
load_json,
unflatten_dict,
write_json,
write_jsonlines,
)
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
from lerobot.common.utils.utils import init_hydra_config
Expand Down Expand Up @@ -175,23 +177,6 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
}


def load_json(fpath: Path) -> dict:
with open(fpath) as f:
return json.load(f)


def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)


def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)


def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path)
Expand Down

0 comments on commit 18ffa42

Please sign in to comment.