Skip to content

Commit

Permalink
Move ImageWriter creation inside the dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 23, 2024
1 parent 0098bd2 commit 0d77be9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 27 deletions.
52 changes: 36 additions & 16 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,13 @@ def __init__(
self.episodes = episodes
self.tolerance_s = tolerance_s
self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.delta_indices = None
self.local_files_only = local_files_only
self.consolidated = True

# Unused attributes
self.image_writer = None
self.episode_buffer = {}
self.local_files_only = local_files_only

# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -626,8 +628,7 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
self.consolidated = False

def _save_episode_table(self, episode_index: int) -> None:
features = self.features
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
ep_table = ep_dataset._data.table
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -675,10 +676,25 @@ def clear_episode_buffer(self) -> None:
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()

def read_mode(self) -> None:
"""Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first."""
# TODO(aliberts, rcadene): find better api/interface for this.
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
if isinstance(self.image_writer, ImageWriter):
logging.warning(
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
)

self.image_writer = ImageWriter(
write_dir=self.root,
num_processes=num_processes,
num_threads=num_threads,
)

def stop_image_writter(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()
self.image_writer = None

def encode_videos(self) -> None:
Expand Down Expand Up @@ -708,20 +724,20 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F
shutil.rmtree(self.image_writer.dir)

if run_compute_stats:
self.read_mode()
self.stop_image_writter()
self.stats = compute_stats(self)
write_stats(self.stats, self.root / STATS_PATH)
self.consolidated = True
else:
logging.warning("Skipping computation of the dataset statistics.")
logging.warning(
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
)

# TODO(aliberts)
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
# - [ ] number of files
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
# - [ ] no remaining self.image_writer.dir

@classmethod
def create(
Expand All @@ -731,7 +747,8 @@ def create(
robot: Robot,
root: Path | None = None,
tolerance_s: float = 1e-4,
image_writer: ImageWriter | None = None,
image_writer_processes: int = 0,
image_writer_threads_per_camera: int = 0,
use_videos: bool = True,
video_backend: str | None = None,
) -> "LeRobotDataset":
Expand All @@ -740,7 +757,6 @@ def create(
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer

if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
Expand All @@ -755,20 +771,24 @@ def create(
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()

obj.image_writer = None
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)

# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True

obj.local_files_only = True
obj.download_videos = False

obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj.local_files_only = True
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj
Expand Down
18 changes: 7 additions & 11 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@
from typing import List

# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.robot_devices.control_utils import (
control_loop,
Expand Down Expand Up @@ -232,17 +231,14 @@ def record(

# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
if len(robot.cameras) > 0:
image_writer = ImageWriter(
write_dir=root,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
else:
image_writer = None

dataset = LeRobotDataset.create(
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
repo_id,
fps,
robot,
root=root,
image_writer_processes=num_image_writer_processes,
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
use_videos=video,
)

if not robot.is_connected:
Expand Down

0 comments on commit 0d77be9

Please sign in to comment.