Skip to content

Commit

Permalink
Merge branch 'main' into user/aliberts/2024_05_28_compile_torchvision
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts authored Jun 18, 2024
2 parents 550577e + b72d574 commit 461b673
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 59 deletions.
38 changes: 22 additions & 16 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class DiffusionConfig:
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- A key starting with "observation.image is required as an input.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
Expand Down Expand Up @@ -153,22 +155,26 @@ def __post_init__(self):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.crop_shape is not None and (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
if self.crop_shape is not None:
for image_key in image_keys:
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
# Check that all input images have the same shape.
first_image_key = next(iter(image_keys))
for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
Expand Down
69 changes: 36 additions & 33 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TODO(alexander-soare):
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
"""

import math
Expand Down Expand Up @@ -83,20 +82,14 @@ def __init__(

self.diffusion = DiffusionModel(config)

image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

self.reset()

def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
Expand Down Expand Up @@ -124,8 +117,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]

batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)

if len(self._queues["action"]) == 0:
Expand All @@ -144,7 +137,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
Expand All @@ -169,9 +162,12 @@ def __init__(self, config: DiffusionConfig):
self.config = config

self.rgb_encoder = DiffusionRgbEncoder(config)
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d(
config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
global_cond_dim=(
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
)
* config.n_obs_steps,
)

Expand Down Expand Up @@ -220,23 +216,34 @@ def conditional_sample(

return sample

def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
# Extract image feature (first combine batch, sequence, and camera index dims).
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature
# dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
# Concatenate state and image features then flatten to (B, global_cond_dim).
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)

def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps

# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)

# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond)
Expand All @@ -253,28 +260,23 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"})
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps

# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)

trajectory = batch["action"]
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)

# Forward diffusion.
trajectory = batch["action"]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
Expand Down Expand Up @@ -305,7 +307,8 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
"You need to provide 'action_is_pad' in the batch when "
f"{self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
Expand Down Expand Up @@ -428,7 +431,7 @@ def __init__(self, config: DiffusionConfig):
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
# Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
Expand Down
53 changes: 43 additions & 10 deletions lerobot/scripts/visualize_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@
from lerobot.common.datasets.transforms import get_image_transforms

OUTPUT_DIR = Path("outputs/image_transforms")
N_EXAMPLES = 5
to_pil = ToPILImage()


def save_config_all_transforms(cfg, original_frame, output_dir):
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
tf = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
Expand All @@ -88,15 +87,15 @@ def save_config_all_transforms(cfg, original_frame, output_dir):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)

for i in range(1, N_EXAMPLES + 1):
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)

print("Combined transforms examples saved to:")
print(f" {output_dir_all}")


def save_config_single_transforms(cfg, original_frame, output_dir):
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
transforms = [
"brightness",
"contrast",
Expand All @@ -106,6 +105,7 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
]
print("Individual transforms examples saved to:")
for transform in transforms:
# Apply one transformation with random value in min_max range
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": cfg[f"{transform}"].min_max,
Expand All @@ -114,18 +114,46 @@ def save_config_single_transforms(cfg, original_frame, output_dir):
output_dir_single = output_dir / f"{transform}"
output_dir_single.mkdir(parents=True, exist_ok=True)

for i in range(1, N_EXAMPLES + 1):
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)

# Apply min transformation
min_value, max_value = cfg[f"{transform}"].min_max
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (min_value, min_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100)

# Apply max transformation
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (max_value, max_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100)

# Apply mean transformation
mean_value = (min_value + max_value) / 2
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (mean_value, mean_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)

print(f" {output_dir_single}")


@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms(cfg):
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
dataset = LeRobotDataset(cfg.dataset_repo_id)

output_dir = Path(OUTPUT_DIR) / cfg.dataset_repo_id.split("/")[-1]
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
output_dir.mkdir(parents=True, exist_ok=True)

# Get 1st frame from 1st camera of 1st episode
Expand All @@ -134,8 +162,13 @@ def visualize_transforms(cfg):
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")

save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir)
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)


@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms_cli(cfg):
visualize_transforms(cfg, output_dir=OUTPUT_DIR)


if __name__ == "__main__":
Expand Down
42 changes: 42 additions & 0 deletions tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.visualize_image_transforms import visualize_transforms
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel

ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
Expand Down Expand Up @@ -258,3 +259,44 @@ def test_sharpness_jitter_invalid_range_min_negative():
def test_sharpness_jitter_invalid_range_max_smaller():
with pytest.raises(ValueError):
SharpnessJitter((2.0, 0.1))


@pytest.mark.parametrize(
"repo_id, n_examples",
[
("lerobot/aloha_sim_transfer_cube_human", 3),
],
)
def test_visualize_image_transforms(repo_id, n_examples):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
output_dir = output_dir / repo_id.split("/")[-1]

# Check if the original frame image exists
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."

# Check if the transformed images exist for each transform type
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
for transform in transforms:
transform_dir = output_dir / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."

# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
for file_name in expected_files:
assert (
transform_dir / file_name
).exists(), f"{file_name} was not found in {transform} directory."

# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = output_dir / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert any(
combined_transforms_dir.iterdir()
), "No transformed images found in combined transforms directory."
for i in range(1, n_examples + 1):
assert (
combined_transforms_dir / f"{i}.png"
).exists(), f"Combined transform image {i}.png was not found."

0 comments on commit 461b673

Please sign in to comment.