diff --git a/benchmark/video/run_video_benchmark.py b/benchmark/video/run_video_benchmark.py index 683b8d95a..d25c2bc02 100644 --- a/benchmark/video/run_video_benchmark.py +++ b/benchmark/video/run_video_benchmark.py @@ -63,7 +63,7 @@ ) -def parse_int_or_none(value): +def parse_int_or_none(value) -> int | None: if value.lower() == "none": return None try: @@ -115,10 +115,10 @@ def save_decoded_frames( shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png") -def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> Path: +def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: ep_num_images = dataset.episode_data_index["to"][0].item() if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images: - return imgs_dir + return imgs_dir.mkdir(parents=True, exist_ok=True) hf_dataset = dataset.hf_dataset.with_format(None) @@ -136,10 +136,8 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> Path: if i >= ep_num_images - 1: break - return imgs_dir - -def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int): +def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]: # Start at 5 to allow for 2_frames_4_space and 6_frames idx = random.randint(5, ep_num_images - 1) match timestamps_mode: @@ -245,7 +243,7 @@ def benchmark_encoding_decoding( save_frames: bool, overwrite: bool = False, seed: int = 1337, -): +) -> list[dict]: fps = dataset.fps if overwrite or not video_path.is_file():