Skip to content

Commit

Permalink
Add cam_png_format
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Jun 13, 2024
1 parent 4eb18c0 commit bb95a8e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
93 changes: 93 additions & 0 deletions lerobot/common/datasets/push_dataset_to_hub/cam_png_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

import torch
from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage

from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame


def check_format(raw_dir: Path) -> bool:
image_paths = list(raw_dir.glob("frame_*.png"))
if len(image_paths) == 0:
raise ValueError


def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
ep_dict = {}
ep_idx = 0

image_paths = sorted(raw_dir.glob("frame_*.png"))
num_frames = len(image_paths)

ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps

ep_dicts = [ep_dict]
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict


def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
if video:
features["observation.image"] = VideoFrame()
else:
features["observation.image"] = Image()

features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["index"] = Value(dtype="int64", id=None)

hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset


def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
):
if video or episodes is not None:
# TODO(aliberts): support this
raise NotImplementedError

# sanity check
check_format(raw_dir, video, episodes)

if fps is None:
fps = 30

data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"fps": fps,
"video": video,
}
return hf_dataset, episode_data_index, info
7 changes: 2 additions & 5 deletions lerobot/scripts/push_dataset_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@

from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.utils import flatten_dict


Expand All @@ -70,6 +69,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
elif raw_format == "cam_png":
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
else:
raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
Expand Down Expand Up @@ -182,10 +183,6 @@ def push_dataset_to_hub(
meta_data_dir = Path(cache_dir) / "meta_data"
videos_dir = Path(cache_dir) / "videos"

# Download the raw dataset if available
if not raw_dir.exists():
download_raw(raw_dir, dataset_id)

if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError()
Expand Down

0 comments on commit bb95a8e

Please sign in to comment.