diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 000000000..b406d43b8 --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,18 @@ +on: + push: + +name: Secret Leaks + +permissions: + contents: read + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index 70a5b505c..db9840a79 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -46,7 +46,7 @@ defaults: - policy: diffusion ``` -This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_. +This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_. Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`. diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py new file mode 100644 index 000000000..bdcc6d7b9 --- /dev/null +++ b/examples/6_add_image_transforms.py @@ -0,0 +1,52 @@ +""" +This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data +augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and +transforms are applied to the observation images before they are returned in the dataset's __get_item__. +""" + +from pathlib import Path + +from torchvision.transforms import ToPILImage, v2 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +dataset_repo_id = "lerobot/aloha_static_tape" + +# Create a LeRobotDataset with no transformations +dataset = LeRobotDataset(dataset_repo_id) +# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)` + +# Get the index of the first observation in the first episode +first_idx = dataset.episode_data_index["from"][0].item() + +# Get the frame corresponding to the first camera +frame = dataset[first_idx][dataset.camera_keys[0]] + + +# Define the transformations +transforms = v2.Compose( + [ + v2.ColorJitter(brightness=(0.5, 1.5)), + v2.ColorJitter(contrast=(0.5, 1.5)), + v2.RandomAdjustSharpness(sharpness_factor=2, p=1), + ] +) + +# Create another LeRobotDataset with the defined transformations +transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms) + +# Get a frame from the transformed dataset +transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] + +# Create a directory to store output images +output_dir = Path("outputs/image_transforms") +output_dir.mkdir(parents=True, exist_ok=True) + +# Save the original frame +to_pil = ToPILImage() +to_pil(frame).save(output_dir / "original_frame.png", quality=100) +print(f"Original frame saved to {output_dir / 'original_frame.png'}.") + +# Save the transformed frame +to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100) +print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.") diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4732f5774..fab8ca575 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -19,6 +19,7 @@ from omegaconf import ListConfig, OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms def resolve_delta_timestamps(cfg): @@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData resolve_delta_timestamps(cfg) - # TODO(rcadene): add data augmentations + image_transforms = None + if cfg.training.image_transforms.enable: + image_transforms = get_image_transforms( + brightness_weight=cfg.brightness.weight, + brightness_min_max=cfg.brightness.min_max, + contrast_weight=cfg.contrast.weight, + contrast_min_max=cfg.contrast.min_max, + saturation_weight=cfg.saturation.weight, + saturation_min_max=cfg.saturation.min_max, + hue_weight=cfg.hue.weight, + hue_min_max=cfg.hue.min_max, + sharpness_weight=cfg.sharpness.weight, + sharpness_min_max=cfg.sharpness.min_max, + max_num_transforms=cfg.max_num_transforms, + random_order=cfg.random_order, + ) if isinstance(cfg.dataset_repo_id, str): dataset = LeRobotDataset( cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps"), + image_transforms=image_transforms, ) else: dataset = MultiLeRobotDataset( - cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps") + cfg.dataset_repo_id, + split=split, + delta_timestamps=cfg.training.get("delta_timestamps"), + image_transforms=image_transforms, ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 58ae51b17..d680b9875 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -46,7 +46,7 @@ def __init__( version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -54,7 +54,7 @@ def __init__( self.version = version self.root = root self.split = split - self.transform = transform + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps # load data from hub or locally when root is provided # TODO(rcadene, aliberts): implement faster transfer @@ -151,8 +151,9 @@ def __getitem__(self, idx): self.tolerance_s, ) - if self.transform is not None: - item = self.transform(item) + if self.image_transforms is not None: + for cam in self.camera_keys: + item[cam] = self.image_transforms(item[cam]) return item @@ -168,7 +169,7 @@ def __repr__(self): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" ) @@ -202,7 +203,7 @@ def from_preloaded( obj.version = version obj.root = root obj.split = split - obj.transform = transform + obj.image_transforms = transform obj.delta_timestamps = delta_timestamps obj.hf_dataset = hf_dataset obj.episode_data_index = episode_data_index @@ -225,7 +226,7 @@ def __init__( version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -239,7 +240,7 @@ def __init__( root=root, split=split, delta_timestamps=delta_timestamps, - transform=transform, + image_transforms=image_transforms, ) for repo_id in repo_ids ] @@ -274,7 +275,7 @@ def __init__( self.version = version self.root = root self.split = split - self.transform = transform + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.stats = aggregate_stats(self._datasets) @@ -380,6 +381,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: for data_key in self.disabled_data_keys: if data_key in item: del item[data_key] + return item def __repr__(self): @@ -394,6 +396,6 @@ def __repr__(self): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" ) diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py new file mode 100644 index 000000000..899f0d66c --- /dev/null +++ b/lerobot/common/datasets/transforms.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +from typing import Any, Callable, Dict, Sequence + +import torch +from torchvision.transforms import v2 +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2 import functional as F # noqa: N812 + + +class RandomSubsetApply(Transform): + """Apply a random subset of N transformations from a list of transformations. + + Args: + transforms: list of transformations. + p: represents the multinomial probabilities (with no replacement) used for sampling the transform. + If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms + have the same probability. + n_subset: number of transformations to apply. If ``None``, all transforms are applied. + Must be in [1, len(transforms)]. + random_order: apply transformations in a random order. + """ + + def __init__( + self, + transforms: Sequence[Callable], + p: list[float] | None = None, + n_subset: int | None = None, + random_order: bool = False, + ) -> None: + super().__init__() + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + if p is None: + p = [1] * len(transforms) + elif len(p) != len(transforms): + raise ValueError( + f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}" + ) + + if n_subset is None: + n_subset = len(transforms) + elif not isinstance(n_subset, int): + raise TypeError("n_subset should be an int or None") + elif not (1 <= n_subset <= len(transforms)): + raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") + + self.transforms = transforms + total = sum(p) + self.p = [prob / total for prob in p] + self.n_subset = n_subset + self.random_order = random_order + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + + selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset) + if not self.random_order: + selected_indices = selected_indices.sort().values + + selected_transforms = [self.transforms[i] for i in selected_indices] + + for transform in selected_transforms: + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + + return outputs + + def extra_repr(self) -> str: + return ( + f"transforms={self.transforms}, " + f"p={self.p}, " + f"n_subset={self.n_subset}, " + f"random_order={self.random_order}" + ) + + +class SharpnessJitter(Transform): + """Randomly change the sharpness of an image or video. + + Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly. + While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image, + SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of + augmentations as a result. + + A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness + by a factor of 2. + + If the input is a :class:`torch.Tensor`, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from + [max(0, 1 - sharpness), 1 + sharpness] or the given + [min, max]. Should be non negative numbers. + """ + + def __init__(self, sharpness: float | Sequence[float]) -> None: + super().__init__() + self.sharpness = self._check_input(sharpness) + + def _check_input(self, sharpness): + if isinstance(sharpness, (int, float)): + if sharpness < 0: + raise ValueError("If sharpness is a single number, it must be non negative.") + sharpness = [1.0 - sharpness, 1.0 + sharpness] + sharpness[0] = max(sharpness[0], 0.0) + elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: + sharpness = [float(v) for v in sharpness] + else: + raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") + + if not 0.0 <= sharpness[0] <= sharpness[1]: + raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") + + return float(sharpness[0]), float(sharpness[1]) + + def _generate_value(self, left: float, right: float) -> float: + return torch.empty(1).uniform_(left, right).item() + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) + + +def get_image_transforms( + brightness_weight: float = 1.0, + brightness_min_max: tuple[float, float] | None = None, + contrast_weight: float = 1.0, + contrast_min_max: tuple[float, float] | None = None, + saturation_weight: float = 1.0, + saturation_min_max: tuple[float, float] | None = None, + hue_weight: float = 1.0, + hue_min_max: tuple[float, float] | None = None, + sharpness_weight: float = 1.0, + sharpness_min_max: tuple[float, float] | None = None, + max_num_transforms: int | None = None, + random_order: bool = False, +): + def check_value(name, weight, min_max): + if min_max is not None: + if len(min_max) != 2: + raise ValueError( + f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided." + ) + if weight < 0.0: + raise ValueError( + f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})." + ) + + check_value("brightness", brightness_weight, brightness_min_max) + check_value("contrast", contrast_weight, contrast_min_max) + check_value("saturation", saturation_weight, saturation_min_max) + check_value("hue", hue_weight, hue_min_max) + check_value("sharpness", sharpness_weight, sharpness_min_max) + + weights = [] + transforms = [] + if brightness_min_max is not None and brightness_weight > 0.0: + weights.append(brightness_weight) + transforms.append(v2.ColorJitter(brightness=brightness_min_max)) + if contrast_min_max is not None and contrast_weight > 0.0: + weights.append(contrast_weight) + transforms.append(v2.ColorJitter(contrast=contrast_min_max)) + if saturation_min_max is not None and saturation_weight > 0.0: + weights.append(saturation_weight) + transforms.append(v2.ColorJitter(saturation=saturation_min_max)) + if hue_min_max is not None and hue_weight > 0.0: + weights.append(hue_weight) + transforms.append(v2.ColorJitter(hue=hue_min_max)) + if sharpness_min_max is not None and sharpness_weight > 0.0: + weights.append(sharpness_weight) + transforms.append(SharpnessJitter(sharpness=sharpness_min_max)) + + n_subset = len(transforms) + if max_num_transforms is not None: + n_subset = min(n_subset, max_num_transforms) + + if n_subset == 0: + return v2.Identity() + else: + # TODO(rcadene, aliberts): add v2.ToDtype float16? + return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 853acbc3b..bf578fcc5 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -241,5 +241,6 @@ def log_dict(self, d, step, mode="train"): def log_video(self, video_path: str, step: int, mode: str = "train"): assert mode in {"train", "eval"} + assert self._wandb is not None wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4") self._wandb.log({f"{mode}/video": wandb_video}, step=step) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 273f4f758..e0482143d 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -239,10 +239,8 @@ def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) # run sampling - sample = self.conditional_sample(batch_size, global_cond=global_cond) + actions = self.conditional_sample(batch_size, global_cond=global_cond) - # `horizon` steps worth of actions (from the first observation). - actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.config.n_action_steps diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index d638c5416..9b055f7e6 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -147,7 +147,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(max).any(), _no_stats_error_str("max") # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min) + batch[key] = (batch[key] - min) / (max - min + 1e-8) # normalize to [-1, 1] batch[key] = batch[key] * 2 - 1 else: diff --git a/lerobot/common/policies/policy_protocol.py b/lerobot/common/policies/policy_protocol.py index 38738a909..4e9e87afd 100644 --- a/lerobot/common/policies/policy_protocol.py +++ b/lerobot/common/policies/policy_protocol.py @@ -57,7 +57,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict: other items should be logging-friendly, native Python types. """ - def select_action(self, batch: dict[str, Tensor]): + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). When the model uses a history of observations, or outputs a sequence of actions, this method deals diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 7c873bf23..de9658e98 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -134,7 +134,7 @@ def reset(self): self._prev_mean: torch.Tensor | None = None @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]): + def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) batch["observation.image"] = batch[self.input_image_key] diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 85b9ceea0..6101df898 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -43,6 +43,40 @@ training: save_checkpoint: true num_workers: 4 batch_size: ??? + image_transforms: + # These transforms are all using standard torchvision.transforms.v2 + # You can find out how these transformations affect images here: + # https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html + # We use a custom RandomSubsetApply container to sample them. + # For each transform, the following parameters are available: + # weight: This represents the multinomial probability (with no replacement) + # used for sampling the transform. If the sum of the weights is not 1, + # they will be normalized. + # min_max: Lower & upper bound respectively used for sampling the transform's parameter + # (following uniform distribution) when it's applied. + # Set this flag to `true` to enable transforms during training + enable: false + # This is the maximum number of transforms (sampled from these below) that will be applied to each frame. + # It's an integer in the interval [1, number of available transforms]. + max_num_transforms: 3 + # By default, transforms are applied in Torchvision's suggested order (shown below). + # Set this to True to apply them in a random order. + random_order: false + brightness: + weight: 1 + min_max: [0.8, 1.2] + contrast: + weight: 1 + min_max: [0.8, 1.2] + saturation: + weight: 1 + min_max: [0.5, 1.5] + hue: + weight: 1 + min_max: [-0.05, 0.05] + sharpness: + weight: 1 + min_max: [0.8, 1.2] eval: n_episodes: 1 diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py index 4d8b48504..4d3cc291f 100644 --- a/lerobot/scripts/display_sys_info.py +++ b/lerobot/scripts/display_sys_info.py @@ -13,39 +13,71 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Use this script to get a quick summary of your system config. +It should be able to run without any of LeRobot's dependencies or LeRobot itself installed. +""" + import platform -import huggingface_hub +HAS_HF_HUB = True +HAS_HF_DATASETS = True +HAS_NP = True +HAS_TORCH = True +HAS_LEROBOT = True + +try: + import huggingface_hub +except ImportError: + HAS_HF_HUB = False + +try: + import datasets +except ImportError: + HAS_HF_DATASETS = False + +try: + import numpy as np +except ImportError: + HAS_NP = False + +try: + import torch +except ImportError: + HAS_TORCH = False + +try: + import lerobot +except ImportError: + HAS_LEROBOT = False -# import dataset -import numpy as np -import torch -from lerobot import __version__ as version +lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A" +hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A" +hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A" +np_version = np.__version__ if HAS_NP else "N/A" -pt_version = torch.__version__ -pt_cuda_available = torch.cuda.is_available() -pt_cuda_available = torch.cuda.is_available() -cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A" +torch_version = torch.__version__ if HAS_TORCH else "N/A" +torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" +cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" # TODO(aliberts): refactor into an actual command `lerobot env` def display_sys_info() -> dict: """Run this to get basic system info to help for tracking issues & bugs.""" info = { - "`lerobot` version": version, + "`lerobot` version": lerobot_version, "Platform": platform.platform(), "Python version": platform.python_version(), - "Huggingface_hub version": huggingface_hub.__version__, - # TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged - # "Dataset version": dataset.__version__, - "Numpy version": np.__version__, - "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "Huggingface_hub version": hf_hub_version, + "Dataset version": hf_datasets_version, + "Numpy version": np_version, + "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})", "Cuda version": cuda_version, "Using GPU in script?": "", - "Using distributed or parallel set-up in script?": "", + # "Using distributed or parallel set-up in script?": "", } - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") print(format_dict(info)) return info diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 784e9fc66..7bf8bde55 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -61,7 +61,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError from PIL import Image as PILImage -from torch import Tensor +from torch import Tensor, nn from tqdm import trange from lerobot.common.datasets.factory import make_dataset @@ -99,13 +99,13 @@ def rollout( "reward": A (batch, sequence) tensor of rewards received for applying the actions. "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon environment termination/truncation). - "don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, + "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, the first True is followed by True's all the way till the end. This can be used for masking extraneous elements from the sequences above. Args: env: The batch of environments. - policy: The policy. + policy: The policy. Must be a PyTorch nn module. seeds: The environments are seeded once at the start of the rollout. If provided, this argument specifies the seeds for each of the environments. return_observations: Whether to include all observations in the returned rollout data. Observations @@ -116,6 +116,7 @@ def rollout( Returns: The dictionary described above. """ + assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." device = get_device_from_parameters(policy) # Reset the policy and environments. @@ -209,7 +210,7 @@ def eval_policy( policy: torch.nn.Module, n_episodes: int, max_episodes_rendered: int = 0, - video_dir: Path | None = None, + videos_dir: Path | None = None, return_episode_data: bool = False, start_seed: int | None = None, enable_progbar: bool = False, @@ -221,7 +222,7 @@ def eval_policy( policy: The policy. n_episodes: The number of episodes to evaluate. max_episodes_rendered: Maximum number of episodes to render into videos. - video_dir: Where to save rendered videos. + videos_dir: Where to save rendered videos. return_episode_data: Whether to return episode data for online training. Incorporates the data into the "episodes" key of the returned dictionary. start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the @@ -231,6 +232,10 @@ def eval_policy( Returns: Dictionary with metrics and data regarding the rollouts. """ + if max_episodes_rendered > 0 and not videos_dir: + raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") + + assert isinstance(policy, Policy) start = time.time() policy.eval() @@ -271,11 +276,16 @@ def render_frame(env: gym.vector.VectorEnv): if max_episodes_rendered > 0: ep_frames: list[np.ndarray] = [] - seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)) + if start_seed is None: + seeds = None + else: + seeds = range( + start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) + ) rollout_data = rollout( env, policy, - seeds=seeds, + seeds=list(seeds) if seeds else None, return_observations=return_episode_data, render_callback=render_frame if max_episodes_rendered > 0 else None, enable_progbar=enable_inner_progbar, @@ -285,7 +295,8 @@ def render_frame(env: gym.vector.VectorEnv): # this won't be included). n_steps = rollout_data["done"].shape[1] # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. - done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps) + done_indices = torch.argmax(rollout_data["done"].to(int), dim=1) + # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() @@ -296,8 +307,12 @@ def render_frame(env: gym.vector.VectorEnv): max_rewards.extend(batch_max_rewards.tolist()) batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") all_successes.extend(batch_successes.tolist()) - all_seeds.extend(seeds) + if seeds: + all_seeds.extend(seeds) + else: + all_seeds.append(None) + # FIXME: episode_data is either None or it doesn't exist if return_episode_data: this_episode_data = _compile_episode_data( rollout_data, @@ -347,8 +362,9 @@ def render_frame(env: gym.vector.VectorEnv): ): if n_episodes_rendered >= max_episodes_rendered: break - video_dir.mkdir(parents=True, exist_ok=True) - video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4" + + videos_dir.mkdir(parents=True, exist_ok=True) + video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4" video_paths.append(str(video_path)) thread = threading.Thread( target=write_video, @@ -503,22 +519,20 @@ def _compile_episode_data( } -def eval( - pretrained_policy_path: str | None = None, +def main( + pretrained_policy_path: Path | None = None, hydra_cfg_path: str | None = None, + out_dir: str | None = None, config_overrides: list[str] | None = None, ): assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) - if hydra_cfg_path is None: - hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides) + if pretrained_policy_path is not None: + hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides) else: hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) - out_dir = ( - f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" - ) if out_dir is None: - raise NotImplementedError() + out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}" # Check device is available device = get_safe_torch_device(hydra_cfg.device, log=True) @@ -534,10 +548,12 @@ def eval( logging.info("Making policy.") if hydra_cfg_path is None: - policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path)) else: # Note: We need the dataset stats to pass to the policy's normalization modules. policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats) + + assert isinstance(policy, nn.Module) policy.eval() with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): @@ -546,7 +562,7 @@ def eval( policy, hydra_cfg.eval.n_episodes, max_episodes_rendered=10, - video_dir=Path(out_dir) / "eval", + videos_dir=Path(out_dir) / "videos", start_seed=hydra_cfg.seed, enable_progbar=True, enable_inner_progbar=True, @@ -586,6 +602,13 @@ def eval( ), ) parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") + parser.add_argument( + "--out-dir", + help=( + "Where to save the evaluation outputs. If not provided, outputs are saved in " + "outputs/eval/{timestamp}_{env_name}_{policy_name}" + ), + ) parser.add_argument( "overrides", nargs="*", @@ -594,7 +617,7 @@ def eval( args = parser.parse_args() if args.pretrained_policy_name_or_path is None: - eval(hydra_cfg_path=args.config, config_overrides=args.overrides) + main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides) else: try: pretrained_policy_path = Path( @@ -618,4 +641,8 @@ def eval( "repo ID, nor is it an existing local directory." ) - eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides) + main( + pretrained_policy_path=pretrained_policy_path, + out_dir=args.out_dir, + config_overrides=args.overrides, + ) diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 7c708c302..b4769a11a 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -47,6 +47,7 @@ import shutil import warnings from pathlib import Path +from typing import Any import torch from huggingface_hub import HfApi, create_branch @@ -58,7 +59,7 @@ from lerobot.common.datasets.utils import flatten_dict -def get_from_raw_to_lerobot_format_fn(raw_format): +def get_from_raw_to_lerobot_format_fn(raw_format: str): if raw_format == "pusht_zarr": from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format elif raw_format == "umi_zarr": @@ -77,7 +78,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format): return from_raw_to_lerobot_format -def save_meta_data(info, stats, episode_data_index, meta_data_dir): +def save_meta_data( + info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path +): meta_data_dir.mkdir(parents=True, exist_ok=True) # save info @@ -95,7 +98,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir): save_file(episode_data_index, ep_data_idx_path) -def push_meta_data_to_hub(repo_id, meta_data_dir, revision): +def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None): """Expect all meta data files to be all stored in a single "meta_data" directory. On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root. """ @@ -109,7 +112,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision): ) -def push_videos_to_hub(repo_id, videos_dir, revision): +def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None): """Expect mp4 files to be all stored in a single "videos" directory. On the hugging face repositery, they will be uploaded in a "videos" directory at the root. """ @@ -210,7 +213,6 @@ def push_dataset_to_hub( push_meta_data_to_hub(repo_id, meta_data_dir, revision="main") if video: push_videos_to_hub(repo_id, videos_dir, revision="main") - create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION) if tests_data_dir: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 860412bd6..01b2ef4f4 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,6 +24,7 @@ from deepdiff import DeepDiff from omegaconf import DictConfig, OmegaConf from termcolor import colored +from torch import nn from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps @@ -150,6 +151,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): grad_norm = info["grad_norm"] lr = info["lr"] update_s = info["update_s"] + dataloading_s = info["dataloading_s"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. @@ -170,6 +172,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): f"lr:{lr:0.1e}", # in seconds f"updt_s:{update_s:.3f}", + f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io ] logging.info(" ".join(log_items)) @@ -290,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. + eval_env = None if cfg.training.eval_freq > 0: logging.info("make_env") eval_env = make_env(cfg) @@ -300,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No dataset_stats=offline_dataset.stats if not cfg.resume else None, pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, ) - + assert isinstance(policy, nn.Module) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) @@ -325,14 +329,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step): + _num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) + step_identifier = f"{step:0{_num_digits}d}" + if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + assert eval_env is not None eval_info = eval_policy( eval_env, policy, cfg.eval.n_episodes, - video_dir=Path(out_dir) / "eval", + videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}", max_episodes_rendered=4, start_seed=cfg.seed, ) @@ -350,9 +358,7 @@ def evaluate_and_checkpoint_if_needed(step): policy, optimizer, lr_scheduler, - identifier=str(step).zfill( - max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) - ), + identifier=step_identifier, ) logging.info("Resume training") @@ -382,7 +388,10 @@ def evaluate_and_checkpoint_if_needed(step): for _ in range(step, cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") + + start_time = time.perf_counter() batch = next(dl_iter) + dataloading_s = time.perf_counter() - start_time for key in batch: batch[key] = batch[key].to(device, non_blocking=True) @@ -397,6 +406,8 @@ def evaluate_and_checkpoint_if_needed(step): use_amp=cfg.use_amp, ) + train_info["dataloading_s"] = dataloading_s + if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True) @@ -406,7 +417,8 @@ def evaluate_and_checkpoint_if_needed(step): step += 1 - eval_env.close() + if eval_env: + eval_env.close() logging.info("End of training") diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 58da6a47e..f947e6101 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -66,28 +66,31 @@ import logging import time from pathlib import Path +from typing import Iterator +import numpy as np import rerun as rr import torch +import torch.utils.data import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset class EpisodeSampler(torch.utils.data.Sampler): - def __init__(self, dataset, episode_index): + def __init__(self, dataset: LeRobotDataset, episode_index: int): from_idx = dataset.episode_data_index["from"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item() self.frame_ids = range(from_idx, to_idx) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.frame_ids) - def __len__(self): + def __len__(self) -> int: return len(self.frame_ids) -def to_hwc_uint8_numpy(chw_float32_torch): +def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 c, h, w = chw_float32_torch.shape @@ -106,6 +109,7 @@ def visualize_dataset( ws_port: int = 9087, save: bool = False, output_dir: Path | None = None, + root: Path | None = None, ) -> Path | None: if save: assert ( @@ -113,7 +117,7 @@ def visualize_dataset( ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id) + dataset = LeRobotDataset(repo_id, root=root) logging.info("Loading dataloader") episode_sampler = EpisodeSampler(dataset, episode_index) @@ -224,7 +228,8 @@ def main(): help=( "Mode of viewing between 'local' or 'distant'. " "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " - "'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + "'distant' creates a server on the distant machine where the data is stored. " + "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." ), ) parser.add_argument( @@ -245,8 +250,8 @@ def main(): default=0, help=( "Save a .rrd file in the directory provided by `--output-dir`. " - "It also deactivates the spawning of a viewer. ", - "Visualize the data by running `rerun path/to/file.rrd` on your local machine.", + "It also deactivates the spawning of a viewer. " + "Visualize the data by running `rerun path/to/file.rrd` on your local machine." ), ) parser.add_argument( @@ -255,6 +260,12 @@ def main(): help="Directory path to write a .rrd file when `--save 1` is set.", ) + parser.add_argument( + "--root", + type=str, + help="Root directory for a dataset stored on a local machine.", + ) + args = parser.parse_args() visualize_dataset(**vars(args)) diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py new file mode 100644 index 000000000..fa3c0ab2a --- /dev/null +++ b/lerobot/scripts/visualize_image_transforms.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize effects of image transforms for a given configuration. + +This script will generate examples of transformed images as they are output by LeRobot dataset. +Additionally, each individual transform can be visualized separately as well as examples of combined transforms + + +--- Usage Examples --- + +Increase hue jitter +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.hue.min_max=[-0.25,0.25] +``` + +Increase brightness & brightness weight +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.brightness.weight=10.0 \ + training.image_transforms.brightness.min_max=[1.0,2.0] +``` + +Blur images and disable saturation & hue +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.sharpness.weight=10.0 \ + training.image_transforms.sharpness.min_max=[0.0,1.0] \ + training.image_transforms.saturation.weight=0.0 \ + training.image_transforms.hue.weight=0.0 +``` + +Use all transforms with random order +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.max_num_transforms=5 \ + training.image_transforms.random_order=true +``` + +""" + +from pathlib import Path + +import hydra +from torchvision.transforms import ToPILImage + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms + +OUTPUT_DIR = Path("outputs/image_transforms") +N_EXAMPLES = 5 +to_pil = ToPILImage() + + +def save_config_all_transforms(cfg, original_frame, output_dir): + tf = get_image_transforms( + brightness_weight=cfg.brightness.weight, + brightness_min_max=cfg.brightness.min_max, + contrast_weight=cfg.contrast.weight, + contrast_min_max=cfg.contrast.min_max, + saturation_weight=cfg.saturation.weight, + saturation_min_max=cfg.saturation.min_max, + hue_weight=cfg.hue.weight, + hue_min_max=cfg.hue.min_max, + sharpness_weight=cfg.sharpness.weight, + sharpness_min_max=cfg.sharpness.min_max, + max_num_transforms=cfg.max_num_transforms, + random_order=cfg.random_order, + ) + + output_dir_all = output_dir / "all" + output_dir_all.mkdir(parents=True, exist_ok=True) + + for i in range(1, N_EXAMPLES + 1): + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100) + + print("Combined transforms examples saved to:") + print(f" {output_dir_all}") + + +def save_config_single_transforms(cfg, original_frame, output_dir): + transforms = [ + "brightness", + "contrast", + "saturation", + "hue", + "sharpness", + ] + print("Individual transforms examples saved to:") + for transform in transforms: + kwargs = { + f"{transform}_weight": cfg[f"{transform}"].weight, + f"{transform}_min_max": cfg[f"{transform}"].min_max, + } + tf = get_image_transforms(**kwargs) + output_dir_single = output_dir / f"{transform}" + output_dir_single.mkdir(parents=True, exist_ok=True) + + for i in range(1, N_EXAMPLES + 1): + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100) + + print(f" {output_dir_single}") + + +@hydra.main(version_base="1.2", config_name="default", config_path="../configs") +def visualize_transforms(cfg): + dataset = LeRobotDataset(cfg.dataset_repo_id) + + output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1] + output_dir.mkdir(parents=True, exist_ok=True) + + # Get 1st frame from 1st camera of 1st episode + original_frame = dataset[0][dataset.camera_keys[0]] + to_pil(original_frame).save(output_dir / "original_frame.png", quality=100) + print("\nOriginal frame saved to:") + print(f" {output_dir / 'original_frame.png'}.") + + save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir) + save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir) + + +if __name__ == "__main__": + visualize_transforms() diff --git a/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors new file mode 100644 index 000000000..77699dab1 --- /dev/null +++ b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a +size 3686488 diff --git a/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors new file mode 100644 index 000000000..13f1033fa --- /dev/null +++ b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f +size 40551392 diff --git a/tests/scripts/save_image_transforms_to_safetensors.py b/tests/scripts/save_image_transforms_to_safetensors.py new file mode 100644 index 000000000..9d024a013 --- /dev/null +++ b/tests/scripts/save_image_transforms_to_safetensors.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import torch +from safetensors.torch import save_file + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms +from lerobot.common.utils.utils import init_hydra_config, seeded_context +from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID +from tests.utils import DEFAULT_CONFIG_PATH + + +def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path): + cfg = init_hydra_config(DEFAULT_CONFIG_PATH) + cfg_tf = cfg.training.image_transforms + default_tf = get_image_transforms( + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, + ) + + with seeded_context(1337): + img_tf = default_tf(original_frame) + + save_file({"default": img_tf}, output_dir / "default_transforms.safetensors") + + +def save_single_transforms(original_frame: torch.Tensor, output_dir: Path): + transforms = { + "brightness": [(0.5, 0.5), (2.0, 2.0)], + "contrast": [(0.5, 0.5), (2.0, 2.0)], + "saturation": [(0.5, 0.5), (2.0, 2.0)], + "hue": [(-0.25, -0.25), (0.25, 0.25)], + "sharpness": [(0.5, 0.5), (2.0, 2.0)], + } + + frames = {"original_frame": original_frame} + for transform, values in transforms.items(): + for min_max in values: + kwargs = { + f"{transform}_weight": 1.0, + f"{transform}_min_max": min_max, + } + tf = get_image_transforms(**kwargs) + key = f"{transform}_{min_max[0]}_{min_max[1]}" + frames[key] = tf(original_frame) + + save_file(frames, output_dir / "single_transforms.safetensors") + + +def main(): + dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None) + output_dir = Path(ARTIFACT_DIR) + output_dir.mkdir(parents=True, exist_ok=True) + original_frame = dataset[0][dataset.camera_keys[0]] + + save_single_transforms(original_frame, output_dir) + save_default_config_transform(original_frame, output_dir) + + +if __name__ == "__main__": + main() diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py new file mode 100644 index 000000000..ba6d972f3 --- /dev/null +++ b/tests/test_image_transforms.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from safetensors.torch import load_file +from torchvision.transforms import v2 +from torchvision.transforms.v2 import functional as F # noqa: N812 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms +from lerobot.common.utils.utils import init_hydra_config, seeded_context +from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel + +ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors") +DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp" + + +def load_png_to_tensor(path: Path): + return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1) + + +@pytest.fixture +def img(): + dataset = LeRobotDataset(DATASET_REPO_ID) + return dataset[0][dataset.camera_keys[0]] + + +@pytest.fixture +def img_random(): + return torch.rand(3, 480, 640) + + +@pytest.fixture +def color_jitters(): + return [ + v2.ColorJitter(brightness=0.5), + v2.ColorJitter(contrast=0.5), + v2.ColorJitter(saturation=0.5), + ] + + +@pytest.fixture +def single_transforms(): + return load_file(ARTIFACT_DIR / "single_transforms.safetensors") + + +@pytest.fixture +def default_transforms(): + return load_file(ARTIFACT_DIR / "default_transforms.safetensors") + + +def test_get_image_transforms_no_transform(img): + tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0) + torch.testing.assert_close(tf_actual(img), img) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_brightness(img, min_max): + tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max) + tf_expected = v2.ColorJitter(brightness=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_contrast(img, min_max): + tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max) + tf_expected = v2.ColorJitter(contrast=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_saturation(img, min_max): + tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max) + tf_expected = v2.ColorJitter(saturation=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)]) +def test_get_image_transforms_hue(img, min_max): + tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max) + tf_expected = v2.ColorJitter(hue=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_sharpness(img, min_max): + tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max) + tf_expected = SharpnessJitter(sharpness=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +def test_get_image_transforms_max_num_transforms(img): + tf_actual = get_image_transforms( + brightness_min_max=(0.5, 0.5), + contrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=False, + ) + tf_expected = v2.Compose( + [ + v2.ColorJitter(brightness=(0.5, 0.5)), + v2.ColorJitter(contrast=(0.5, 0.5)), + v2.ColorJitter(saturation=(0.5, 0.5)), + v2.ColorJitter(hue=(0.5, 0.5)), + SharpnessJitter(sharpness=(0.5, 0.5)), + ] + ) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@require_x86_64_kernel +def test_get_image_transforms_random_order(img): + out_imgs = [] + tf = get_image_transforms( + brightness_min_max=(0.5, 0.5), + contrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=True, + ) + with seeded_context(1337): + for _ in range(10): + out_imgs.append(tf(img)) + + for i in range(1, len(out_imgs)): + with pytest.raises(AssertionError): + torch.testing.assert_close(out_imgs[0], out_imgs[i]) + + +@pytest.mark.parametrize( + "transform, min_max_values", + [ + ("brightness", [(0.5, 0.5), (2.0, 2.0)]), + ("contrast", [(0.5, 0.5), (2.0, 2.0)]), + ("saturation", [(0.5, 0.5), (2.0, 2.0)]), + ("hue", [(-0.25, -0.25), (0.25, 0.25)]), + ("sharpness", [(0.5, 0.5), (2.0, 2.0)]), + ], +) +def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms): + for min_max in min_max_values: + kwargs = { + f"{transform}_weight": 1.0, + f"{transform}_min_max": min_max, + } + tf = get_image_transforms(**kwargs) + actual = tf(img) + key = f"{transform}_{min_max[0]}_{min_max[1]}" + expected = single_transforms[key] + torch.testing.assert_close(actual, expected) + + +@require_x86_64_kernel +def test_backward_compatibility_default_config(img, default_transforms): + cfg = init_hydra_config(DEFAULT_CONFIG_PATH) + cfg_tf = cfg.training.image_transforms + default_tf = get_image_transforms( + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, + ) + + with seeded_context(1337): + actual = default_tf(img) + + expected = default_transforms["default"] + + torch.testing.assert_close(actual, expected) + + +@pytest.mark.parametrize("p", [[0, 1], [1, 0]]) +def test_random_subset_apply_single_choice(p, img): + flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] + random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False) + actual = random_choice(img) + + p_horz, _ = p + if p_horz: + torch.testing.assert_close(actual, F.horizontal_flip(img)) + else: + torch.testing.assert_close(actual, F.vertical_flip(img)) + + +def test_random_subset_apply_random_order(img): + flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] + random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True) + # We can't really check whether the transforms are actually applied in random order. However, + # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform + # applies them in random order, we can use a fixed order to compute the expected value. + actual = random_order(img) + expected = v2.Compose(flips)(img) + torch.testing.assert_close(actual, expected) + + +def test_random_subset_apply_valid_transforms(color_jitters, img): + transform = RandomSubsetApply(color_jitters) + output = transform(img) + assert output.shape == img.shape + + +def test_random_subset_apply_probability_length_mismatch(color_jitters): + with pytest.raises(ValueError): + RandomSubsetApply(color_jitters, p=[0.5, 0.5]) + + +@pytest.mark.parametrize("n_subset", [0, 5]) +def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset): + with pytest.raises(ValueError): + RandomSubsetApply(color_jitters, n_subset=n_subset) + + +def test_sharpness_jitter_valid_range_tuple(img): + tf = SharpnessJitter((0.1, 2.0)) + output = tf(img) + assert output.shape == img.shape + + +def test_sharpness_jitter_valid_range_float(img): + tf = SharpnessJitter(0.5) + output = tf(img) + assert output.shape == img.shape + + +def test_sharpness_jitter_invalid_range_min_negative(): + with pytest.raises(ValueError): + SharpnessJitter((-0.1, 2.0)) + + +def test_sharpness_jitter_invalid_range_max_smaller(): + with pytest.raises(ValueError): + SharpnessJitter((2.0, 0.1)) diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 999540402..029c59ed1 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + import pytest from lerobot.scripts.visualize_dataset import visualize_dataset @@ -31,3 +33,20 @@ def test_visualize_dataset(tmpdir, repo_id): output_dir=tmpdir, ) assert rrd_path.exists() + + +@pytest.mark.parametrize( + "repo_id", + ["lerobot/pusht"], +) +@pytest.mark.parametrize("root", [Path(__file__).parent / "data"]) +def test_visualize_local_dataset(tmpdir, repo_id, root): + rrd_path = visualize_dataset( + repo_id, + episode_index=0, + batch_size=32, + save=True, + output_dir=tmpdir, + root=root, + ) + assert rrd_path.exists()