diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index bafac2e1e..870414b5e 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -19,9 +19,6 @@ import einops import torch import tqdm -from datasets import Image - -from lerobot.common.datasets.video_utils import VideoFrame def get_stats_einops_patterns(dataset, num_workers=0): @@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0): batch = next(iter(dataloader)) stats_patterns = {} - for key, feats_type in dataset.features.items(): - # NOTE: skip language_instruction embedding in stats computation - if key == "language_instruction": - continue + for key in dataset.features: # sanity check that tensors are not float64 assert batch[key].dtype != torch.float64 - if isinstance(feats_type, (VideoFrame, Image)): + # if isinstance(feats_type, (VideoFrame, Image)): + if key in dataset.camera_keys: # sanity check that images are channel first _, c, h, w = batch[key].shape assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" @@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0): elif batch[key].ndim == 1: stats_patterns[key] = "b -> 1" else: - raise ValueError(f"{key}, {feats_type}, {batch[key].shape}") + raise ValueError(f"{key}, {batch[key].shape}") return stats_patterns diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 6801bc5dd..8f368ef24 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -53,45 +53,54 @@ class ImageWriter: the number of threads. If it is still not stable, try to use 1 subprocess, or more. """ - def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): + def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10): self.dir = write_dir self.dir.mkdir(parents=True, exist_ok=True) self.image_path = DEFAULT_IMAGE_PATH self.num_processes = num_processes - self.num_threads = self.num_threads_per_process = num_threads + self.num_threads = num_threads + self.timeout = timeout - if self.num_processes <= 0: + if self.num_processes == 0 and self.num_threads == 0: + self.type = "synchronous" + elif self.num_processes == 0 and self.num_threads > 0: self.type = "threads" self.threads = ThreadPoolExecutor(max_workers=self.num_threads) self.futures = [] else: self.type = "processes" - self.num_threads_per_process = self.num_threads + self.main_event = multiprocessing.Event() self.image_queue = multiprocessing.Queue() self.processes: list[multiprocessing.Process] = [] - for _ in range(num_processes): - process = multiprocessing.Process(target=self._loop_to_save_images_in_threads) + self.events: list[multiprocessing.Event] = [] + for _ in range(self.num_processes): + event = multiprocessing.Event() + process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,)) process.start() self.processes.append(process) + self.events.append(event) - def _loop_to_save_images_in_threads(self) -> None: + def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None: with ThreadPoolExecutor(max_workers=self.num_threads) as executor: futures = [] while True: frame_data = self.image_queue.get() if frame_data is None: - break + self._wait_threads(self.futures, 10) + return image, file_path = frame_data futures.append(executor.submit(self._save_image, image, file_path)) - with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: - wait(futures) - progress_bar.update(len(futures)) + if self.main_event.is_set(): + self._wait_threads(self.futures, 10) + event.set() def async_save_image(self, image: torch.Tensor, file_path: Path) -> None: """Save an image asynchronously using threads or processes.""" - if self.type == "threads": + if self.type == "synchronous": + self._save_image(image, file_path) + elif self.type == "threads": self.futures.append(self.threads.submit(self._save_image, image, file_path)) else: self.image_queue.put((image, file_path)) @@ -111,12 +120,33 @@ def get_episode_dir(self, episode_index: int, image_key: str) -> Path: episode_index=episode_index, image_key=image_key, frame_index=0 ).parent - def stop(self, timeout=20) -> None: + def wait(self) -> None: + """Wait for the thread/processes to finish writing.""" + if self.type == "synchronous": + return + elif self.type == "threads": + self._wait_threads(self.futures) + else: + self._wait_processes() + + def _wait_threads(self, futures) -> None: + with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: + wait(futures, timeout=self.timeout) + progress_bar.update(len(futures)) + + def _wait_processes(self) -> None: + self.main_event.set() + for event in self.events: + event.wait() + + self.main_event.clear() + + def shutdown(self, timeout=20) -> None: """Stop the image writer, waiting for all processes or threads to finish.""" - if self.type == "threads": - with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar: - wait(self.futures, timeout=timeout) - progress_bar.update(len(self.futures)) + if self.type == "synchronous": + return + elif self.type == "threads": + self.threads.shutdown(wait=True) else: self._stop_processes(timeout) @@ -127,8 +157,9 @@ def _stop_processes(self, timeout) -> None: for process in self.processes: process.join(timeout=timeout) - if process.is_alive(): - process.terminate() + for process in self.processes: + if process.is_alive(): + process.terminate() self.image_queue.close() self.image_queue.join_thread() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d4e6d2263..f451be281 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -22,10 +22,10 @@ from typing import Callable import datasets -import pyarrow.parquet as pq import torch import torch.utils from datasets import load_dataset +from datasets.table import embed_table_storage from huggingface_hub import snapshot_download, upload_folder from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats @@ -57,6 +57,7 @@ VideoFrame, decode_video_frames_torchvision, encode_video_frames, + get_video_info, ) from lerobot.common.robot_devices.robots.utils import Robot @@ -391,7 +392,11 @@ def shapes(self) -> dict: return self.info["shapes"] @property - def features(self) -> datasets.Features: + def features(self) -> list[str]: + return list(self._features) + self.video_keys + + @property + def _features(self) -> datasets.Features: """Features of the hf_dataset.""" if self.hf_dataset is not None: return self.hf_dataset.features @@ -583,6 +588,7 @@ def add_frame(self, frame: dict) -> None: image=frame[cam_key], file_path=img_path, ) + if cam_key in self.image_keys: self.episode_buffer[cam_key].append(str(img_path)) @@ -592,7 +598,7 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None: disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to the hub. - Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise, + Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise, you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend time for video encoding. """ @@ -608,7 +614,7 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None: for key in self.episode_buffer: if key in self.image_keys: continue - if key in self.keys: + elif key in self.keys: self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) elif key == "episode_index": self.episode_buffer[key] = torch.full((episode_length,), episode_index) @@ -619,6 +625,8 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None: self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length) self._save_episode_to_metadata(episode_index, episode_length, task, task_index) + + self._wait_image_writer() self._save_episode_table(episode_index) if encode_videos and len(self.video_keys) > 0: @@ -629,11 +637,17 @@ def add_episode(self, task: str, encode_videos: bool = False) -> None: self.consolidated = False def _save_episode_table(self, episode_index: int) -> None: - ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train") - ep_table = ep_dataset._data.table + ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train") ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) - pq.write_table(ep_table, ep_data_path) + + # Embed image bytes into the table before saving to parquet + format = ep_dataset.format + ep_dataset = ep_dataset.with_format("arrow") + ep_dataset = ep_dataset.map(embed_table_storage, batched=False) + ep_dataset = ep_dataset.with_format(**format) + + ep_dataset.to_parquet(ep_data_path) def _save_episode_to_metadata( self, episode_index: int, episode_length: int, task: str, task_index: int @@ -677,7 +691,7 @@ def clear_episode_buffer(self) -> None: # Reset the buffer self.episode_buffer = self._create_episode_buffer() - def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None: + def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None: if isinstance(self.image_writer, ImageWriter): logging.warning( "You are starting a new ImageWriter that is replacing an already exising one in the dataset." @@ -689,18 +703,23 @@ def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> N num_threads=num_threads, ) - def stop_image_writter(self) -> None: + def stop_image_writer(self) -> None: """ Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized. """ if self.image_writer is not None: - self.image_writer.stop() + self.image_writer.shutdown() self.image_writer = None + def _wait_image_writer(self) -> None: + """Wait for asynchronous image writer to finish.""" + if self.image_writer is not None: + self.image_writer.wait() + def encode_videos(self) -> None: # Use ffmpeg to convert frames stored as png into mp4 videos - for episode_index in range(self.num_episodes): + for episode_index in range(self.total_episodes): for key in self.video_keys: # TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need # to call self.image_writer here @@ -713,6 +732,18 @@ def encode_videos(self) -> None: # since video encoding with ffmpeg is already using multithreading. encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) + def _write_video_info(self) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + for key in self.video_keys: + if key not in self.info["videos"]: + video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + self.info["videos"][key] = get_video_info(video_path) + + write_json(self.info, self.root / INFO_PATH) + def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) @@ -720,12 +751,13 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F if len(self.video_keys) > 0: self.encode_videos() + self._write_video_info() if not keep_image_files and self.image_writer is not None: shutil.rmtree(self.image_writer.dir) if run_compute_stats: - self.stop_image_writter() + self.stop_image_writer() self.stats = compute_stats(self) write_stats(self.stats, self.root / STATS_PATH) self.consolidated = True @@ -735,7 +767,7 @@ def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = F ) # TODO(aliberts) - # - [ ] add video info in info.json + # - [X] add video info in info.json # Sanity checks: # - [ ] shapes # - [ ] ep_lenghts @@ -775,7 +807,7 @@ def create( "In this case, frames from lower fps cameras will be repeated to fill in the blanks" ) if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera): - obj.start_image_writter( + obj.start_image_writer( image_writer_processes, image_writer_threads_per_camera * robot.num_cameras ) elif ( @@ -791,7 +823,7 @@ def create( ) if len(video_keys) > 0 and not use_videos: - raise ValueError + raise ValueError() obj.tasks, obj.stats, obj.episode_dicts = {}, {}, [] obj.info = create_empty_dataset_info( @@ -918,7 +950,7 @@ def video(self) -> bool: def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys}) + features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys}) return features @property diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 120076b92..10312272b 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -116,7 +116,6 @@ from datasets import Dataset from huggingface_hub import HfApi from huggingface_hub.errors import EntryNotFoundError -from PIL import Image from safetensors.torch import load_file from lerobot.common.datasets.utils import ( @@ -136,7 +135,12 @@ write_json, write_jsonlines, ) -from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401 +from lerobot.common.datasets.video_utils import ( + VideoFrame, # noqa: F401 + get_image_shapes, + get_video_info, + get_video_shapes, +) from lerobot.common.utils.utils import init_hydra_config V16 = "v1.6" @@ -391,83 +395,6 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st return [f for f in video_files if f not in lfs_tracked_files] -def _get_audio_info(video_path: Path | str) -> dict: - ffprobe_audio_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "a:0", - "-show_entries", - "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") - - info = json.loads(result.stdout) - audio_stream_info = info["streams"][0] if info.get("streams") else None - if audio_stream_info is None: - return {"has_audio": False} - - # Return the information, defaulting to None if no audio stream is present - return { - "has_audio": True, - "audio.channels": audio_stream_info.get("channels", None), - "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, - "audio.sample_rate": int(audio_stream_info["sample_rate"]) - if audio_stream_info.get("sample_rate") - else None, - "audio.bit_depth": audio_stream_info.get("bit_depth", None), - "audio.channel_layout": audio_stream_info.get("channel_layout", None), - } - - -def _get_video_info(video_path: Path | str) -> dict: - ffprobe_video_cmd = [ - "ffprobe", - "-v", - "error", - "-select_streams", - "v:0", - "-show_entries", - "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt", - "-of", - "json", - str(video_path), - ] - result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if result.returncode != 0: - raise RuntimeError(f"Error running ffprobe: {result.stderr}") - - info = json.loads(result.stdout) - video_stream_info = info["streams"][0] - - # Calculate fps from r_frame_rate - r_frame_rate = video_stream_info["r_frame_rate"] - num, denom = map(int, r_frame_rate.split("/")) - fps = num / denom - - pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"]) - - video_info = { - "video.fps": fps, - "video.width": video_stream_info["width"], - "video.height": video_stream_info["height"], - "video.channels": pixel_channels, - "video.codec": video_stream_info["codec_name"], - "video.pix_fmt": video_stream_info["pix_fmt"], - "video.is_depth_map": False, - **_get_audio_info(video_path), - } - - return video_info - - def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: hub_api = HfApi() videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH} @@ -481,62 +408,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files ) for vid_key, vid_path in zip(video_keys, video_files, strict=True): - videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path) + videos_info_dict[vid_key] = get_video_info(local_dir / vid_path) return videos_info_dict -def get_video_pixel_channels(pix_fmt: str) -> int: - if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: - return 1 - elif "rgba" in pix_fmt or "yuva" in pix_fmt: - return 4 - elif "rgb" in pix_fmt or "yuv" in pix_fmt: - return 3 - else: - raise ValueError("Unknown format") - - -def get_image_pixel_channels(image: Image): - if image.mode == "L": - return 1 # Grayscale - elif image.mode == "LA": - return 2 # Grayscale + Alpha - elif image.mode == "RGB": - return 3 # RGB - elif image.mode == "RGBA": - return 4 # RGBA - else: - raise ValueError("Unknown format") - - -def get_video_shapes(videos_info: dict, video_keys: list) -> dict: - video_shapes = {} - for img_key in video_keys: - channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"]) - video_shapes[img_key] = { - "width": videos_info[img_key]["video.width"], - "height": videos_info[img_key]["video.height"], - "channels": channels, - } - - return video_shapes - - -def get_image_shapes(dataset: Dataset, image_keys: list) -> dict: - image_shapes = {} - for img_key in image_keys: - image = dataset[0][img_key] # Assuming first row - channels = get_image_pixel_channels(image) - image_shapes[img_key] = { - "width": image.width, - "height": image.height, - "channels": channels, - } - - return image_shapes - - def get_generic_motor_names(sequence_shapes: dict) -> dict: return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()} diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index b5d634ba0..48f22435c 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import subprocess import warnings @@ -24,7 +25,9 @@ import pyarrow as pa import torch import torchvision +from datasets import Dataset from datasets.features.features import register_feature +from PIL import Image def decode_video_frames_torchvision( @@ -210,3 +213,131 @@ def __call__(self): ) # to make VideoFrame available in HuggingFace `datasets` register_feature(VideoFrame, "VideoFrame") + + +def get_audio_info(video_path: Path | str) -> dict: + ffprobe_audio_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "a:0", + "-show_entries", + "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + audio_stream_info = info["streams"][0] if info.get("streams") else None + if audio_stream_info is None: + return {"has_audio": False} + + # Return the information, defaulting to None if no audio stream is present + return { + "has_audio": True, + "audio.channels": audio_stream_info.get("channels", None), + "audio.codec": audio_stream_info.get("codec_name", None), + "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.sample_rate": int(audio_stream_info["sample_rate"]) + if audio_stream_info.get("sample_rate") + else None, + "audio.bit_depth": audio_stream_info.get("bit_depth", None), + "audio.channel_layout": audio_stream_info.get("channel_layout", None), + } + + +def get_video_info(video_path: Path | str) -> dict: + ffprobe_video_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + video_stream_info = info["streams"][0] + + # Calculate fps from r_frame_rate + r_frame_rate = video_stream_info["r_frame_rate"] + num, denom = map(int, r_frame_rate.split("/")) + fps = num / denom + + pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"]) + + video_info = { + "video.fps": fps, + "video.width": video_stream_info["width"], + "video.height": video_stream_info["height"], + "video.channels": pixel_channels, + "video.codec": video_stream_info["codec_name"], + "video.pix_fmt": video_stream_info["pix_fmt"], + "video.is_depth_map": False, + **get_audio_info(video_path), + } + + return video_info + + +def get_video_shapes(videos_info: dict, video_keys: list) -> dict: + video_shapes = {} + for img_key in video_keys: + channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"]) + video_shapes[img_key] = { + "width": videos_info[img_key]["video.width"], + "height": videos_info[img_key]["video.height"], + "channels": channels, + } + + return video_shapes + + +def get_image_shapes(dataset: Dataset, image_keys: list) -> dict: + image_shapes = {} + for img_key in image_keys: + image = dataset[0][img_key] # Assuming first row + channels = get_image_pixel_channels(image) + image_shapes[img_key] = { + "width": image.width, + "height": image.height, + "channels": channels, + } + + return image_shapes + + +def get_video_pixel_channels(pix_fmt: str) -> int: + if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: + return 1 + elif "rgba" in pix_fmt or "yuva" in pix_fmt: + return 4 + elif "rgb" in pix_fmt or "yuv" in pix_fmt: + return 3 + else: + raise ValueError("Unknown format") + + +def get_image_pixel_channels(image: Image): + if image.mode == "L": + return 1 # Grayscale + elif image.mode == "LA": + return 2 # Grayscale + Alpha + elif image.mode == "RGB": + return 3 # RGB + elif image.mode == "RGBA": + return 4 # RGBA + else: + raise ValueError("Unknown format") diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 029751481..f3424e57d 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -234,8 +234,8 @@ def record( dataset = LeRobotDataset.create( repo_id, fps, - robot, root=root, + robot=robot, image_writer_processes=num_image_writer_processes, image_writer_threads_per_camera=num_image_writer_threads_per_camera, use_videos=video, @@ -307,10 +307,6 @@ def record( log_say("Stop recording", play_sounds, blocking=True) stop_recording(robot, listener, display_cameras) - if dataset.image_writer is not None: - logging.info("Waiting for image writer to terminate...") - dataset.image_writer.stop() - if run_compute_stats: logging.info("Computing dataset statistics")