From 0d77be90ee0871b16fbfbbe10f9024aae4ba83a8 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Wed, 23 Oct 2024 23:12:44 +0200 Subject: [PATCH] Move ImageWriter creation inside the dataset --- lerobot/common/datasets/lerobot_dataset.py | 52 +++++++++++++++------- lerobot/scripts/control_robot.py | 18 +++----- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index b32e1008c..6a1d3719c 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -177,11 +177,13 @@ def __init__( self.episodes = episodes self.tolerance_s = tolerance_s self.video_backend = video_backend if video_backend is not None else "pyav" - self.image_writer = image_writer self.delta_indices = None + self.local_files_only = local_files_only self.consolidated = True + + # Unused attributes + self.image_writer = None self.episode_buffer = {} - self.local_files_only = local_files_only # Load metadata self.root.mkdir(exist_ok=True, parents=True) @@ -626,8 +628,7 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None: self.consolidated = False def _save_episode_table(self, episode_index: int) -> None: - features = self.features - ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train") + ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train") ep_table = ep_dataset._data.table ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) @@ -675,10 +676,25 @@ def clear_episode_buffer(self) -> None: # Reset the buffer self.episode_buffer = self._create_episode_buffer() - def read_mode(self) -> None: - """Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first.""" - # TODO(aliberts, rcadene): find better api/interface for this. + def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None: + if isinstance(self.image_writer, ImageWriter): + logging.warning( + "You are starting a new ImageWriter that is replacing an already exising one in the dataset." + ) + + self.image_writer = ImageWriter( + write_dir=self.root, + num_processes=num_processes, + num_threads=num_threads, + ) + + def stop_image_writter(self) -> None: + """ + Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to + remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized. + """ if self.image_writer is not None: + self.image_writer.stop() self.image_writer = None def encode_videos(self) -> None: @@ -708,20 +724,20 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F shutil.rmtree(self.image_writer.dir) if run_compute_stats: - self.read_mode() + self.stop_image_writter() self.stats = compute_stats(self) write_stats(self.stats, self.root / STATS_PATH) self.consolidated = True else: - logging.warning("Skipping computation of the dataset statistics.") + logging.warning( + "Skipping computation of the dataset statistics, dataset is not fully consolidated." + ) # TODO(aliberts) # Sanity checks: # - [ ] shapes # - [ ] ep_lenghts # - [ ] number of files - # - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002) - # - [ ] no remaining self.image_writer.dir @classmethod def create( @@ -731,7 +747,8 @@ def create( robot: Robot, root: Path | None = None, tolerance_s: float = 1e-4, - image_writer: ImageWriter | None = None, + image_writer_processes: int = 0, + image_writer_threads_per_camera: int = 0, use_videos: bool = True, video_backend: str | None = None, ) -> "LeRobotDataset": @@ -740,7 +757,6 @@ def create( obj.repo_id = repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id obj.tolerance_s = tolerance_s - obj.image_writer = image_writer if not all(cam.fps == fps for cam in robot.cameras.values()): logging.warning( @@ -755,20 +771,24 @@ def create( # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj._create_episode_buffer() + obj.image_writer = None + if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera): + obj.start_image_writter( + image_writer_processes, image_writer_threads_per_camera * robot.num_cameras + ) + # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It # is used to know when certain operations are need (for instance, computing dataset statistics). In # order to be able to push the dataset to the hub, it needs to be consolidated first by calling # self.consolidate(). obj.consolidated = True - obj.local_files_only = True - obj.download_videos = False - obj.episodes = None obj.hf_dataset = None obj.image_transforms = None obj.delta_timestamps = None obj.delta_indices = None + obj.local_files_only = True obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else "pyav" return obj diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 1185db20e..029751481 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -105,7 +105,6 @@ from typing import List # from safetensors.torch import load_file, save_file -from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.robot_devices.control_utils import ( control_loop, @@ -232,17 +231,14 @@ def record( # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) - if len(robot.cameras) > 0: - image_writer = ImageWriter( - write_dir=root, - num_processes=num_image_writer_processes, - num_threads=num_image_writer_threads_per_camera * robot.num_cameras, - ) - else: - image_writer = None - dataset = LeRobotDataset.create( - repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video + repo_id, + fps, + robot, + root=root, + image_writer_processes=num_image_writer_processes, + image_writer_threads_per_camera=num_image_writer_threads_per_camera, + use_videos=video, ) if not robot.is_connected: