Skip to content

Commit

Permalink
Add types
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Jun 22, 2024
1 parent a4b3e18 commit 7550918
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions benchmark/video/run_video_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
)


def parse_int_or_none(value):
def parse_int_or_none(value) -> int | None:
if value.lower() == "none":
return None
try:
Expand Down Expand Up @@ -115,10 +115,10 @@ def save_decoded_frames(
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")


def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> Path:
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
ep_num_images = dataset.episode_data_index["to"][0].item()
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
return imgs_dir
return

imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None)
Expand All @@ -136,10 +136,8 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> Path:
if i >= ep_num_images - 1:
break

return imgs_dir


def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int):
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
Expand Down Expand Up @@ -245,7 +243,7 @@ def benchmark_encoding_decoding(
save_frames: bool,
overwrite: bool = False,
seed: int = 1337,
):
) -> list[dict]:
fps = dataset.fps

if overwrite or not video_path.is_file():
Expand Down

0 comments on commit 7550918

Please sign in to comment.