Skip to content

Commit

Permalink
Add video_info, fix image_writer
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 25, 2024
1 parent 18ffa42 commit e210d79
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 180 deletions.
13 changes: 4 additions & 9 deletions lerobot/common/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
import einops
import torch
import tqdm
from datasets import Image

from lerobot.common.datasets.video_utils import VideoFrame


def get_stats_einops_patterns(dataset, num_workers=0):
Expand All @@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0):
batch = next(iter(dataloader))

stats_patterns = {}
for key, feats_type in dataset.features.items():
# NOTE: skip language_instruction embedding in stats computation
if key == "language_instruction":
continue

for key in dataset.features:
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64

if isinstance(feats_type, (VideoFrame, Image)):
# if isinstance(feats_type, (VideoFrame, Image)):
if key in dataset.camera_keys:
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
Expand All @@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
raise ValueError(f"{key}, {batch[key].shape}")

return stats_patterns

Expand Down
69 changes: 50 additions & 19 deletions lerobot/common/datasets/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,45 +53,54 @@ class ImageWriter:
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""

def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10):
self.dir = write_dir
self.dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes
self.num_threads = self.num_threads_per_process = num_threads
self.num_threads = num_threads
self.timeout = timeout

if self.num_processes <= 0:
if self.num_processes == 0 and self.num_threads == 0:
self.type = "synchronous"
elif self.num_processes == 0 and self.num_threads > 0:
self.type = "threads"
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
self.futures = []
else:
self.type = "processes"
self.num_threads_per_process = self.num_threads
self.main_event = multiprocessing.Event()
self.image_queue = multiprocessing.Queue()
self.processes: list[multiprocessing.Process] = []
for _ in range(num_processes):
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads)
self.events: list[multiprocessing.Event] = []
for _ in range(self.num_processes):
event = multiprocessing.Event()
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,))
process.start()
self.processes.append(process)
self.events.append(event)

def _loop_to_save_images_in_threads(self) -> None:
def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = []
while True:
frame_data = self.image_queue.get()
if frame_data is None:
break
self._wait_threads(self.futures, 10)
return

image, file_path = frame_data
futures.append(executor.submit(self._save_image, image, file_path))

with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
wait(futures)
progress_bar.update(len(futures))
if self.main_event.is_set():
self._wait_threads(self.futures, 10)
event.set()

def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
"""Save an image asynchronously using threads or processes."""
if self.type == "threads":
if self.type == "synchronous":
self._save_image(image, file_path)
elif self.type == "threads":
self.futures.append(self.threads.submit(self._save_image, image, file_path))
else:
self.image_queue.put((image, file_path))
Expand All @@ -111,12 +120,33 @@ def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
episode_index=episode_index, image_key=image_key, frame_index=0
).parent

def stop(self, timeout=20) -> None:
def wait(self) -> None:
"""Wait for the thread/processes to finish writing."""
if self.type == "synchronous":
return
elif self.type == "threads":
self._wait_threads(self.futures)
else:
self._wait_processes()

def _wait_threads(self, futures) -> None:
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
wait(futures, timeout=self.timeout)
progress_bar.update(len(futures))

def _wait_processes(self) -> None:
self.main_event.set()
for event in self.events:
event.wait()

self.main_event.clear()

def shutdown(self, timeout=20) -> None:
"""Stop the image writer, waiting for all processes or threads to finish."""
if self.type == "threads":
with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar:
wait(self.futures, timeout=timeout)
progress_bar.update(len(self.futures))
if self.type == "synchronous":
return
elif self.type == "threads":
self.threads.shutdown(wait=True)
else:
self._stop_processes(timeout)

Expand All @@ -127,8 +157,9 @@ def _stop_processes(self, timeout) -> None:
for process in self.processes:
process.join(timeout=timeout)

if process.is_alive():
process.terminate()
for process in self.processes:
if process.is_alive():
process.terminate()

self.image_queue.close()
self.image_queue.join_thread()
64 changes: 48 additions & 16 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from typing import Callable

import datasets
import pyarrow.parquet as pq
import torch
import torch.utils
from datasets import load_dataset
from datasets.table import embed_table_storage
from huggingface_hub import snapshot_download, upload_folder

from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
Expand Down Expand Up @@ -57,6 +57,7 @@
VideoFrame,
decode_video_frames_torchvision,
encode_video_frames,
get_video_info,
)
from lerobot.common.robot_devices.robots.utils import Robot

Expand Down Expand Up @@ -391,7 +392,11 @@ def shapes(self) -> dict:
return self.info["shapes"]

@property
def features(self) -> datasets.Features:
def features(self) -> list[str]:
return list(self._features) + self.video_keys

@property
def _features(self) -> datasets.Features:
"""Features of the hf_dataset."""
if self.hf_dataset is not None:
return self.hf_dataset.features
Expand Down Expand Up @@ -583,6 +588,7 @@ def add_frame(self, frame: dict) -> None:
image=frame[cam_key],
file_path=img_path,
)

if cam_key in self.image_keys:
self.episode_buffer[cam_key].append(str(img_path))

Expand All @@ -592,7 +598,7 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
"""
Expand All @@ -608,7 +614,7 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
for key in self.episode_buffer:
if key in self.image_keys:
continue
if key in self.keys:
elif key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
Expand All @@ -619,6 +625,8 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:

self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)

self._wait_image_writer()
self._save_episode_table(episode_index)

if encode_videos and len(self.video_keys) > 0:
Expand All @@ -629,11 +637,17 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None:
self.consolidated = False

def _save_episode_table(self, episode_index: int) -> None:
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
ep_table = ep_dataset._data.table
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(ep_table, ep_data_path)

# Embed image bytes into the table before saving to parquet
format = ep_dataset.format
ep_dataset = ep_dataset.with_format("arrow")
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
ep_dataset = ep_dataset.with_format(**format)

ep_dataset.to_parquet(ep_data_path)

def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int
Expand Down Expand Up @@ -677,7 +691,7 @@ def clear_episode_buffer(self) -> None:
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()

def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
if isinstance(self.image_writer, ImageWriter):
logging.warning(
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
Expand All @@ -689,18 +703,23 @@ def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> N
num_threads=num_threads,
)

def stop_image_writter(self) -> None:
def stop_image_writer(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()
self.image_writer.shutdown()
self.image_writer = None

def _wait_image_writer(self) -> None:
"""Wait for asynchronous image writer to finish."""
if self.image_writer is not None:
self.image_writer.wait()

def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(self.num_episodes):
for episode_index in range(self.total_episodes):
for key in self.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
Expand All @@ -713,19 +732,32 @@ def encode_videos(self) -> None:
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)

def _write_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in self.video_keys:
if key not in self.info["videos"]:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["videos"][key] = get_video_info(video_path)

write_json(self.info, self.root / INFO_PATH)

def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)

if len(self.video_keys) > 0:
self.encode_videos()
self._write_video_info()

if not keep_image_files and self.image_writer is not None:
shutil.rmtree(self.image_writer.dir)

if run_compute_stats:
self.stop_image_writter()
self.stop_image_writer()
self.stats = compute_stats(self)
write_stats(self.stats, self.root / STATS_PATH)
self.consolidated = True
Expand All @@ -735,7 +767,7 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F
)

# TODO(aliberts)
# - [ ] add video info in info.json
# - [X] add video info in info.json
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
Expand Down Expand Up @@ -775,7 +807,7 @@ def create(
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
obj.start_image_writer(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)
elif (
Expand All @@ -791,7 +823,7 @@ def create(
)

if len(video_keys) > 0 and not use_videos:
raise ValueError
raise ValueError()

obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(
Expand Down Expand Up @@ -918,7 +950,7 @@ def video(self) -> bool:
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
return features

@property
Expand Down
Loading

0 comments on commit e210d79

Please sign in to comment.