Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data augmentation in LeRobotDataset #234

Merged
merged 63 commits into from
Jun 11, 2024

Conversation

marinabar
Copy link
Contributor

@marinabar marinabar commented May 31, 2024

What this does

Implements data augmentation for images a LeRobot dataset object.

  • Adds a custom RandomSubsetApply transform to apply a random subset of N transformations from a list of transformations.
  • Adds a custom SharpnessJitter transform to randomly change the sharpness of an image or video.
  • Adds an example 6_add_image_transforms.py showing usage of transforms parameter enabled with LeRobot Dataset.
  • Adds a visualize_image_transforms.py script to produce examples of transformed images that would be produced with a given config.
  • Adds tests and artifacts.

Default transforms are :

  • contrast
  • brightness
  • hue
  • saturation
  • sharpness

The parameters are taken from the default.yaml configuration and the transforms are defined in transforms.py. They are then applied in __getitem__() method of LeRobotDataset.
The transformation is applied on the images of all of the given cameras.
(WIP : support for multi image observations in delta_timestamps)

How you can verify it

To test various types of transforms, you can use the newly added script

python lerobot/scripts/visualize_image_transforms.py

This applies the transforms from the configuration file and then saves multiple images corresponding to each applied transform.

You can also run the example script:

python examples/6_add_image_transforms.py

How it was tested

We trained a model on the data collected on a grasping task from Reachy2, while incorporating data augmentation. On evaluation, the policy appeared more robust to lighting changes. (Evaluation ran in the dark with multiple sources of light)

Examples of single transformations:

transformation min_max
original frame None original_frame
brightness (0.5, 0.5) brightness_0 5_0 5
brightness (2.0, 2.0) brightness_2 0_2 0
contrast (0.5, 0.5) contrast_0 5_0 5
contrast (2.0, 2.0) contrast_2 0_2 0
saturation (0.5, 0.5) saturation_0 5_0 5
saturation (2.0, 2.0) saturation_2 0_2 0
hue (-0.25, 0.25) hue_-0 25_-0 25
hue (0.25, 0.25) hue_0 25_0 25
sharpness (0.5, 0.5) sharpness_0 5_0 5
sharpness (2.0, 2.0) sharpness_2 0_2 0

This change is Reviewable

@aliberts aliberts added ✨ Enhancement New feature or request 🗃️ Dataset Something dataset-related labels Jun 1, 2024
@aliberts aliberts changed the title Implemented data augmentation with LeRobot class Add data augmentation in LeRobotDataset Jun 1, 2024
@Cadene
Copy link
Collaborator

Cadene commented Jun 1, 2024

@marinabar Wonderful PR :)
Could you illustrate the augmentations by adding some screenshots to the PR description? Thanks!

@Cadene
Copy link
Collaborator

Cadene commented Jun 1, 2024

@marinabar Wonderful PR :) Could you illustrate the augmentations by adding some screenshots to the PR description? Thanks!

@marinabar Ideally we should show an image with the biggest augmentation for each transform (to better understand what each transform is doing).
Then we should show the worst case where all biggest augmentation of all transform are applied sequentially to an image (this is because your implementation can apply transforms sequentially).

Also, we should display the frames in the original scale. What you displayed is a bit too small and compressed to get a good feeling.

Ideally we could add a script in lerobot/scripts/show_image_transforms.py that could save these images in outputs/show_image_transforms. That's something we would need each time we add new transforms, so super useful!!!

Finally, we could add backward compatibility tests where we would save the result of specific frames, augmented with each transform. This is to ensure that if torchvision change something, we are aware of it. Here are two pointers for inspirations:

Data augmentation is really something we should be careful about since it can fail silently.

What do you think?

cc @aliberts @alexander-soare

lerobot/configs/default.yaml Outdated Show resolved Hide resolved
examples/6_show_image_transforms.py Outdated Show resolved Hide resolved
lerobot/scripts/show_image_transforms.py Outdated Show resolved Hide resolved
lerobot/common/datasets/lerobot_dataset.py Outdated Show resolved Hide resolved
@aliberts aliberts marked this pull request as ready for review June 10, 2024 14:24
Copy link
Collaborator

@Cadene Cadene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API and tests look great! Thanks Marina and Simon :)
One round of minor change and should be ready to merge.
Could you please ping @alexander-soare when it's done? THanks!

Comment on lines 78 to 91
class SharpnessJitter(Transform):
"""Randomly change the sharpness of an image or video.
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
by a factor of 2.

If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
sharpness (float or tuple of float (min, max)): How much to jitter sharpness.
sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness]
or the given [min, max]. Should be non negative numbers.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment to explain why we dont use RandomAdjustSharpness and copy past your change in this thread? Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 04adbd7

"""Randomly change the sharpness of an image or video.
    Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
+    While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
+    SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
+    augmentations as a result.
+    
    A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
    by a factor of 2.

@@ -57,3 +57,29 @@ wandb:
disable_artifact: false
project: lerobot
notes: ""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have them here ;)

Comment on lines 1 to 12
from pathlib import Path

import hydra
from torchvision.transforms import ToPILImage

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_image_transforms

to_pil = ToPILImage()


def main(cfg, output_dir=Path("outputs/image_transforms")):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide a docstring in header of the script + example command?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 1fbb0d9

Comment on lines 34 to 36
img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you print the output directory?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 540fb9c

@@ -0,0 +1,69 @@
from pathlib import Path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename save_image_transforms_to_safetensors.py for consistency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 1890637

Comment on lines 48 to 53
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": (0.5, 0.5),
}
tf = get_image_transforms(**kwargs)
frames[transform] = tf(original_frame)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we save the two extreme values + mean value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 9a3739d

@@ -0,0 +1,245 @@
from pathlib import Path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename to test_image_transforms.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed 1890637

lerobot/common/datasets/transforms.py Show resolved Hide resolved
lerobot/common/datasets/transforms.py Show resolved Hide resolved
lerobot/common/datasets/transforms.py Show resolved Hide resolved
@@ -57,3 +57,29 @@ wandb:
disable_artifact: false
project: lerobot
notes: ""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we at least nest them under the training key but in this file? To make clear that they are not relevant to eval (for instance).

lerobot/configs/default.yaml Outdated Show resolved Hide resolved
@aliberts aliberts merged commit ff8f6aa into huggingface:main Jun 11, 2024
5 checks passed
@marinabar marinabar deleted the 2024_05_30_add_data_augmentation branch June 14, 2024 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🗃️ Dataset Something dataset-related ✨ Enhancement New feature or request
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants