Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into update_push_to_hub
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Jun 12, 2024
2 parents 3d79727 + ff8f6aa commit af84c00
Show file tree
Hide file tree
Showing 23 changed files with 997 additions and 78 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/trufflehog.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
on:
push:

name: Secret Leaks

permissions:
contents: read

jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
2 changes: 1 addition & 1 deletion examples/4_train_policy_with_script.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ defaults:
- policy: diffusion
```
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.

Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.

Expand Down
52 changes: 52 additions & 0 deletions examples/6_add_image_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
"""

from pathlib import Path

from torchvision.transforms import ToPILImage, v2

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

dataset_repo_id = "lerobot/aloha_static_tape"

# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(dataset_repo_id)
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`

# Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item()

# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]


# Define the transformations
transforms = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 1.5)),
v2.ColorJitter(contrast=(0.5, 1.5)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
]
)

# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)

# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]

# Create a directory to store output images
output_dir = Path("outputs/image_transforms")
output_dir.mkdir(parents=True, exist_ok=True)

# Save the original frame
to_pil = ToPILImage()
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")

# Save the transformed frame
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
24 changes: 22 additions & 2 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from omegaconf import ListConfig, OmegaConf

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms


def resolve_delta_timestamps(cfg):
Expand Down Expand Up @@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData

resolve_delta_timestamps(cfg)

# TODO(rcadene): add data augmentations
image_transforms = None
if cfg.training.image_transforms.enable:
image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)

if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)

if cfg.get("override_dataset_stats"):
Expand Down
22 changes: 12 additions & 10 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def __init__(
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.repo_id = repo_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
Expand Down Expand Up @@ -151,8 +151,9 @@ def __getitem__(self, idx):
self.tolerance_s,
)

if self.transform is not None:
item = self.transform(item)
if self.image_transforms is not None:
for cam in self.camera_keys:
item[cam] = self.image_transforms(item[cam])

return item

Expand All @@ -168,7 +169,7 @@ def __repr__(self):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)

Expand Down Expand Up @@ -202,7 +203,7 @@ def from_preloaded(
obj.version = version
obj.root = root
obj.split = split
obj.transform = transform
obj.image_transforms = transform
obj.delta_timestamps = delta_timestamps
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
Expand All @@ -225,7 +226,7 @@ def __init__(
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
Expand All @@ -239,7 +240,7 @@ def __init__(
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
image_transforms=image_transforms,
)
for repo_id in repo_ids
]
Expand Down Expand Up @@ -274,7 +275,7 @@ def __init__(
self.version = version
self.root = root
self.split = split
self.transform = transform
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)

Expand Down Expand Up @@ -380,6 +381,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]

return item

def __repr__(self):
Expand All @@ -394,6 +396,6 @@ def __repr__(self):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
Loading

0 comments on commit af84c00

Please sign in to comment.