From ff8f6aa6cde2957f08547eb081aac12ca4669b6a Mon Sep 17 00:00:00 2001 From: Marina Barannikov Date: Tue, 11 Jun 2024 19:20:55 +0200 Subject: [PATCH] Add data augmentation in LeRobotDataset (#234) Co-authored-by: Simon Alibert Co-authored-by: Remi Cadene --- examples/6_add_image_transforms.py | 52 ++++ lerobot/common/datasets/factory.py | 24 +- lerobot/common/datasets/lerobot_dataset.py | 22 +- lerobot/common/datasets/transforms.py | 197 +++++++++++++ lerobot/configs/default.yaml | 34 +++ lerobot/scripts/visualize_image_transforms.py | 142 ++++++++++ .../default_transforms.safetensors | 3 + .../single_transforms.safetensors | 3 + .../save_image_transforms_to_safetensors.py | 86 ++++++ tests/test_image_transforms.py | 260 ++++++++++++++++++ 10 files changed, 811 insertions(+), 12 deletions(-) create mode 100644 examples/6_add_image_transforms.py create mode 100644 lerobot/common/datasets/transforms.py create mode 100644 lerobot/scripts/visualize_image_transforms.py create mode 100644 tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors create mode 100644 tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors create mode 100644 tests/scripts/save_image_transforms_to_safetensors.py create mode 100644 tests/test_image_transforms.py diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py new file mode 100644 index 000000000..bdcc6d7b9 --- /dev/null +++ b/examples/6_add_image_transforms.py @@ -0,0 +1,52 @@ +""" +This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data +augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and +transforms are applied to the observation images before they are returned in the dataset's __get_item__. +""" + +from pathlib import Path + +from torchvision.transforms import ToPILImage, v2 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +dataset_repo_id = "lerobot/aloha_static_tape" + +# Create a LeRobotDataset with no transformations +dataset = LeRobotDataset(dataset_repo_id) +# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)` + +# Get the index of the first observation in the first episode +first_idx = dataset.episode_data_index["from"][0].item() + +# Get the frame corresponding to the first camera +frame = dataset[first_idx][dataset.camera_keys[0]] + + +# Define the transformations +transforms = v2.Compose( + [ + v2.ColorJitter(brightness=(0.5, 1.5)), + v2.ColorJitter(contrast=(0.5, 1.5)), + v2.RandomAdjustSharpness(sharpness_factor=2, p=1), + ] +) + +# Create another LeRobotDataset with the defined transformations +transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms) + +# Get a frame from the transformed dataset +transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] + +# Create a directory to store output images +output_dir = Path("outputs/image_transforms") +output_dir.mkdir(parents=True, exist_ok=True) + +# Save the original frame +to_pil = ToPILImage() +to_pil(frame).save(output_dir / "original_frame.png", quality=100) +print(f"Original frame saved to {output_dir / 'original_frame.png'}.") + +# Save the transformed frame +to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100) +print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.") diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4732f5774..fab8ca575 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -19,6 +19,7 @@ from omegaconf import ListConfig, OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms def resolve_delta_timestamps(cfg): @@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData resolve_delta_timestamps(cfg) - # TODO(rcadene): add data augmentations + image_transforms = None + if cfg.training.image_transforms.enable: + image_transforms = get_image_transforms( + brightness_weight=cfg.brightness.weight, + brightness_min_max=cfg.brightness.min_max, + contrast_weight=cfg.contrast.weight, + contrast_min_max=cfg.contrast.min_max, + saturation_weight=cfg.saturation.weight, + saturation_min_max=cfg.saturation.min_max, + hue_weight=cfg.hue.weight, + hue_min_max=cfg.hue.min_max, + sharpness_weight=cfg.sharpness.weight, + sharpness_min_max=cfg.sharpness.min_max, + max_num_transforms=cfg.max_num_transforms, + random_order=cfg.random_order, + ) if isinstance(cfg.dataset_repo_id, str): dataset = LeRobotDataset( cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps"), + image_transforms=image_transforms, ) else: dataset = MultiLeRobotDataset( - cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps") + cfg.dataset_repo_id, + split=split, + delta_timestamps=cfg.training.get("delta_timestamps"), + image_transforms=image_transforms, ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 58ae51b17..d680b9875 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -46,7 +46,7 @@ def __init__( version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -54,7 +54,7 @@ def __init__( self.version = version self.root = root self.split = split - self.transform = transform + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps # load data from hub or locally when root is provided # TODO(rcadene, aliberts): implement faster transfer @@ -151,8 +151,9 @@ def __getitem__(self, idx): self.tolerance_s, ) - if self.transform is not None: - item = self.transform(item) + if self.image_transforms is not None: + for cam in self.camera_keys: + item[cam] = self.image_transforms(item[cam]) return item @@ -168,7 +169,7 @@ def __repr__(self): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" ) @@ -202,7 +203,7 @@ def from_preloaded( obj.version = version obj.root = root obj.split = split - obj.transform = transform + obj.image_transforms = transform obj.delta_timestamps = delta_timestamps obj.hf_dataset = hf_dataset obj.episode_data_index = episode_data_index @@ -225,7 +226,7 @@ def __init__( version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -239,7 +240,7 @@ def __init__( root=root, split=split, delta_timestamps=delta_timestamps, - transform=transform, + image_transforms=image_transforms, ) for repo_id in repo_ids ] @@ -274,7 +275,7 @@ def __init__( self.version = version self.root = root self.split = split - self.transform = transform + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.stats = aggregate_stats(self._datasets) @@ -380,6 +381,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: for data_key in self.disabled_data_keys: if data_key in item: del item[data_key] + return item def __repr__(self): @@ -394,6 +396,6 @@ def __repr__(self): f" Recorded Frames per Second: {self.fps},\n" f" Camera Keys: {self.camera_keys},\n" f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" - f" Transformations: {self.transform},\n" + f" Transformations: {self.image_transforms},\n" f")" ) diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py new file mode 100644 index 000000000..899f0d66c --- /dev/null +++ b/lerobot/common/datasets/transforms.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +from typing import Any, Callable, Dict, Sequence + +import torch +from torchvision.transforms import v2 +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2 import functional as F # noqa: N812 + + +class RandomSubsetApply(Transform): + """Apply a random subset of N transformations from a list of transformations. + + Args: + transforms: list of transformations. + p: represents the multinomial probabilities (with no replacement) used for sampling the transform. + If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms + have the same probability. + n_subset: number of transformations to apply. If ``None``, all transforms are applied. + Must be in [1, len(transforms)]. + random_order: apply transformations in a random order. + """ + + def __init__( + self, + transforms: Sequence[Callable], + p: list[float] | None = None, + n_subset: int | None = None, + random_order: bool = False, + ) -> None: + super().__init__() + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + if p is None: + p = [1] * len(transforms) + elif len(p) != len(transforms): + raise ValueError( + f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}" + ) + + if n_subset is None: + n_subset = len(transforms) + elif not isinstance(n_subset, int): + raise TypeError("n_subset should be an int or None") + elif not (1 <= n_subset <= len(transforms)): + raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") + + self.transforms = transforms + total = sum(p) + self.p = [prob / total for prob in p] + self.n_subset = n_subset + self.random_order = random_order + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + + selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset) + if not self.random_order: + selected_indices = selected_indices.sort().values + + selected_transforms = [self.transforms[i] for i in selected_indices] + + for transform in selected_transforms: + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + + return outputs + + def extra_repr(self) -> str: + return ( + f"transforms={self.transforms}, " + f"p={self.p}, " + f"n_subset={self.n_subset}, " + f"random_order={self.random_order}" + ) + + +class SharpnessJitter(Transform): + """Randomly change the sharpness of an image or video. + + Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly. + While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image, + SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of + augmentations as a result. + + A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness + by a factor of 2. + + If the input is a :class:`torch.Tensor`, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from + [max(0, 1 - sharpness), 1 + sharpness] or the given + [min, max]. Should be non negative numbers. + """ + + def __init__(self, sharpness: float | Sequence[float]) -> None: + super().__init__() + self.sharpness = self._check_input(sharpness) + + def _check_input(self, sharpness): + if isinstance(sharpness, (int, float)): + if sharpness < 0: + raise ValueError("If sharpness is a single number, it must be non negative.") + sharpness = [1.0 - sharpness, 1.0 + sharpness] + sharpness[0] = max(sharpness[0], 0.0) + elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: + sharpness = [float(v) for v in sharpness] + else: + raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") + + if not 0.0 <= sharpness[0] <= sharpness[1]: + raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") + + return float(sharpness[0]), float(sharpness[1]) + + def _generate_value(self, left: float, right: float) -> float: + return torch.empty(1).uniform_(left, right).item() + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) + + +def get_image_transforms( + brightness_weight: float = 1.0, + brightness_min_max: tuple[float, float] | None = None, + contrast_weight: float = 1.0, + contrast_min_max: tuple[float, float] | None = None, + saturation_weight: float = 1.0, + saturation_min_max: tuple[float, float] | None = None, + hue_weight: float = 1.0, + hue_min_max: tuple[float, float] | None = None, + sharpness_weight: float = 1.0, + sharpness_min_max: tuple[float, float] | None = None, + max_num_transforms: int | None = None, + random_order: bool = False, +): + def check_value(name, weight, min_max): + if min_max is not None: + if len(min_max) != 2: + raise ValueError( + f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided." + ) + if weight < 0.0: + raise ValueError( + f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})." + ) + + check_value("brightness", brightness_weight, brightness_min_max) + check_value("contrast", contrast_weight, contrast_min_max) + check_value("saturation", saturation_weight, saturation_min_max) + check_value("hue", hue_weight, hue_min_max) + check_value("sharpness", sharpness_weight, sharpness_min_max) + + weights = [] + transforms = [] + if brightness_min_max is not None and brightness_weight > 0.0: + weights.append(brightness_weight) + transforms.append(v2.ColorJitter(brightness=brightness_min_max)) + if contrast_min_max is not None and contrast_weight > 0.0: + weights.append(contrast_weight) + transforms.append(v2.ColorJitter(contrast=contrast_min_max)) + if saturation_min_max is not None and saturation_weight > 0.0: + weights.append(saturation_weight) + transforms.append(v2.ColorJitter(saturation=saturation_min_max)) + if hue_min_max is not None and hue_weight > 0.0: + weights.append(hue_weight) + transforms.append(v2.ColorJitter(hue=hue_min_max)) + if sharpness_min_max is not None and sharpness_weight > 0.0: + weights.append(sharpness_weight) + transforms.append(SharpnessJitter(sharpness=sharpness_min_max)) + + n_subset = len(transforms) + if max_num_transforms is not None: + n_subset = min(n_subset, max_num_transforms) + + if n_subset == 0: + return v2.Identity() + else: + # TODO(rcadene, aliberts): add v2.ToDtype float16? + return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 85b9ceea0..6101df898 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -43,6 +43,40 @@ training: save_checkpoint: true num_workers: 4 batch_size: ??? + image_transforms: + # These transforms are all using standard torchvision.transforms.v2 + # You can find out how these transformations affect images here: + # https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html + # We use a custom RandomSubsetApply container to sample them. + # For each transform, the following parameters are available: + # weight: This represents the multinomial probability (with no replacement) + # used for sampling the transform. If the sum of the weights is not 1, + # they will be normalized. + # min_max: Lower & upper bound respectively used for sampling the transform's parameter + # (following uniform distribution) when it's applied. + # Set this flag to `true` to enable transforms during training + enable: false + # This is the maximum number of transforms (sampled from these below) that will be applied to each frame. + # It's an integer in the interval [1, number of available transforms]. + max_num_transforms: 3 + # By default, transforms are applied in Torchvision's suggested order (shown below). + # Set this to True to apply them in a random order. + random_order: false + brightness: + weight: 1 + min_max: [0.8, 1.2] + contrast: + weight: 1 + min_max: [0.8, 1.2] + saturation: + weight: 1 + min_max: [0.5, 1.5] + hue: + weight: 1 + min_max: [-0.05, 0.05] + sharpness: + weight: 1 + min_max: [0.8, 1.2] eval: n_episodes: 1 diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py new file mode 100644 index 000000000..fa3c0ab2a --- /dev/null +++ b/lerobot/scripts/visualize_image_transforms.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Visualize effects of image transforms for a given configuration. + +This script will generate examples of transformed images as they are output by LeRobot dataset. +Additionally, each individual transform can be visualized separately as well as examples of combined transforms + + +--- Usage Examples --- + +Increase hue jitter +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.hue.min_max=[-0.25,0.25] +``` + +Increase brightness & brightness weight +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.brightness.weight=10.0 \ + training.image_transforms.brightness.min_max=[1.0,2.0] +``` + +Blur images and disable saturation & hue +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.sharpness.weight=10.0 \ + training.image_transforms.sharpness.min_max=[0.0,1.0] \ + training.image_transforms.saturation.weight=0.0 \ + training.image_transforms.hue.weight=0.0 +``` + +Use all transforms with random order +``` +python lerobot/scripts/visualize_image_transforms.py \ + dataset_repo_id=lerobot/aloha_mobile_shrimp \ + training.image_transforms.max_num_transforms=5 \ + training.image_transforms.random_order=true +``` + +""" + +from pathlib import Path + +import hydra +from torchvision.transforms import ToPILImage + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms + +OUTPUT_DIR = Path("outputs/image_transforms") +N_EXAMPLES = 5 +to_pil = ToPILImage() + + +def save_config_all_transforms(cfg, original_frame, output_dir): + tf = get_image_transforms( + brightness_weight=cfg.brightness.weight, + brightness_min_max=cfg.brightness.min_max, + contrast_weight=cfg.contrast.weight, + contrast_min_max=cfg.contrast.min_max, + saturation_weight=cfg.saturation.weight, + saturation_min_max=cfg.saturation.min_max, + hue_weight=cfg.hue.weight, + hue_min_max=cfg.hue.min_max, + sharpness_weight=cfg.sharpness.weight, + sharpness_min_max=cfg.sharpness.min_max, + max_num_transforms=cfg.max_num_transforms, + random_order=cfg.random_order, + ) + + output_dir_all = output_dir / "all" + output_dir_all.mkdir(parents=True, exist_ok=True) + + for i in range(1, N_EXAMPLES + 1): + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100) + + print("Combined transforms examples saved to:") + print(f" {output_dir_all}") + + +def save_config_single_transforms(cfg, original_frame, output_dir): + transforms = [ + "brightness", + "contrast", + "saturation", + "hue", + "sharpness", + ] + print("Individual transforms examples saved to:") + for transform in transforms: + kwargs = { + f"{transform}_weight": cfg[f"{transform}"].weight, + f"{transform}_min_max": cfg[f"{transform}"].min_max, + } + tf = get_image_transforms(**kwargs) + output_dir_single = output_dir / f"{transform}" + output_dir_single.mkdir(parents=True, exist_ok=True) + + for i in range(1, N_EXAMPLES + 1): + transformed_frame = tf(original_frame) + to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100) + + print(f" {output_dir_single}") + + +@hydra.main(version_base="1.2", config_name="default", config_path="../configs") +def visualize_transforms(cfg): + dataset = LeRobotDataset(cfg.dataset_repo_id) + + output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1] + output_dir.mkdir(parents=True, exist_ok=True) + + # Get 1st frame from 1st camera of 1st episode + original_frame = dataset[0][dataset.camera_keys[0]] + to_pil(original_frame).save(output_dir / "original_frame.png", quality=100) + print("\nOriginal frame saved to:") + print(f" {output_dir / 'original_frame.png'}.") + + save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir) + save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir) + + +if __name__ == "__main__": + visualize_transforms() diff --git a/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors new file mode 100644 index 000000000..77699dab1 --- /dev/null +++ b/tests/data/save_image_transforms_to_safetensors/default_transforms.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a +size 3686488 diff --git a/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors new file mode 100644 index 000000000..13f1033fa --- /dev/null +++ b/tests/data/save_image_transforms_to_safetensors/single_transforms.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f +size 40551392 diff --git a/tests/scripts/save_image_transforms_to_safetensors.py b/tests/scripts/save_image_transforms_to_safetensors.py new file mode 100644 index 000000000..9d024a013 --- /dev/null +++ b/tests/scripts/save_image_transforms_to_safetensors.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import torch +from safetensors.torch import save_file + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import get_image_transforms +from lerobot.common.utils.utils import init_hydra_config, seeded_context +from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID +from tests.utils import DEFAULT_CONFIG_PATH + + +def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path): + cfg = init_hydra_config(DEFAULT_CONFIG_PATH) + cfg_tf = cfg.training.image_transforms + default_tf = get_image_transforms( + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, + ) + + with seeded_context(1337): + img_tf = default_tf(original_frame) + + save_file({"default": img_tf}, output_dir / "default_transforms.safetensors") + + +def save_single_transforms(original_frame: torch.Tensor, output_dir: Path): + transforms = { + "brightness": [(0.5, 0.5), (2.0, 2.0)], + "contrast": [(0.5, 0.5), (2.0, 2.0)], + "saturation": [(0.5, 0.5), (2.0, 2.0)], + "hue": [(-0.25, -0.25), (0.25, 0.25)], + "sharpness": [(0.5, 0.5), (2.0, 2.0)], + } + + frames = {"original_frame": original_frame} + for transform, values in transforms.items(): + for min_max in values: + kwargs = { + f"{transform}_weight": 1.0, + f"{transform}_min_max": min_max, + } + tf = get_image_transforms(**kwargs) + key = f"{transform}_{min_max[0]}_{min_max[1]}" + frames[key] = tf(original_frame) + + save_file(frames, output_dir / "single_transforms.safetensors") + + +def main(): + dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None) + output_dir = Path(ARTIFACT_DIR) + output_dir.mkdir(parents=True, exist_ok=True) + original_frame = dataset[0][dataset.camera_keys[0]] + + save_single_transforms(original_frame, output_dir) + save_default_config_transform(original_frame, output_dir) + + +if __name__ == "__main__": + main() diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py new file mode 100644 index 000000000..ba6d972f3 --- /dev/null +++ b/tests/test_image_transforms.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from safetensors.torch import load_file +from torchvision.transforms import v2 +from torchvision.transforms.v2 import functional as F # noqa: N812 + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms +from lerobot.common.utils.utils import init_hydra_config, seeded_context +from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel + +ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors") +DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp" + + +def load_png_to_tensor(path: Path): + return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1) + + +@pytest.fixture +def img(): + dataset = LeRobotDataset(DATASET_REPO_ID) + return dataset[0][dataset.camera_keys[0]] + + +@pytest.fixture +def img_random(): + return torch.rand(3, 480, 640) + + +@pytest.fixture +def color_jitters(): + return [ + v2.ColorJitter(brightness=0.5), + v2.ColorJitter(contrast=0.5), + v2.ColorJitter(saturation=0.5), + ] + + +@pytest.fixture +def single_transforms(): + return load_file(ARTIFACT_DIR / "single_transforms.safetensors") + + +@pytest.fixture +def default_transforms(): + return load_file(ARTIFACT_DIR / "default_transforms.safetensors") + + +def test_get_image_transforms_no_transform(img): + tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0) + torch.testing.assert_close(tf_actual(img), img) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_brightness(img, min_max): + tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max) + tf_expected = v2.ColorJitter(brightness=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_contrast(img, min_max): + tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max) + tf_expected = v2.ColorJitter(contrast=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_saturation(img, min_max): + tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max) + tf_expected = v2.ColorJitter(saturation=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)]) +def test_get_image_transforms_hue(img, min_max): + tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max) + tf_expected = v2.ColorJitter(hue=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)]) +def test_get_image_transforms_sharpness(img, min_max): + tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max) + tf_expected = SharpnessJitter(sharpness=min_max) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +def test_get_image_transforms_max_num_transforms(img): + tf_actual = get_image_transforms( + brightness_min_max=(0.5, 0.5), + contrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=False, + ) + tf_expected = v2.Compose( + [ + v2.ColorJitter(brightness=(0.5, 0.5)), + v2.ColorJitter(contrast=(0.5, 0.5)), + v2.ColorJitter(saturation=(0.5, 0.5)), + v2.ColorJitter(hue=(0.5, 0.5)), + SharpnessJitter(sharpness=(0.5, 0.5)), + ] + ) + torch.testing.assert_close(tf_actual(img), tf_expected(img)) + + +@require_x86_64_kernel +def test_get_image_transforms_random_order(img): + out_imgs = [] + tf = get_image_transforms( + brightness_min_max=(0.5, 0.5), + contrast_min_max=(0.5, 0.5), + saturation_min_max=(0.5, 0.5), + hue_min_max=(0.5, 0.5), + sharpness_min_max=(0.5, 0.5), + random_order=True, + ) + with seeded_context(1337): + for _ in range(10): + out_imgs.append(tf(img)) + + for i in range(1, len(out_imgs)): + with pytest.raises(AssertionError): + torch.testing.assert_close(out_imgs[0], out_imgs[i]) + + +@pytest.mark.parametrize( + "transform, min_max_values", + [ + ("brightness", [(0.5, 0.5), (2.0, 2.0)]), + ("contrast", [(0.5, 0.5), (2.0, 2.0)]), + ("saturation", [(0.5, 0.5), (2.0, 2.0)]), + ("hue", [(-0.25, -0.25), (0.25, 0.25)]), + ("sharpness", [(0.5, 0.5), (2.0, 2.0)]), + ], +) +def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms): + for min_max in min_max_values: + kwargs = { + f"{transform}_weight": 1.0, + f"{transform}_min_max": min_max, + } + tf = get_image_transforms(**kwargs) + actual = tf(img) + key = f"{transform}_{min_max[0]}_{min_max[1]}" + expected = single_transforms[key] + torch.testing.assert_close(actual, expected) + + +@require_x86_64_kernel +def test_backward_compatibility_default_config(img, default_transforms): + cfg = init_hydra_config(DEFAULT_CONFIG_PATH) + cfg_tf = cfg.training.image_transforms + default_tf = get_image_transforms( + brightness_weight=cfg_tf.brightness.weight, + brightness_min_max=cfg_tf.brightness.min_max, + contrast_weight=cfg_tf.contrast.weight, + contrast_min_max=cfg_tf.contrast.min_max, + saturation_weight=cfg_tf.saturation.weight, + saturation_min_max=cfg_tf.saturation.min_max, + hue_weight=cfg_tf.hue.weight, + hue_min_max=cfg_tf.hue.min_max, + sharpness_weight=cfg_tf.sharpness.weight, + sharpness_min_max=cfg_tf.sharpness.min_max, + max_num_transforms=cfg_tf.max_num_transforms, + random_order=cfg_tf.random_order, + ) + + with seeded_context(1337): + actual = default_tf(img) + + expected = default_transforms["default"] + + torch.testing.assert_close(actual, expected) + + +@pytest.mark.parametrize("p", [[0, 1], [1, 0]]) +def test_random_subset_apply_single_choice(p, img): + flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] + random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False) + actual = random_choice(img) + + p_horz, _ = p + if p_horz: + torch.testing.assert_close(actual, F.horizontal_flip(img)) + else: + torch.testing.assert_close(actual, F.vertical_flip(img)) + + +def test_random_subset_apply_random_order(img): + flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)] + random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True) + # We can't really check whether the transforms are actually applied in random order. However, + # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform + # applies them in random order, we can use a fixed order to compute the expected value. + actual = random_order(img) + expected = v2.Compose(flips)(img) + torch.testing.assert_close(actual, expected) + + +def test_random_subset_apply_valid_transforms(color_jitters, img): + transform = RandomSubsetApply(color_jitters) + output = transform(img) + assert output.shape == img.shape + + +def test_random_subset_apply_probability_length_mismatch(color_jitters): + with pytest.raises(ValueError): + RandomSubsetApply(color_jitters, p=[0.5, 0.5]) + + +@pytest.mark.parametrize("n_subset", [0, 5]) +def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset): + with pytest.raises(ValueError): + RandomSubsetApply(color_jitters, n_subset=n_subset) + + +def test_sharpness_jitter_valid_range_tuple(img): + tf = SharpnessJitter((0.1, 2.0)) + output = tf(img) + assert output.shape == img.shape + + +def test_sharpness_jitter_valid_range_float(img): + tf = SharpnessJitter(0.5) + output = tf(img) + assert output.shape == img.shape + + +def test_sharpness_jitter_invalid_range_min_negative(): + with pytest.raises(ValueError): + SharpnessJitter((-0.1, 2.0)) + + +def test_sharpness_jitter_invalid_range_max_smaller(): + with pytest.raises(ValueError): + SharpnessJitter((2.0, 0.1))