Skip to content

Commit

Permalink
Add suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 11, 2024
1 parent 3ea5312 commit 8bd406e
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ def __init__(
repo_id: str,
root: Path | None = None,
episodes: list[int] | None = None,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_data: bool = True,
download_videos: bool = True,
video_backend: str | None = None,
):
"""LeRobotDataset encapsulates 3 main things:
Expand All @@ -64,7 +63,7 @@ def __init__(
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
3 use modes are available for this class, depending on 3 different use cases:
3 modes are available for this class, depending on 3 different use cases:
1. Your dataset already exists on the Hugging Face Hub at the address
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
Expand Down Expand Up @@ -119,7 +118,6 @@ def __init__(
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
split (str, optional): _description_. Defaults to "train".
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
from videos or images). Defaults to None.
Expand All @@ -129,19 +127,18 @@ def __init__(
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
download_data (bool, optional): Flag to download actual data. Defaults to True.
download_videos (bool, optional): Flag to download the videos. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
super().__init__()
self.repo_id = repo_id
self.root = root if root is not None else LEROBOT_HOME / repo_id
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.download_data = download_data
self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None

Expand All @@ -152,13 +149,6 @@ def __init__(
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)

if not self.download_data:
# TODO(aliberts): Add actual support for this
# maybe use local_files_only=True or HF_HUB_OFFLINE=True
# see thread https://huggingface.slack.com/archives/C06ME3E7JUD/p1728637455476019
self.hf_dataset, self.episode_data_index = None, None
return

# Load actual data
self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
Expand Down Expand Up @@ -192,12 +182,13 @@ def download_episodes(self) -> None:
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
ignore_patterns = None if self.download_videos else "videos/"
if self.episodes is not None:
files = [
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
for ep_idx in self.episodes
]
if len(self.video_keys) > 0:
if len(self.video_keys) > 0 and self.download_videos:
video_files = [
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in self.video_keys
Expand All @@ -211,6 +202,7 @@ def download_episodes(self) -> None:
revision=self._version,
local_dir=self.root,
allow_patterns=files,
ignore_patterns=ignore_patterns,
)

@property
Expand Down Expand Up @@ -371,7 +363,8 @@ def __getitem__(self, idx) -> dict:
item = {**video_frames, **item}

if self.image_transforms is not None:
for cam in self.camera_keys:
image_keys = self.camera_keys if self.download_videos else self.image_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])

return item
Expand All @@ -380,7 +373,6 @@ def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
Expand Down

0 comments on commit 8bd406e

Please sign in to comment.