From 15dd6827140889e24aa00bfd61156ccd4311452c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 17 Jun 2024 08:11:20 +0100 Subject: [PATCH] Add multi-image support to diffusion policy (#218) --- .../diffusion/configuration_diffusion.py | 38 ++++++----- .../policies/diffusion/modeling_diffusion.py | 67 ++++++++++--------- 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 59ed16567..2b7923ada 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -28,7 +28,9 @@ class DiffusionConfig: Notes on the inputs and outputs: - "observation.state" is required as an input key. - - A key starting with "observation.image is required as an input. + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. Right now we only support all images having the same shape. - "action" is required as an output key. Args: @@ -153,22 +155,26 @@ def __post_init__(self): raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) - # There should only be one image key. image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} - if len(image_keys) != 1: - raise ValueError( - f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." - ) - image_key = next(iter(image_keys)) - if self.crop_shape is not None and ( - self.crop_shape[0] > self.input_shapes[image_key][1] - or self.crop_shape[1] > self.input_shapes[image_key][2] - ): - raise ValueError( - f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " - f"for `crop_shape` and {self.input_shapes[image_key]} for " - "`input_shapes[{image_key}]`." - ) + if self.crop_shape is not None: + for image_key in image_keys: + if ( + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] + ): + raise ValueError( + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." + ) + # Check that all input images have the same shape. + first_image_key = next(iter(image_keys)) + for image_key in image_keys: + if self.input_shapes[image_key] != self.input_shapes[first_image_key]: + raise ValueError( + f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " + "expect all image shapes to match." + ) supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: raise ValueError( diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index e0482143d..47378fdf5 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -18,7 +18,6 @@ TODO(alexander-soare): - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - - Make compatible with multiple image keys. """ import math @@ -83,20 +82,14 @@ def __init__( self.diffusion = DiffusionModel(config) - image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] - # Note: This check is covered in the post-init of the config but have a sanity check just in case. - if len(image_keys) != 1: - raise NotImplementedError( - f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." - ) - self.input_image_key = image_keys[0] + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.reset() def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { - "observation.image": deque(maxlen=self.config.n_obs_steps), + "observation.images": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } @@ -124,8 +117,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ batch = self.normalize_inputs(batch) - batch["observation.image"] = batch[self.input_image_key] - + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: @@ -144,7 +137,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch["observation.image"] = batch[self.input_image_key] + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -169,9 +162,10 @@ def __init__(self, config: DiffusionConfig): self.config = config self.rgb_encoder = DiffusionRgbEncoder(config) + num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) self.unet = DiffusionConditionalUnet1d( config, - global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim) + global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images) * config.n_obs_steps, ) @@ -220,23 +214,34 @@ def conditional_sample( return sample + def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: + """Encode image features and concatenate them all together along with the state vector.""" + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + # Extract image feature (first combine batch, sequence, and camera index dims). + img_features = self.rgb_encoder( + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + ) + # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature + # dim (effectively concatenating the camera features). + img_features = einops.rearrange( + img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + ) + # Concatenate state and image features then flatten to (B, global_cond_dim). + return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ This function expects `batch` to have: { "observation.state": (B, n_obs_steps, state_dim) - "observation.image": (B, n_obs_steps, C, H, W) + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) } """ batch_size, n_obs_steps = batch["observation.state"].shape[:2] assert n_obs_steps == self.config.n_obs_steps - # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) - # Separate batch and sequence dims. - img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # run sampling actions = self.conditional_sample(batch_size, global_cond=global_cond) @@ -253,28 +258,23 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: This function expects `batch` to have (at least): { "observation.state": (B, n_obs_steps, state_dim) - "observation.image": (B, n_obs_steps, C, H, W) + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) "action": (B, horizon, action_dim) "action_is_pad": (B, horizon) } """ # Input validation. - assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"}) + n_obs_steps = batch["observation.state"].shape[1] horizon = batch["action"].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps - # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) - # Separate batch and sequence dims. - img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) - - trajectory = batch["action"] + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # Forward diffusion. + trajectory = batch["action"] # Sample noise to add to the trajectory. eps = torch.randn(trajectory.shape, device=trajectory.device) # Sample a random noising timestep for each item in the batch. @@ -305,7 +305,8 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: if self.config.do_mask_loss_for_padding: if "action_is_pad" not in batch: raise ValueError( - f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}." + "You need to provide 'action_is_pad' in the batch when " + f"{self.config.do_mask_loss_for_padding=}." ) in_episode_bound = ~batch["action_is_pad"] loss = loss * in_episode_bound.unsqueeze(-1) @@ -428,7 +429,7 @@ def __init__(self, config: DiffusionConfig): # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # height and width from `config.input_shapes`. image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] - assert len(image_keys) == 1 + # Note: we have a check in the config class to make sure all images have the same shape. image_key = image_keys[0] dummy_input_h_w = ( config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]