diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 0900d9105..6801bc5dd 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -54,7 +54,7 @@ class ImageWriter: """ def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): - self.dir = write_dir / "images" + self.dir = write_dir self.dir.mkdir(parents=True, exist_ok=True) self.image_path = DEFAULT_IMAGE_PATH self.num_processes = num_processes diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 6a1d3719c..e95f53c97 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -35,6 +35,7 @@ INFO_PATH, STATS_PATH, TASKS_PATH, + _get_info_from_robot, append_jsonl, check_delta_timestamps, check_timestamps_sync, @@ -683,7 +684,7 @@ def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> N ) self.image_writer = ImageWriter( - write_dir=self.root, + write_dir=self.root / "images", num_processes=num_processes, num_threads=num_threads, ) @@ -734,6 +735,7 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F ) # TODO(aliberts) + # - [ ] add video info in info.json # Sanity checks: # - [ ] shapes # - [ ] ep_lenghts @@ -744,8 +746,14 @@ def create( cls, repo_id: str, fps: int, - robot: Robot, root: Path | None = None, + robot: Robot | None = None, + robot_type: str | None = None, + keys: list[str] | None = None, + image_keys: list[str] | None = None, + video_keys: list[str] = None, + shapes: dict | None = None, + names: dict | None = None, tolerance_s: float = 1e-4, image_writer_processes: int = 0, image_writer_threads_per_camera: int = 0, @@ -757,26 +765,41 @@ def create( obj.repo_id = repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id obj.tolerance_s = tolerance_s + obj.image_writer = None - if not all(cam.fps == fps for cam in robot.cameras.values()): - logging.warning( - f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." - "In this case, frames from lower fps cameras will be repeated to fill in the blanks" - ) + if robot is not None: + robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) + if not all(cam.fps == fps for cam in robot.cameras.values()): + logging.warning( + f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." + "In this case, frames from lower fps cameras will be repeated to fill in the blanks" + ) + if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera): + obj.start_image_writter( + image_writer_processes, image_writer_threads_per_camera * robot.num_cameras + ) + elif ( + robot_type is None + or keys is None + or image_keys is None + or video_keys is None + or shapes is None + or names is None + ): + raise ValueError() + + if len(video_keys) > 0 and not use_videos: + raise ValueError obj.tasks, obj.stats, obj.episode_dicts = {}, {}, [] - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos) + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names + ) write_json(obj.info, obj.root / INFO_PATH) # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj._create_episode_buffer() - obj.image_writer = None - if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera): - obj.start_image_writter( - image_writer_processes, image_writer_threads_per_camera * robot.num_cameras - ) - # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It # is used to know when certain operations are need (for instance, computing dataset statistics). In # order to be able to push the dataset to the hub, it needs to be consolidated first by calling diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index ccb57197d..f2ce9b55a 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -193,7 +193,7 @@ def load_episode_dicts(local_dir: Path) -> dict: return list(reader) -def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict: +def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]: shapes = {key: len(names) for key, names in robot.names.items()} camera_shapes = {} for key, cam in robot.cameras.items(): @@ -203,10 +203,30 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use "height": cam.height, "channels": cam.channels, } + keys = list(robot.names) + image_keys = [] if use_videos else list(camera_shapes) + video_keys = list(camera_shapes) if use_videos else [] + shapes = {**shapes, **camera_shapes} + names = robot.names + robot_type = robot.robot_type + + return robot_type, keys, image_keys, video_keys, shapes, names + + +def create_empty_dataset_info( + codebase_version: str, + fps: int, + robot_type: str, + keys: list[str], + image_keys: list[str], + video_keys: list[str], + shapes: dict, + names: dict, +) -> dict: return { "codebase_version": codebase_version, "data_path": DEFAULT_PARQUET_PATH, - "robot_type": robot.robot_type, + "robot_type": robot_type, "total_episodes": 0, "total_frames": 0, "total_tasks": 0, @@ -215,12 +235,12 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use "chunks_size": DEFAULT_CHUNK_SIZE, "fps": fps, "splits": {}, - "keys": list(robot.names), - "video_keys": list(camera_shapes) if use_videos else [], - "image_keys": [] if use_videos else list(camera_shapes), - "shapes": {**shapes, **camera_shapes}, - "names": robot.names, - "videos": {"videos_path": DEFAULT_VIDEO_PATH} if use_videos else None, + "keys": keys, + "video_keys": video_keys, + "image_keys": image_keys, + "shapes": shapes, + "names": names, + "videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None, }