-
Notifications
You must be signed in to change notification settings - Fork 641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data augmentation in LeRobotDataset #234
Add data augmentation in LeRobotDataset #234
Conversation
@marinabar Wonderful PR :) |
@marinabar Ideally we should show an image with the biggest augmentation for each transform (to better understand what each transform is doing). Also, we should display the frames in the original scale. What you displayed is a bit too small and compressed to get a good feeling. Ideally we could add a script in Finally, we could add backward compatibility tests where we would save the result of specific frames, augmented with each transform. This is to ensure that if torchvision change something, we are aware of it. Here are two pointers for inspirations: Data augmentation is really something we should be careful about since it can fail silently. What do you think? |
…_augmentation' into 2024_05_30_add_data_augmentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API and tests look great! Thanks Marina and Simon :)
One round of minor change and should be ready to merge.
Could you please ping @alexander-soare when it's done? THanks!
class SharpnessJitter(Transform): | ||
"""Randomly change the sharpness of an image or video. | ||
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly. | ||
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness | ||
by a factor of 2. | ||
|
||
If the input is a :class:`torch.Tensor`, | ||
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. | ||
|
||
Args: | ||
sharpness (float or tuple of float (min, max)): How much to jitter sharpness. | ||
sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness] | ||
or the given [min, max]. Should be non negative numbers. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment to explain why we dont use RandomAdjustSharpness
and copy past your change in this thread? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 04adbd7
"""Randomly change the sharpness of an image or video.
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
+ While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
+ SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
+ augmentations as a result.
+
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
by a factor of 2.
lerobot/configs/default.yaml
Outdated
@@ -57,3 +57,29 @@ wandb: | |||
disable_artifact: false | |||
project: lerobot | |||
notes: "" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have them here ;)
from pathlib import Path | ||
|
||
import hydra | ||
from torchvision.transforms import ToPILImage | ||
|
||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | ||
from lerobot.common.datasets.transforms import make_image_transforms | ||
|
||
to_pil = ToPILImage() | ||
|
||
|
||
def main(cfg, output_dir=Path("outputs/image_transforms")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide a docstring in header of the script + example command?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 1fbb0d9
img = transform(frame) | ||
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you print the output directory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 540fb9c
@@ -0,0 +1,69 @@ | |||
from pathlib import Path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename save_image_transforms_to_safetensors.py
for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 1890637
kwargs = { | ||
f"{transform}_weight": 1.0, | ||
f"{transform}_min_max": (0.5, 0.5), | ||
} | ||
tf = get_image_transforms(**kwargs) | ||
frames[transform] = tf(original_frame) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we save the two extreme values + mean value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done 9a3739d
tests/test_transforms.py
Outdated
@@ -0,0 +1,245 @@ | |||
from pathlib import Path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you rename to test_image_transforms.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed 1890637
lerobot/configs/default.yaml
Outdated
@@ -57,3 +57,29 @@ wandb: | |||
disable_artifact: false | |||
project: lerobot | |||
notes: "" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we at least nest them under the training
key but in this file? To make clear that they are not relevant to eval (for instance).
What this does
Implements data augmentation for images a LeRobot dataset object.
RandomSubsetApply
transform to apply a random subset of N transformations from a list of transformations.SharpnessJitter
transform to randomly change the sharpness of an image or video.6_add_image_transforms.py
showing usage of transforms parameter enabled with LeRobot Dataset.visualize_image_transforms.py
script to produce examples of transformed images that would be produced with a given config.Default transforms are :
The parameters are taken from the
default.yaml
configuration and the transforms are defined intransforms.py
. They are then applied in__getitem__()
method of LeRobotDataset.The transformation is applied on the images of all of the given cameras.
(WIP : support for multi image observations in delta_timestamps)
How you can verify it
To test various types of transforms, you can use the newly added script
This applies the transforms from the configuration file and then saves multiple images corresponding to each applied transform.
You can also run the example script:
How it was tested
We trained a model on the data collected on a grasping task from Reachy2, while incorporating data augmentation. On evaluation, the policy appeared more robust to lighting changes. (Evaluation ran in the dark with multiple sources of light)
Examples of single transformations:
transformation
min_max
None
brightness
(0.5, 0.5)
brightness
(2.0, 2.0)
contrast
(0.5, 0.5)
contrast
(2.0, 2.0)
saturation
(0.5, 0.5)
saturation
(2.0, 2.0)
hue
(-0.25, 0.25)
hue
(0.25, 0.25)
sharpness
(0.5, 0.5)
sharpness
(2.0, 2.0)
This change is