Skip to content

Commit

Permalink
Cleanup video benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Jun 19, 2024
1 parent 5975c00 commit 62509c2
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import random
import shutil
import subprocess
import time
from pathlib import Path

import einops
Expand All @@ -47,6 +46,7 @@
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision,
)
from lerobot.common.utils.benchmark import TimeBenchmark

OUTPUT_DIR = Path("tmp/run_video_benchmark")
DRY_RUN = False
Expand Down Expand Up @@ -148,19 +148,6 @@ def run_video_benchmark(

video_size_bytes = video_path.stat().st_size

# Set decoder

decoder = cfg["decoder"]
decoder_kwgs = cfg["decoder_kwgs"]
backend = cfg["backend"]

if decoder == "torchvision":
decode_frames_fn = decode_video_frames_torchvision
else:
raise ValueError(decoder)

# Estimate average loading time

def load_original_frames(imgs_dir, timestamps) -> torch.Tensor:
frames = []
for ts in timestamps:
Expand All @@ -179,10 +166,10 @@ def load_original_frames(imgs_dir, timestamps) -> torch.Tensor:
ssim_values = []
mse_values = []

benchmark = TimeBenchmark()
random.seed(seed)

for t in range(50):
# test loading 2 frames that are 4 frames appart, which might be a common setting
ts = random.randint(fps, ep_num_images - fps) / fps

if timestamps_mode == "1_frame":
Expand All @@ -198,21 +185,18 @@ def load_original_frames(imgs_dir, timestamps) -> torch.Tensor:

num_frames = len(timestamps)

start_time_s = time.monotonic()
frames = decode_frames_fn(
video_path, timestamps=timestamps, tolerance_s=1e-4, backend=backend, **decoder_kwgs
)
avg_load_time = (time.monotonic() - start_time_s) / num_frames
list_avg_load_time.append(avg_load_time)
with benchmark:
frames = decode_video_frames_torchvision(
video_path, timestamps=timestamps, tolerance_s=1e-4, backend=cfg["backend"]
)
list_avg_load_time.append(benchmark.result / num_frames)

start_time_s = time.monotonic()
original_frames = load_original_frames(imgs_dir, timestamps)
avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames
list_avg_load_time_from_images.append(avg_load_time_from_images)
with benchmark:
original_frames = load_original_frames(imgs_dir, timestamps)
list_avg_load_time_from_images.append(benchmark.result / num_frames)

# Estimate reconstruction error between original frames and decoded frames with various metrics
for i, ts in enumerate(timestamps):
# are_close = torch.allclose(frames[i], original_frames[i], atol=0.02)
num_pixels = original_frames[i].numel()
per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels
per_pixel_l2_errors.append(per_pixel_l2_error)
Expand Down Expand Up @@ -306,6 +290,7 @@ def one_variable_study(
var_name,
"compression_factor",
"load_time_factor",
"avg_load_time_ms",
"avg_per_pixel_l2_error",
"avg_psnr",
"avg_ssim",
Expand All @@ -320,8 +305,6 @@ def one_variable_study(
"pix_fmt": "yuv444p",
# video decoding
"backend": "pyav",
"decoder": "torchvision",
"decoder_kwgs": {},
}
for repo_id in repo_ids:
for val in var_values:
Expand All @@ -341,6 +324,7 @@ def one_variable_study(
val,
info["compression_factor"],
info["load_time_factor"],
info["avg_load_time"] * 1e3,
info["avg_per_pixel_l2_error"],
info["avg_psnr"],
info["avg_ssim"],
Expand All @@ -358,6 +342,7 @@ def best_study(repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: b
"image_size",
"compression_factor",
"load_time_factor",
"avg_load_time_ms",
"avg_per_pixel_l2_error",
"avg_psnr",
"avg_ssim",
Expand All @@ -373,8 +358,6 @@ def best_study(repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: b
"pix_fmt": "yuv444p",
# video decoding
"backend": "video_reader",
"decoder": "torchvision",
"decoder_kwgs": {},
}
if not dry_run:
run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode)
Expand All @@ -386,7 +369,11 @@ def best_study(repo_ids: list, bench_dir: Path, timestamps_mode: str, dry_run: b
f"{width} x {height}",
info["compression_factor"],
info["load_time_factor"],
info["avg_load_time"] * 1e3,
info["avg_per_pixel_l2_error"],
info["avg_psnr"],
info["avg_ssim"],
info["avg_mse"],
]
)
display_markdown_table(headers, rows)
Expand Down

0 comments on commit 62509c2

Please sign in to comment.