Skip to content

Commit

Permalink
Allow dataset creation without robot
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 23, 2024
1 parent 0d77be9 commit 60865e8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 23 deletions.
2 changes: 1 addition & 1 deletion lerobot/common/datasets/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ImageWriter:
"""

def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.dir = write_dir / "images"
self.dir = write_dir
self.dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes
Expand Down
51 changes: 37 additions & 14 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
INFO_PATH,
STATS_PATH,
TASKS_PATH,
_get_info_from_robot,
append_jsonl,
check_delta_timestamps,
check_timestamps_sync,
Expand Down Expand Up @@ -683,7 +684,7 @@ def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> N
)

self.image_writer = ImageWriter(
write_dir=self.root,
write_dir=self.root / "images",
num_processes=num_processes,
num_threads=num_threads,
)
Expand Down Expand Up @@ -734,6 +735,7 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F
)

# TODO(aliberts)
# - [ ] add video info in info.json
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
Expand All @@ -744,8 +746,14 @@ def create(
cls,
repo_id: str,
fps: int,
robot: Robot,
root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
keys: list[str] | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = None,
shapes: dict | None = None,
names: dict | None = None,
tolerance_s: float = 1e-4,
image_writer_processes: int = 0,
image_writer_threads_per_camera: int = 0,
Expand All @@ -757,26 +765,41 @@ def create(
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.tolerance_s = tolerance_s
obj.image_writer = None

if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
if robot is not None:
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)
elif (
robot_type is None
or keys is None
or image_keys is None
or video_keys is None
or shapes is None
or names is None
):
raise ValueError()

if len(video_keys) > 0 and not use_videos:
raise ValueError

obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
)
write_json(obj.info, obj.root / INFO_PATH)

# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()

obj.image_writer = None
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)

# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
Expand Down
36 changes: 28 additions & 8 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def load_episode_dicts(local_dir: Path) -> dict:
return list(reader)


def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
shapes = {key: len(names) for key, names in robot.names.items()}
camera_shapes = {}
for key, cam in robot.cameras.items():
Expand All @@ -203,10 +203,30 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use
"height": cam.height,
"channels": cam.channels,
}
keys = list(robot.names)
image_keys = [] if use_videos else list(camera_shapes)
video_keys = list(camera_shapes) if use_videos else []
shapes = {**shapes, **camera_shapes}
names = robot.names
robot_type = robot.robot_type

return robot_type, keys, image_keys, video_keys, shapes, names


def create_empty_dataset_info(
codebase_version: str,
fps: int,
robot_type: str,
keys: list[str],
image_keys: list[str],
video_keys: list[str],
shapes: dict,
names: dict,
) -> dict:
return {
"codebase_version": codebase_version,
"data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot.robot_type,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
Expand All @@ -215,12 +235,12 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps,
"splits": {},
"keys": list(robot.names),
"video_keys": list(camera_shapes) if use_videos else [],
"image_keys": [] if use_videos else list(camera_shapes),
"shapes": {**shapes, **camera_shapes},
"names": robot.names,
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if use_videos else None,
"keys": keys,
"video_keys": video_keys,
"image_keys": image_keys,
"shapes": shapes,
"names": names,
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
}


Expand Down

0 comments on commit 60865e8

Please sign in to comment.