Skip to content

Commit

Permalink
Add multi-image support to diffusion policy (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare authored Jun 17, 2024
1 parent e28fa23 commit 15dd682
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 49 deletions.
38 changes: 22 additions & 16 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class DiffusionConfig:
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- A key starting with "observation.image is required as an input.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views. Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
Expand Down Expand Up @@ -153,22 +155,26 @@ def __post_init__(self):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.crop_shape is not None and (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
if self.crop_shape is not None:
for image_key in image_keys:
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
# Check that all input images have the same shape.
first_image_key = next(iter(image_keys))
for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
Expand Down
67 changes: 34 additions & 33 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TODO(alexander-soare):
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
"""

import math
Expand Down Expand Up @@ -83,20 +82,14 @@ def __init__(

self.diffusion = DiffusionModel(config)

image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

self.reset()

def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
Expand Down Expand Up @@ -124,8 +117,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]

batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)

if len(self._queues["action"]) == 0:
Expand All @@ -144,7 +137,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
Expand All @@ -169,9 +162,10 @@ def __init__(self, config: DiffusionConfig):
self.config = config

self.rgb_encoder = DiffusionRgbEncoder(config)
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d(
config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images)
* config.n_obs_steps,
)

Expand Down Expand Up @@ -220,23 +214,34 @@ def conditional_sample(

return sample

def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
# Extract image feature (first combine batch, sequence, and camera index dims).
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature
# dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
# Concatenate state and image features then flatten to (B, global_cond_dim).
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)

def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have:
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps

# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)

# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond)
Expand All @@ -253,28 +258,23 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"})
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps

# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)

trajectory = batch["action"]
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)

# Forward diffusion.
trajectory = batch["action"]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
Expand Down Expand Up @@ -305,7 +305,8 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
"You need to provide 'action_is_pad' in the batch when "
f"{self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
Expand Down Expand Up @@ -428,7 +429,7 @@ def __init__(self, config: DiffusionConfig):
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
# Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
Expand Down

0 comments on commit 15dd682

Please sign in to comment.