From 59397fb44afc19297fcdb57dbed2ddef7a6e34ca Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 9 Mar 2024 18:44:36 +0100 Subject: [PATCH 1/6] Move tdmpc files --- lerobot/common/policies/factory.py | 2 +- lerobot/common/policies/tdmpc/__init__.py | 0 lerobot/common/policies/{tdmpc_helper.py => tdmpc/helper.py} | 0 lerobot/common/policies/{tdmpc.py => tdmpc/policy.py} | 2 +- 4 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 lerobot/common/policies/tdmpc/__init__.py rename lerobot/common/policies/{tdmpc_helper.py => tdmpc/helper.py} (100%) rename lerobot/common/policies/{tdmpc.py => tdmpc/policy.py} (99%) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 9507586cd..a956cb4bc 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,6 +1,6 @@ def make_policy(cfg): if cfg.policy.name == "tdmpc": - from lerobot.common.policies.tdmpc import TDMPC + from lerobot.common.policies.tdmpc.policy import TDMPC policy = TDMPC(cfg.policy, cfg.device) elif cfg.policy.name == "diffusion": diff --git a/lerobot/common/policies/tdmpc/__init__.py b/lerobot/common/policies/tdmpc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lerobot/common/policies/tdmpc_helper.py b/lerobot/common/policies/tdmpc/helper.py similarity index 100% rename from lerobot/common/policies/tdmpc_helper.py rename to lerobot/common/policies/tdmpc/helper.py diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc/policy.py similarity index 99% rename from lerobot/common/policies/tdmpc.py rename to lerobot/common/policies/tdmpc/policy.py index 42fbb825f..ae9888a50 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -import lerobot.common.policies.tdmpc_helper as h +import lerobot.common.policies.tdmpc.helper as h FIRST_FRAME = 0 From 302b78962c97700f231bc12bcbe229c7ad4874e3 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 10 Mar 2024 15:31:17 +0100 Subject: [PATCH 2/6] Integrate diffusion policy --- lerobot/common/datasets/pusht.py | 17 +- .../diffusion/diffusion_unet_image_policy.py | 32 +- .../diffusion/model/conditional_unet1d.py | 286 ++++++ .../diffusion/model/conv1d_components.py | 47 + .../diffusion/model/crop_randomizer.py | 294 ++++++ .../diffusion/model/dict_of_tensor_mixin.py | 41 + .../policies/diffusion/model/lr_scheduler.py | 46 + .../diffusion/model/mask_generator.py | 65 ++ .../diffusion/model/module_attr_mixin.py | 15 + .../{ => model}/multi_image_obs_encoder.py | 6 +- .../policies/diffusion/model/normalizer.py | 358 +++++++ .../diffusion/model/positional_embedding.py | 19 + .../policies/diffusion/model/tensor_utils.py | 971 ++++++++++++++++++ lerobot/common/policies/diffusion/policy.py | 6 +- .../policies/diffusion/pytorch_utils.py | 46 + .../policies/diffusion/replay_buffer.py | 614 +++++++++++ 16 files changed, 2850 insertions(+), 13 deletions(-) create mode 100644 lerobot/common/policies/diffusion/model/conditional_unet1d.py create mode 100644 lerobot/common/policies/diffusion/model/conv1d_components.py create mode 100644 lerobot/common/policies/diffusion/model/crop_randomizer.py create mode 100644 lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py create mode 100644 lerobot/common/policies/diffusion/model/lr_scheduler.py create mode 100644 lerobot/common/policies/diffusion/model/mask_generator.py create mode 100644 lerobot/common/policies/diffusion/model/module_attr_mixin.py rename lerobot/common/policies/diffusion/{ => model}/multi_image_obs_encoder.py (96%) create mode 100644 lerobot/common/policies/diffusion/model/normalizer.py create mode 100644 lerobot/common/policies/diffusion/model/positional_embedding.py create mode 100644 lerobot/common/policies/diffusion/model/tensor_utils.py create mode 100644 lerobot/common/policies/diffusion/pytorch_utils.py create mode 100644 lerobot/common/policies/diffusion/replay_buffer.py diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 33355c702..835f4cb59 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -5,11 +5,10 @@ import numpy as np import pygame import pymunk +import shapely.geometry as sg import torch import torchrl import tqdm -from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer -from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely from tensordict import TensorDict from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.storages import TensorStorage @@ -17,6 +16,7 @@ from lerobot.common.datasets.abstract import AbstractExperienceReplay from lerobot.common.datasets.utils import download_and_extract_zip +from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, @@ -26,6 +26,19 @@ PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") +def pymunk_to_shapely(body, shapes): + geoms = [] + for shape in shapes: + if isinstance(shape, pymunk.shapes.Poly): + verts = [body.local_to_world(v) for v in shape.get_vertices()] + verts += [verts[0]] + geoms.append(sg.Polygon(verts)) + else: + raise RuntimeError(f"Unsupported shape type {type(shape)}") + geom = sg.MultiPolygon(geoms) + return geom + + def get_goal_pose_body(pose): mass = 1 inertia = pymunk.moment_for_box(mass, (50, 100)) diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index 3c12d53a4..b759802ea 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -5,11 +5,33 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from einops import reduce -from diffusion_policy.common.pytorch_util import dict_apply -from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D -from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator -from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder -from diffusion_policy.policy.base_image_policy import BaseImagePolicy +from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D +from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator +from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin +from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder +from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer +from lerobot.common.policies.diffusion.pytorch_utils import dict_apply + + +class BaseImagePolicy(ModuleAttrMixin): + # init accepts keyword argument shape_meta, see config/task/*_image.yaml + + def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + obs_dict: + str: B,To,* + return: B,Ta,Da + """ + raise NotImplementedError() + + # reset state for stateful policies + def reset(self): + pass + + # ========== training =========== + # no standard training interface except setting normalizer + def set_normalizer(self, normalizer: LinearNormalizer): + raise NotImplementedError() class DiffusionUnetImagePolicy(BaseImagePolicy): diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py new file mode 100644 index 000000000..d2971d38e --- /dev/null +++ b/lerobot/common/policies/diffusion/model/conditional_unet1d.py @@ -0,0 +1,286 @@ +import logging +from typing import Union + +import einops +import torch +import torch.nn as nn +from einops.layers.torch import Rearrange + +from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d +from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb + +logger = logging.getLogger(__name__) + + +class ConditionalResidualBlock1D(nn.Module): + def __init__( + self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False + ): + super().__init__() + + self.blocks = nn.ModuleList( + [ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), + ] + ) + + # FiLM modulation https://arxiv.org/abs/1709.07871 + # predicts per-channel scale and bias + cond_channels = out_channels + if cond_predict_scale: + cond_channels = out_channels * 2 + self.cond_predict_scale = cond_predict_scale + self.out_channels = out_channels + self.cond_encoder = nn.Sequential( + nn.Mish(), + nn.Linear(cond_dim, cond_channels), + Rearrange("batch t -> batch t 1"), + ) + + # make sure dimensions compatible + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) + + def forward(self, x, cond): + """ + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + out = self.blocks[0](x) + embed = self.cond_encoder(cond) + if self.cond_predict_scale: + embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) + scale = embed[:, 0, ...] + bias = embed[:, 1, ...] + out = scale * out + bias + else: + out = out + embed + out = self.blocks[1](out) + out = out + self.residual_conv(x) + return out + + +class ConditionalUnet1D(nn.Module): + def __init__( + self, + input_dim, + local_cond_dim=None, + global_cond_dim=None, + diffusion_step_embed_dim=256, + down_dims=None, + kernel_size=3, + n_groups=8, + cond_predict_scale=False, + ): + super().__init__() + if down_dims is None: + down_dims = [256, 512, 1024] + + all_dims = [input_dim] + list(down_dims) + start_dim = down_dims[0] + + dsed = diffusion_step_embed_dim + diffusion_step_encoder = nn.Sequential( + SinusoidalPosEmb(dsed), + nn.Linear(dsed, dsed * 4), + nn.Mish(), + nn.Linear(dsed * 4, dsed), + ) + cond_dim = dsed + if global_cond_dim is not None: + cond_dim += global_cond_dim + + in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False)) + + local_cond_encoder = None + if local_cond_dim is not None: + _, dim_out = in_out[0] + dim_in = local_cond_dim + local_cond_encoder = nn.ModuleList( + [ + # down encoder + ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + # up encoder + ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + ] + ) + + mid_dim = all_dims[-1] + self.mid_modules = nn.ModuleList( + [ + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + ConditionalResidualBlock1D( + mid_dim, + mid_dim, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + ] + ) + + down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + down_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + Downsample1d(dim_out) if not is_last else nn.Identity(), + ] + ) + ) + + up_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + up_modules.append( + nn.ModuleList( + [ + ConditionalResidualBlock1D( + dim_out * 2, + dim_in, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + ConditionalResidualBlock1D( + dim_in, + dim_in, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ), + Upsample1d(dim_in) if not is_last else nn.Identity(), + ] + ) + ) + + final_conv = nn.Sequential( + Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), + nn.Conv1d(start_dim, input_dim, 1), + ) + + self.diffusion_step_encoder = diffusion_step_encoder + self.local_cond_encoder = local_cond_encoder + self.up_modules = up_modules + self.down_modules = down_modules + self.final_conv = final_conv + + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + local_cond=None, + global_cond=None, + **kwargs, + ): + """ + x: (B,T,input_dim) + timestep: (B,) or int, diffusion step + local_cond: (B,T,local_cond_dim) + global_cond: (B,global_cond_dim) + output: (B,T,input_dim) + """ + sample = einops.rearrange(sample, "b h t -> b t h") + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + global_feature = self.diffusion_step_encoder(timesteps) + + if global_cond is not None: + global_feature = torch.cat([global_feature, global_cond], axis=-1) + + # encode local features + h_local = [] + if local_cond is not None: + local_cond = einops.rearrange(local_cond, "b h t -> b t h") + resnet, resnet2 = self.local_cond_encoder + x = resnet(local_cond, global_feature) + h_local.append(x) + x = resnet2(local_cond, global_feature) + h_local.append(x) + + x = sample + h = [] + for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + x = resnet(x, global_feature) + if idx == 0 and len(h_local) > 0: + x = x + h_local[0] + x = resnet2(x, global_feature) + h.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, global_feature) + # The correct condition should be: + # if idx == (len(self.up_modules)-1) and len(h_local) > 0: + # However this change will break compatibility with published checkpoints. + # Therefore it is left as a comment. + if idx == len(self.up_modules) and len(h_local) > 0: + x = x + h_local[1] + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + x = einops.rearrange(x, "b t h -> b h t") + return x diff --git a/lerobot/common/policies/diffusion/model/conv1d_components.py b/lerobot/common/policies/diffusion/model/conv1d_components.py new file mode 100644 index 000000000..3c21eaf6f --- /dev/null +++ b/lerobot/common/policies/diffusion/model/conv1d_components.py @@ -0,0 +1,47 @@ +import torch.nn as nn + +# from einops.layers.torch import Rearrange + + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + # Rearrange('batch channels horizon -> batch channels 1 horizon'), + nn.GroupNorm(n_groups, out_channels), + # Rearrange('batch channels 1 horizon -> batch channels horizon'), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +# def test(): +# cb = Conv1dBlock(256, 128, kernel_size=3) +# x = torch.zeros((1,256,16)) +# o = cb(x) diff --git a/lerobot/common/policies/diffusion/model/crop_randomizer.py b/lerobot/common/policies/diffusion/model/crop_randomizer.py new file mode 100644 index 000000000..2e60f353b --- /dev/null +++ b/lerobot/common/policies/diffusion/model/crop_randomizer.py @@ -0,0 +1,294 @@ +import torch +import torch.nn as nn +import torchvision.transforms.functional as ttf + +import lerobot.common.policies.diffusion.model.tensor_utils as tu + + +class CropRandomizer(nn.Module): + """ + Randomly sample crops at input, and then average across crop features at output. + """ + + def __init__( + self, + input_shape, + crop_height, + crop_width, + num_crops=1, + pos_enc=False, + ): + """ + Args: + input_shape (tuple, list): shape of input (not including batch dimension) + crop_height (int): crop height + crop_width (int): crop width + num_crops (int): number of random crops to take + pos_enc (bool): if True, add 2 channels to the output to encode the spatial + location of the cropped pixels in the source image + """ + super().__init__() + + assert len(input_shape) == 3 # (C, H, W) + assert crop_height < input_shape[1] + assert crop_width < input_shape[2] + + self.input_shape = input_shape + self.crop_height = crop_height + self.crop_width = crop_width + self.num_crops = num_crops + self.pos_enc = pos_enc + + def output_shape_in(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_in operation, where raw inputs (usually observation modalities) + are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + + # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because + # the number of crops are reshaped into the batch dimension, increasing the batch + # size from B to B * N + out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0] + return [out_c, self.crop_height, self.crop_width] + + def output_shape_out(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. Corresponds to + the @forward_out operation, where processed inputs (usually encoded observation + modalities) are passed in. + + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + + # since the forward_out operation splits [B * N, ...] -> [B, N, ...] + # and then pools to result in [B, ...], only the batch dimension changes, + # and so the other dimensions retain their shape. + return list(input_shape) + + def forward_in(self, inputs): + """ + Samples N random crops for each input in the batch, and then reshapes + inputs to [B * N, ...]. + """ + assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions + if self.training: + # generate random crops + out, _ = sample_random_image_crops( + images=inputs, + crop_height=self.crop_height, + crop_width=self.crop_width, + num_crops=self.num_crops, + pos_enc=self.pos_enc, + ) + # [B, N, ...] -> [B * N, ...] + return tu.join_dimensions(out, 0, 1) + else: + # take center crop during eval + out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width)) + if self.num_crops > 1: + B, C, H, W = out.shape # noqa: N806 + out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W) + # [B * N, ...] + return out + + def forward_out(self, inputs): + """ + Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N + to result in shape [B, ...] to make sure the network output is consistent with + what would have happened if there were no randomization. + """ + if self.num_crops <= 1: + return inputs + else: + batch_size = inputs.shape[0] // self.num_crops + out = tu.reshape_dimensions( + inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops) + ) + return out.mean(dim=1) + + def forward(self, inputs): + return self.forward_in(inputs) + + def __repr__(self): + """Pretty print network.""" + header = "{}".format(str(self.__class__.__name__)) + msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format( + self.input_shape, self.crop_height, self.crop_width, self.num_crops + ) + return msg + + +def crop_image_from_indices(images, crop_indices, crop_height, crop_width): + """ + Crops images at the locations specified by @crop_indices. Crops will be + taken across all channels. + + Args: + images (torch.Tensor): batch of images of shape [..., C, H, W] + + crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where + N is the number of crops to take per image and each entry corresponds + to the pixel height and width of where to take the crop. Note that + the indices can also be of shape [..., 2] if only 1 crop should + be taken per image. Leading dimensions must be consistent with + @images argument. Each index specifies the top left of the crop. + Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where + H and W are the height and width of @images and CH and CW are + @crop_height and @crop_width. + + crop_height (int): height of crop to take + + crop_width (int): width of crop to take + + Returns: + crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width] + """ + + # make sure length of input shapes is consistent + assert crop_indices.shape[-1] == 2 + ndim_im_shape = len(images.shape) + ndim_indices_shape = len(crop_indices.shape) + assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2) + + # maybe pad so that @crop_indices is shape [..., N, 2] + is_padded = False + if ndim_im_shape == ndim_indices_shape + 2: + crop_indices = crop_indices.unsqueeze(-2) + is_padded = True + + # make sure leading dimensions between images and indices are consistent + assert images.shape[:-3] == crop_indices.shape[:-2] + + device = images.device + image_c, image_h, image_w = images.shape[-3:] + num_crops = crop_indices.shape[-2] + + # make sure @crop_indices are in valid range + assert (crop_indices[..., 0] >= 0).all().item() + assert (crop_indices[..., 0] < (image_h - crop_height)).all().item() + assert (crop_indices[..., 1] >= 0).all().item() + assert (crop_indices[..., 1] < (image_w - crop_width)).all().item() + + # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window. + + # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW] + crop_ind_grid_h = torch.arange(crop_height).to(device) + crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1) + # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW] + crop_ind_grid_w = torch.arange(crop_width).to(device) + crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0) + # combine into shape [CH, CW, 2] + crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1) + + # Add above grid with the offset index of each sampled crop to get 2d indices for each crop. + # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2] + # shape array that tells us which pixels from the corresponding source image to grab. + grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2] + all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape) + + # For using @torch.gather, convert to flat indices from 2D indices, and also + # repeat across the channel dimension. To get flat index of each pixel to grab for + # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind + all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW] + all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW] + all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW] + + # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds + images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4) + images_to_crop = tu.flatten(images_to_crop, begin_axis=-2) + crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds) + # [..., N, C, CH * CW] -> [..., N, C, CH, CW] + reshape_axis = len(crops.shape) - 1 + crops = tu.reshape_dimensions( + crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width) + ) + + if is_padded: + # undo padding -> [..., C, CH, CW] + crops = crops.squeeze(-4) + return crops + + +def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False): + """ + For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from + @images. + + Args: + images (torch.Tensor): batch of images of shape [..., C, H, W] + + crop_height (int): height of crop to take + + crop_width (int): width of crop to take + + num_crops (n): number of crops to sample + + pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial + encoding of the original source pixel locations. This means that the + output crops will contain information about where in the source image + it was sampled from. + + Returns: + crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width) + if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width) + + crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2) + """ + device = images.device + + # maybe add 2 channels of spatial encoding to the source image + source_im = images + if pos_enc: + # spatial encoding [y, x] in [0, 1] + h, w = source_im.shape[-2:] + pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w)) + pos_y = pos_y.float().to(device) / float(h) + pos_x = pos_x.float().to(device) / float(w) + position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W] + + # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W] + leading_shape = source_im.shape[:-3] + position_enc = position_enc[(None,) * len(leading_shape)] + position_enc = position_enc.expand(*leading_shape, -1, -1, -1) + + # concat across channel dimension with input + source_im = torch.cat((source_im, position_enc), dim=-3) + + # make sure sample boundaries ensure crops are fully within the images + image_c, image_h, image_w = source_im.shape[-3:] + max_sample_h = image_h - crop_height + max_sample_w = image_w - crop_width + + # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W]. + # Each gets @num_crops samples - typically this will just be the batch dimension (B), so + # we will sample [B, N] indices, but this supports having more than one leading dimension, + # or possibly no leading dimension. + # + # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints + crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() + crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() + crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2] + + crops = crop_image_from_indices( + images=source_im, + crop_indices=crop_inds, + crop_height=crop_height, + crop_width=crop_width, + ) + + return crops, crop_inds diff --git a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py b/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py new file mode 100644 index 000000000..d13560060 --- /dev/null +++ b/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn + + +class DictOfTensorMixin(nn.Module): + def __init__(self, params_dict=None): + super().__init__() + if params_dict is None: + params_dict = nn.ParameterDict() + self.params_dict = params_dict + + @property + def device(self): + return next(iter(self.parameters())).device + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + def dfs_add(dest, keys, value: torch.Tensor): + if len(keys) == 1: + dest[keys[0]] = value + return + + if keys[0] not in dest: + dest[keys[0]] = nn.ParameterDict() + dfs_add(dest[keys[0]], keys[1:], value) + + def load_dict(state_dict, prefix): + out_dict = nn.ParameterDict() + for key, value in state_dict.items(): + value: torch.Tensor + if key.startswith(prefix): + param_keys = key[len(prefix) :].split(".")[1:] + # if len(param_keys) == 0: + # import pdb; pdb.set_trace() + dfs_add(out_dict, param_keys, value.clone()) + return out_dict + + self.params_dict = load_dict(state_dict, prefix + "params_dict") + self.params_dict.requires_grad_(False) + return diff --git a/lerobot/common/policies/diffusion/model/lr_scheduler.py b/lerobot/common/policies/diffusion/model/lr_scheduler.py new file mode 100644 index 000000000..084b3a366 --- /dev/null +++ b/lerobot/common/policies/diffusion/model/lr_scheduler.py @@ -0,0 +1,46 @@ +from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + **kwargs, +): + """ + Added kwargs vs diffuser's original implementation + + Unified API to get any scheduler from its name. + + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer, **kwargs) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs + ) diff --git a/lerobot/common/policies/diffusion/model/mask_generator.py b/lerobot/common/policies/diffusion/model/mask_generator.py new file mode 100644 index 000000000..63306dea3 --- /dev/null +++ b/lerobot/common/policies/diffusion/model/mask_generator.py @@ -0,0 +1,65 @@ +import torch + +from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin + + +class LowdimMaskGenerator(ModuleAttrMixin): + def __init__( + self, + action_dim, + obs_dim, + # obs mask setup + max_n_obs_steps=2, + fix_obs_steps=True, + # action mask + action_visible=False, + ): + super().__init__() + self.action_dim = action_dim + self.obs_dim = obs_dim + self.max_n_obs_steps = max_n_obs_steps + self.fix_obs_steps = fix_obs_steps + self.action_visible = action_visible + + @torch.no_grad() + def forward(self, shape, seed=None): + device = self.device + B, T, D = shape # noqa: N806 + assert (self.action_dim + self.obs_dim) == D + + # create all tensors on this device + rng = torch.Generator(device=device) + if seed is not None: + rng = rng.manual_seed(seed) + + # generate dim mask + dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device) + is_action_dim = dim_mask.clone() + is_action_dim[..., : self.action_dim] = True + is_obs_dim = ~is_action_dim + + # generate obs mask + if self.fix_obs_steps: + obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device) + else: + obs_steps = torch.randint( + low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device + ) + + steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T) + obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) + obs_mask = obs_mask & is_obs_dim + + # generate action mask + if self.action_visible: + action_steps = torch.maximum( + obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device) + ) + action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) + action_mask = action_mask & is_action_dim + + mask = obs_mask + if self.action_visible: + mask = mask | action_mask + + return mask diff --git a/lerobot/common/policies/diffusion/model/module_attr_mixin.py b/lerobot/common/policies/diffusion/model/module_attr_mixin.py new file mode 100644 index 000000000..5d2cf4ea9 --- /dev/null +++ b/lerobot/common/policies/diffusion/model/module_attr_mixin.py @@ -0,0 +1,15 @@ +import torch.nn as nn + + +class ModuleAttrMixin(nn.Module): + def __init__(self): + super().__init__() + self._dummy_variable = nn.Parameter() + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype diff --git a/lerobot/common/policies/diffusion/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py similarity index 96% rename from lerobot/common/policies/diffusion/multi_image_obs_encoder.py rename to lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py index e52f147f4..94dc6f490 100644 --- a/lerobot/common/policies/diffusion/multi_image_obs_encoder.py +++ b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py @@ -5,9 +5,9 @@ import torch.nn as nn import torchvision -from diffusion_policy.common.pytorch_util import replace_submodules -from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin -from diffusion_policy.model.vision.crop_randomizer import CropRandomizer +from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer +from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin +from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules class MultiImageObsEncoder(ModuleAttrMixin): diff --git a/lerobot/common/policies/diffusion/model/normalizer.py b/lerobot/common/policies/diffusion/model/normalizer.py new file mode 100644 index 000000000..0e4d79abd --- /dev/null +++ b/lerobot/common/policies/diffusion/model/normalizer.py @@ -0,0 +1,358 @@ +from typing import Dict, Union + +import numpy as np +import torch +import torch.nn as nn +import zarr + +from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin +from lerobot.common.policies.diffusion.pytorch_utils import dict_apply + + +class LinearNormalizer(DictOfTensorMixin): + avaliable_modes = ["limits", "gaussian"] + + @torch.no_grad() + def fit( + self, + data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array], + last_n_dims=1, + dtype=torch.float32, + mode="limits", + output_max=1.0, + output_min=-1.0, + range_eps=1e-4, + fit_offset=True, + ): + if isinstance(data, dict): + for key, value in data.items(): + self.params_dict[key] = _fit( + value, + last_n_dims=last_n_dims, + dtype=dtype, + mode=mode, + output_max=output_max, + output_min=output_min, + range_eps=range_eps, + fit_offset=fit_offset, + ) + else: + self.params_dict["_default"] = _fit( + data, + last_n_dims=last_n_dims, + dtype=dtype, + mode=mode, + output_max=output_max, + output_min=output_min, + range_eps=range_eps, + fit_offset=fit_offset, + ) + + def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: + return self.normalize(x) + + def __getitem__(self, key: str): + return SingleFieldLinearNormalizer(self.params_dict[key]) + + def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"): + self.params_dict[key] = value.params_dict + + def _normalize_impl(self, x, forward=True): + if isinstance(x, dict): + result = {} + for key, value in x.items(): + params = self.params_dict[key] + result[key] = _normalize(value, params, forward=forward) + return result + else: + if "_default" not in self.params_dict: + raise RuntimeError("Not initialized") + params = self.params_dict["_default"] + return _normalize(x, params, forward=forward) + + def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: + return self._normalize_impl(x, forward=True) + + def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: + return self._normalize_impl(x, forward=False) + + def get_input_stats(self) -> Dict: + if len(self.params_dict) == 0: + raise RuntimeError("Not initialized") + if len(self.params_dict) == 1 and "_default" in self.params_dict: + return self.params_dict["_default"]["input_stats"] + + result = {} + for key, value in self.params_dict.items(): + if key != "_default": + result[key] = value["input_stats"] + return result + + def get_output_stats(self, key="_default"): + input_stats = self.get_input_stats() + if "min" in input_stats: + # no dict + return dict_apply(input_stats, self.normalize) + + result = {} + for key, group in input_stats.items(): + this_dict = {} + for name, value in group.items(): + this_dict[name] = self.normalize({key: value})[key] + result[key] = this_dict + return result + + +class SingleFieldLinearNormalizer(DictOfTensorMixin): + avaliable_modes = ["limits", "gaussian"] + + @torch.no_grad() + def fit( + self, + data: Union[torch.Tensor, np.ndarray, zarr.Array], + last_n_dims=1, + dtype=torch.float32, + mode="limits", + output_max=1.0, + output_min=-1.0, + range_eps=1e-4, + fit_offset=True, + ): + self.params_dict = _fit( + data, + last_n_dims=last_n_dims, + dtype=dtype, + mode=mode, + output_max=output_max, + output_min=output_min, + range_eps=range_eps, + fit_offset=fit_offset, + ) + + @classmethod + def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs): + obj = cls() + obj.fit(data, **kwargs) + return obj + + @classmethod + def create_manual( + cls, + scale: Union[torch.Tensor, np.ndarray], + offset: Union[torch.Tensor, np.ndarray], + input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]], + ): + def to_tensor(x): + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) + x = x.flatten() + return x + + # check + for x in [offset] + list(input_stats_dict.values()): + assert x.shape == scale.shape + assert x.dtype == scale.dtype + + params_dict = nn.ParameterDict( + { + "scale": to_tensor(scale), + "offset": to_tensor(offset), + "input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)), + } + ) + return cls(params_dict) + + @classmethod + def create_identity(cls, dtype=torch.float32): + scale = torch.tensor([1], dtype=dtype) + offset = torch.tensor([0], dtype=dtype) + input_stats_dict = { + "min": torch.tensor([-1], dtype=dtype), + "max": torch.tensor([1], dtype=dtype), + "mean": torch.tensor([0], dtype=dtype), + "std": torch.tensor([1], dtype=dtype), + } + return cls.create_manual(scale, offset, input_stats_dict) + + def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: + return _normalize(x, self.params_dict, forward=True) + + def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: + return _normalize(x, self.params_dict, forward=False) + + def get_input_stats(self): + return self.params_dict["input_stats"] + + def get_output_stats(self): + return dict_apply(self.params_dict["input_stats"], self.normalize) + + def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: + return self.normalize(x) + + +def _fit( + data: Union[torch.Tensor, np.ndarray, zarr.Array], + last_n_dims=1, + dtype=torch.float32, + mode="limits", + output_max=1.0, + output_min=-1.0, + range_eps=1e-4, + fit_offset=True, +): + assert mode in ["limits", "gaussian"] + assert last_n_dims >= 0 + assert output_max > output_min + + # convert data to torch and type + if isinstance(data, zarr.Array): + data = data[:] + if isinstance(data, np.ndarray): + data = torch.from_numpy(data) + if dtype is not None: + data = data.type(dtype) + + # convert shape + dim = 1 + if last_n_dims > 0: + dim = np.prod(data.shape[-last_n_dims:]) + data = data.reshape(-1, dim) + + # compute input stats min max mean std + input_min, _ = data.min(axis=0) + input_max, _ = data.max(axis=0) + input_mean = data.mean(axis=0) + input_std = data.std(axis=0) + + # compute scale and offset + if mode == "limits": + if fit_offset: + # unit scale + input_range = input_max - input_min + ignore_dim = input_range < range_eps + input_range[ignore_dim] = output_max - output_min + scale = (output_max - output_min) / input_range + offset = output_min - scale * input_min + offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] + # ignore dims scaled to mean of output max and min + else: + # use this when data is pre-zero-centered. + assert output_max > 0 + assert output_min < 0 + # unit abs + output_abs = min(abs(output_min), abs(output_max)) + input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max)) + ignore_dim = input_abs < range_eps + input_abs[ignore_dim] = output_abs + # don't scale constant channels + scale = output_abs / input_abs + offset = torch.zeros_like(input_mean) + elif mode == "gaussian": + ignore_dim = input_std < range_eps + scale = input_std.clone() + scale[ignore_dim] = 1 + scale = 1 / scale + + offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean) + + # save + this_params = nn.ParameterDict( + { + "scale": scale, + "offset": offset, + "input_stats": nn.ParameterDict( + {"min": input_min, "max": input_max, "mean": input_mean, "std": input_std} + ), + } + ) + for p in this_params.parameters(): + p.requires_grad_(False) + return this_params + + +def _normalize(x, params, forward=True): + assert "scale" in params + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + scale = params["scale"] + offset = params["offset"] + x = x.to(device=scale.device, dtype=scale.dtype) + src_shape = x.shape + x = x.reshape(-1, scale.shape[0]) + x = x * scale + offset if forward else (x - offset) / scale + x = x.reshape(src_shape) + return x + + +def test(): + data = torch.zeros((100, 10, 9, 2)).uniform_() + data[..., 0, 0] = 0 + + normalizer = SingleFieldLinearNormalizer() + normalizer.fit(data, mode="limits", last_n_dims=2) + datan = normalizer.normalize(data) + assert datan.shape == data.shape + assert np.allclose(datan.max(), 1.0) + assert np.allclose(datan.min(), -1.0) + dataun = normalizer.unnormalize(datan) + assert torch.allclose(data, dataun, atol=1e-7) + + _ = normalizer.get_input_stats() + _ = normalizer.get_output_stats() + + normalizer = SingleFieldLinearNormalizer() + normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False) + datan = normalizer.normalize(data) + assert datan.shape == data.shape + assert np.allclose(datan.max(), 1.0, atol=1e-3) + assert np.allclose(datan.min(), 0.0, atol=1e-3) + dataun = normalizer.unnormalize(datan) + assert torch.allclose(data, dataun, atol=1e-7) + + data = torch.zeros((100, 10, 9, 2)).uniform_() + normalizer = SingleFieldLinearNormalizer() + normalizer.fit(data, mode="gaussian", last_n_dims=0) + datan = normalizer.normalize(data) + assert datan.shape == data.shape + assert np.allclose(datan.mean(), 0.0, atol=1e-3) + assert np.allclose(datan.std(), 1.0, atol=1e-3) + dataun = normalizer.unnormalize(datan) + assert torch.allclose(data, dataun, atol=1e-7) + + # dict + data = torch.zeros((100, 10, 9, 2)).uniform_() + data[..., 0, 0] = 0 + + normalizer = LinearNormalizer() + normalizer.fit(data, mode="limits", last_n_dims=2) + datan = normalizer.normalize(data) + assert datan.shape == data.shape + assert np.allclose(datan.max(), 1.0) + assert np.allclose(datan.min(), -1.0) + dataun = normalizer.unnormalize(datan) + assert torch.allclose(data, dataun, atol=1e-7) + + _ = normalizer.get_input_stats() + _ = normalizer.get_output_stats() + + data = { + "obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512, + "action": torch.zeros((1000, 128, 2)).uniform_() * 512, + } + normalizer = LinearNormalizer() + normalizer.fit(data) + datan = normalizer.normalize(data) + dataun = normalizer.unnormalize(datan) + for key in data: + assert torch.allclose(data[key], dataun[key], atol=1e-4) + + _ = normalizer.get_input_stats() + _ = normalizer.get_output_stats() + + state_dict = normalizer.state_dict() + n = LinearNormalizer() + n.load_state_dict(state_dict) + datan = n.normalize(data) + dataun = n.unnormalize(datan) + for key in data: + assert torch.allclose(data[key], dataun[key], atol=1e-4) diff --git a/lerobot/common/policies/diffusion/model/positional_embedding.py b/lerobot/common/policies/diffusion/model/positional_embedding.py new file mode 100644 index 000000000..65fc97bdd --- /dev/null +++ b/lerobot/common/policies/diffusion/model/positional_embedding.py @@ -0,0 +1,19 @@ +import math + +import torch +import torch.nn as nn + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb diff --git a/lerobot/common/policies/diffusion/model/tensor_utils.py b/lerobot/common/policies/diffusion/model/tensor_utils.py new file mode 100644 index 000000000..0801e2946 --- /dev/null +++ b/lerobot/common/policies/diffusion/model/tensor_utils.py @@ -0,0 +1,971 @@ +""" +A collection of utilities for working with nested tensor structures consisting +of numpy arrays and torch tensors. +""" +import collections + +import numpy as np +import torch + + +def recursive_dict_list_tuple_apply(x, type_func_dict): + """ + Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of + {data_type: function_to_apply}. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + type_func_dict (dict): a mapping from data types to the functions to be + applied for each data type. + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + assert list not in type_func_dict + assert tuple not in type_func_dict + assert dict not in type_func_dict + + if isinstance(x, (dict, collections.OrderedDict)): + new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {} + for k, v in x.items(): + new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict) + return new_x + elif isinstance(x, (list, tuple)): + ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + else: + for t, f in type_func_dict.items(): + if isinstance(x, t): + return f(x) + else: + raise NotImplementedError("Cannot handle data type %s" % str(type(x))) + + +def map_tensor(x, func): + """ + Apply function @func to torch.Tensor objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each tensor + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: func, + type(None): lambda x: x, + }, + ) + + +def map_ndarray(x, func): + """ + Apply function @func to np.ndarray objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + np.ndarray: func, + type(None): lambda x: x, + }, + ) + + +def map_tensor_ndarray(x, tensor_func, ndarray_func): + """ + Apply function @tensor_func to torch.Tensor objects and @ndarray_func to + np.ndarray objects in a nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + tensor_func (function): function to apply to each tensor + ndarray_Func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: tensor_func, + np.ndarray: ndarray_func, + type(None): lambda x: x, + }, + ) + + +def clone(x): + """ + Clones all torch tensors and numpy arrays in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.clone(), + np.ndarray: lambda x: x.copy(), + type(None): lambda x: x, + }, + ) + + +def detach(x): + """ + Detaches all torch tensors in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.detach(), + }, + ) + + +def to_batch(x): + """ + Introduces a leading batch dimension of 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[None, ...], + np.ndarray: lambda x: x[None, ...], + type(None): lambda x: x, + }, + ) + + +def to_sequence(x): + """ + Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, None, ...], + np.ndarray: lambda x: x[:, None, ...], + type(None): lambda x: x, + }, + ) + + +def index_at_time(x, ind): + """ + Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in + nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ind (int): index + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, ind, ...], + np.ndarray: lambda x: x[:, ind, ...], + type(None): lambda x: x, + }, + ) + + +def unsqueeze(x, dim): + """ + Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays + in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + dim (int): dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.unsqueeze(dim=dim), + np.ndarray: lambda x: np.expand_dims(x, axis=dim), + type(None): lambda x: x, + }, + ) + + +def contiguous(x): + """ + Makes all torch tensors and numpy arrays contiguous in nested dictionary or + list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.contiguous(), + np.ndarray: lambda x: np.ascontiguousarray(x), + type(None): lambda x: x, + }, + ) + + +def to_device(x, device): + """ + Sends all torch tensors in nested dictionary or list or tuple to device + @device, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, d=device: x.to(d), + type(None): lambda x: x, + }, + ) + + +def to_tensor(x): + """ + Converts all numpy arrays in nested dictionary or list or tuple to + torch tensors (and leaves existing torch Tensors as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x, + np.ndarray: lambda x: torch.from_numpy(x), + type(None): lambda x: x, + }, + ) + + +def to_numpy(x): + """ + Converts all torch tensors in nested dictionary or list or tuple to + numpy (and leaves existing numpy arrays as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy() + else: + return tensor.detach().numpy() + + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x, + type(None): lambda x: x, + }, + ) + + +def to_list(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to a list, and returns a new nested structure. Useful for + json encoding. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy().tolist() + else: + return tensor.detach().numpy().tolist() + + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x.tolist(), + type(None): lambda x: x, + }, + ) + + +def to_float(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to float type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.float(), + np.ndarray: lambda x: x.astype(np.float32), + type(None): lambda x: x, + }, + ) + + +def to_uint8(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to uint8 type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.byte(), + np.ndarray: lambda x: x.astype(np.uint8), + type(None): lambda x: x, + }, + ) + + +def to_torch(x, device): + """ + Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to + torch tensors on device @device and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return to_device(to_float(to_tensor(x)), device) + + +def to_one_hot_single(tensor, num_class): + """ + Convert tensor to one-hot representation, assuming a certain number of total class labels. + + Args: + tensor (torch.Tensor): tensor containing integer labels + num_class (int): number of classes + + Returns: + x (torch.Tensor): tensor containing one-hot representation of labels + """ + x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device) + x.scatter_(-1, tensor.unsqueeze(-1), 1) + return x + + +def to_one_hot(tensor, num_class): + """ + Convert all tensors in nested dictionary or list or tuple to one-hot representation, + assuming a certain number of total class labels. + + Args: + tensor (dict or list or tuple): a possibly nested dictionary or list or tuple + num_class (int): number of classes + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc)) + + +def flatten_single(x, begin_axis=1): + """ + Flatten a tensor in all dimensions from @begin_axis onwards. + + Args: + x (torch.Tensor): tensor to flatten + begin_axis (int): which axis to flatten from + + Returns: + y (torch.Tensor): flattened tensor + """ + fixed_size = x.size()[:begin_axis] + _s = list(fixed_size) + [-1] + return x.reshape(*_s) + + +def flatten(x, begin_axis=1): + """ + Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): which axis to flatten from + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b), + }, + ) + + +def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions in a tensor to a target dimension. + + Args: + x (torch.Tensor): tensor to reshape + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (torch.Tensor): reshaped tensor + """ + assert begin_axis <= end_axis + assert begin_axis >= 0 + assert end_axis < len(x.shape) + assert isinstance(target_dims, (tuple, list)) + s = x.shape + final_s = [] + for i in range(len(s)): + if i == begin_axis: + final_s.extend(target_dims) + elif i < begin_axis or i > end_axis: + final_s.append(s[i]) + return x.reshape(*final_s) + + +def reshape_dimensions(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions for all tensors in nested dictionary or list or tuple + to a target dimension. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t + ), + np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t + ), + type(None): lambda x: x, + }, + ) + + +def join_dimensions(x, begin_axis, end_axis): + """ + Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for + all tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1] + ), + np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1] + ), + type(None): lambda x: x, + }, + ) + + +def expand_at_single(x, size, dim): + """ + Expand a tensor at a single dimension @dim by @size + + Args: + x (torch.Tensor): input tensor + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (torch.Tensor): expanded tensor + """ + assert dim < x.ndimension() + assert x.shape[dim] == 1 + expand_dims = [-1] * x.ndimension() + expand_dims[dim] = size + return x.expand(*expand_dims) + + +def expand_at(x, size, dim): + """ + Expand all tensors in nested dictionary or list or tuple at a single + dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) + + +def unsqueeze_expand_at(x, size, dim): + """ + Unsqueeze and expand a tensor at a dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to unsqueeze and expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze(x, dim) + return expand_at(x, size, dim) + + +def repeat_by_expand_at(x, repeats, dim): + """ + Repeat a dimension by combining expand and reshape operations. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + repeats (int): number of times to repeat the target dimension + dim (int): dimension to repeat on + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze_expand_at(x, repeats, dim + 1) + return join_dimensions(x, dim, dim + 1) + + +def named_reduce_single(x, reduction, dim): + """ + Reduce tensor at a dimension by named reduction functions. + + Args: + x (torch.Tensor): tensor to be reduced + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (torch.Tensor): reduced tensor + """ + assert x.ndimension() > dim + assert reduction in ["sum", "max", "mean", "flatten"] + if reduction == "flatten": + x = flatten(x, begin_axis=dim) + elif reduction == "max": + x = torch.max(x, dim=dim)[0] # [B, D] + elif reduction == "sum": + x = torch.sum(x, dim=dim) + else: + x = torch.mean(x, dim=dim) + return x + + +def named_reduce(x, reduction, dim): + """ + Reduces all tensors in nested dictionary or list or tuple at a dimension + using a named reduction function. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d)) + + +def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): + """ + This function indexes out a target dimension of a tensor in a structured way, + by allowing a different value to be selected for each member of a flat index + tensor (@indices) corresponding to a source dimension. This can be interpreted + as moving along the source dimension, using the corresponding index value + in @indices to select values for all other dimensions outside of the + source and target dimensions. A common use case is to gather values + in target dimension 1 for each batch member (target dimension 0). + + Args: + x (torch.Tensor): tensor to gather values for + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out + """ + assert len(indices.shape) == 1 + assert x.shape[source_dim] == indices.shape[0] + + # unsqueeze in all dimensions except the source dimension + new_shape = [1] * x.ndimension() + new_shape[source_dim] = -1 + indices = indices.reshape(*new_shape) + + # repeat in all dimensions - but preserve shape of source dimension, + # and make sure target_dimension has singleton dimension + expand_shape = list(x.shape) + expand_shape[source_dim] = -1 + expand_shape[target_dim] = 1 + indices = indices.expand(*expand_shape) + + out = x.gather(dim=target_dim, index=indices) + return out.squeeze(target_dim) + + +def gather_along_dim_with_dim(x, target_dim, source_dim, indices): + """ + Apply @gather_along_dim_with_dim_single to all tensors in a nested + dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor( + x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i) + ) + + +def gather_sequence_single(seq, indices): + """ + Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in + the batch given an index for each sequence. + + Args: + seq (torch.Tensor): tensor with leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Return: + y (torch.Tensor): indexed tensor of shape [B, ....] + """ + return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices) + + +def gather_sequence(seq, indices): + """ + Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch + for tensors with leading dimensions [B, T, ...]. + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Returns: + y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] + """ + return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices) + + +def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None): + """ + Pad input tensor or array @seq in the time dimension (dimension 1). + + Args: + seq (np.ndarray or torch.Tensor): sequence to be padded + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (np.ndarray or torch.Tensor) + """ + assert isinstance(seq, (np.ndarray, torch.Tensor)) + assert pad_same or pad_values is not None + if pad_values is not None: + assert isinstance(pad_values, float) + repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave + concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat + ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like + seq_dim = 1 if batched else 0 + + begin_pad = [] + end_pad = [] + + if padding[0] > 0: + pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values + begin_pad.append(repeat_func(pad, padding[0], seq_dim)) + if padding[1] > 0: + pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values + end_pad.append(repeat_func(pad, padding[1], seq_dim)) + + return concat_func(begin_pad + [seq] + end_pad, seq_dim) + + +def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None): + """ + Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (dict or list or tuple) + """ + return recursive_dict_list_tuple_apply( + seq, + { + torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( + x, p, b, ps, pv + ), + np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( + x, p, b, ps, pv + ), + type(None): lambda x: x, + }, + ) + + +def assert_size_at_dim_single(x, size, dim, msg): + """ + Ensure that array or tensor @x has size @size in dim @dim. + + Args: + x (np.ndarray or torch.Tensor): input array or tensor + size (int): size that tensors should have at @dim + dim (int): dimension to check + msg (str): text to display if assertion fails + """ + assert x.shape[dim] == size, msg + + +def assert_size_at_dim(x, size, dim, msg): + """ + Ensure that arrays and tensors in nested dictionary or list or tuple have + size @size in dim @dim. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size that tensors should have at @dim + dim (int): dimension to check + """ + map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) + + +def get_shape(x): + """ + Get all shapes of arrays and tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple that contains each array or + tensor's shape + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.shape, + np.ndarray: lambda x: x.shape, + type(None): lambda x: x, + }, + ) + + +def list_of_flat_dict_to_dict_of_list(list_of_dict): + """ + Helper function to go from a list of flat dictionaries to a dictionary of lists. + By "flat" we mean that none of the values are dictionaries, but are numpy arrays, + floats, etc. + + Args: + list_of_dict (list): list of flat dictionaries + + Returns: + dict_of_list (dict): dictionary of lists + """ + assert isinstance(list_of_dict, list) + dic = collections.OrderedDict() + for i in range(len(list_of_dict)): + for k in list_of_dict[i]: + if k not in dic: + dic[k] = [] + dic[k].append(list_of_dict[i][k]) + return dic + + +def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""): + """ + Flatten a nested dict or list to a list. + + For example, given a dict + { + a: 1 + b: { + c: 2 + } + c: 3 + } + + the function would return [(a, 1), (b_c, 2), (c, 3)] + + Args: + d (dict, list): a nested dict or list to be flattened + parent_key (str): recursion helper + sep (str): separator for nesting keys + item_key (str): recursion helper + Returns: + list: a list of (key, value) tuples + """ + items = [] + if isinstance(d, (tuple, list)): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for i, v in enumerate(d): + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) + return items + elif isinstance(d, dict): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for k, v in d.items(): + assert isinstance(k, str) + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) + return items + else: + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + return [(new_key, d)] + + +def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs): + """ + Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the + batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. + Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping + outputs to [B, T, ...]. + + Args: + inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + op: a layer op that accepts inputs + activation: activation to apply at the output + inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op + inputs_as_args (bool) whether to feed input as a args list to the op + kwargs (dict): other kwargs to supply to the op + + Returns: + outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. + """ + batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] + inputs = join_dimensions(inputs, 0, 1) + if inputs_as_kwargs: + outputs = op(**inputs, **kwargs) + elif inputs_as_args: + outputs = op(*inputs, **kwargs) + else: + outputs = op(inputs, **kwargs) + + if activation is not None: + outputs = map_tensor(outputs, activation) + outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len)) + return outputs diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 37bc79a08..3df76aa4a 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -4,10 +4,10 @@ import hydra import torch import torch.nn as nn -from diffusion_policy.model.common.lr_scheduler import get_scheduler -from .diffusion_unet_image_policy import DiffusionUnetImagePolicy -from .multi_image_obs_encoder import MultiImageObsEncoder +from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy +from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler +from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder class DiffusionPolicy(nn.Module): diff --git a/lerobot/common/policies/diffusion/pytorch_utils.py b/lerobot/common/policies/diffusion/pytorch_utils.py new file mode 100644 index 000000000..c1444b062 --- /dev/null +++ b/lerobot/common/policies/diffusion/pytorch_utils.py @@ -0,0 +1,46 @@ +from typing import Callable, Dict + +import torch +import torch.nn as nn + + +def dict_apply( + x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] +) -> Dict[str, torch.Tensor]: + result = {} + for key, value in x.items(): + if isinstance(value, dict): + result[key] = dict_apply(value, func) + else: + result[key] = func(value) + return result + + +def replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + predicate: Return true if the module is to be replaced. + func: Return new module to use. + """ + if predicate(root_module): + return func(root_module) + + bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parent, k in bn_list: + parent_module = root_module + if len(parent) > 0: + parent_module = root_module.get_submodule(".".join(parent)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + assert len(bn_list) == 0 + return root_module diff --git a/lerobot/common/policies/diffusion/replay_buffer.py b/lerobot/common/policies/diffusion/replay_buffer.py new file mode 100644 index 000000000..7fccf74df --- /dev/null +++ b/lerobot/common/policies/diffusion/replay_buffer.py @@ -0,0 +1,614 @@ +from __future__ import annotations + +import math +import numbers +import os +from functools import cached_property + +import numcodecs +import numpy as np +import zarr + + +def check_chunks_compatible(chunks: tuple, shape: tuple): + assert len(shape) == len(chunks) + for c in chunks: + assert isinstance(c, numbers.Integral) + assert c > 0 + + +def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"): + old_arr = group[name] + if chunks is None: + chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks + check_chunks_compatible(chunks, old_arr.shape) + + if compressor is None: + compressor = old_arr.compressor + + if (chunks == old_arr.chunks) and (compressor == old_arr.compressor): + # no change + return old_arr + + # rechunk recompress + group.move(name, tmp_key) + old_arr = group[tmp_key] + n_copied, n_skipped, n_bytes_copied = zarr.copy( + source=old_arr, + dest=group, + name=name, + chunks=chunks, + compressor=compressor, + ) + del group[tmp_key] + arr = group[name] + return arr + + +def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None): + """ + Common shapes + T,D + T,N,D + T,H,W,C + T,N,H,W,C + """ + itemsize = np.dtype(dtype).itemsize + # reversed + rshape = list(shape[::-1]) + if max_chunk_length is not None: + rshape[-1] = int(max_chunk_length) + split_idx = len(shape) - 1 + for i in range(len(shape) - 1): + this_chunk_bytes = itemsize * np.prod(rshape[:i]) + next_chunk_bytes = itemsize * np.prod(rshape[: i + 1]) + if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes: + split_idx = i + + rchunks = rshape[:split_idx] + item_chunk_bytes = itemsize * np.prod(rshape[:split_idx]) + this_max_chunk_length = rshape[split_idx] + next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)) + rchunks.append(next_chunk_length) + len_diff = len(shape) - len(rchunks) + rchunks.extend([1] * len_diff) + chunks = tuple(rchunks[::-1]) + # print(np.prod(chunks) * itemsize / target_chunk_bytes) + return chunks + + +class ReplayBuffer: + """ + Zarr-based temporal datastructure. + Assumes first dimension to be time. Only chunk in time dimension. + """ + + def __init__(self, root: zarr.Group | dict[str, dict]): + """ + Dummy constructor. Use copy_from* and create_from* class methods instead. + """ + assert "data" in root + assert "meta" in root + assert "episode_ends" in root["meta"] + for value in root["data"].values(): + assert value.shape[0] == root["meta"]["episode_ends"][-1] + self.root = root + + # ============= create constructors =============== + @classmethod + def create_empty_zarr(cls, storage=None, root=None): + if root is None: + if storage is None: + storage = zarr.MemoryStore() + root = zarr.group(store=storage) + root.require_group("data", overwrite=False) + meta = root.require_group("meta", overwrite=False) + if "episode_ends" not in meta: + meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False) + return cls(root=root) + + @classmethod + def create_empty_numpy(cls): + root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}} + return cls(root=root) + + @classmethod + def create_from_group(cls, group, **kwargs): + if "data" not in group: + # create from stratch + buffer = cls.create_empty_zarr(root=group, **kwargs) + else: + # already exist + buffer = cls(root=group, **kwargs) + return buffer + + @classmethod + def create_from_path(cls, zarr_path, mode="r", **kwargs): + """ + Open a on-disk zarr directly (for dataset larger than memory). + Slower. + """ + group = zarr.open(os.path.expanduser(zarr_path), mode) + return cls.create_from_group(group, **kwargs) + + # ============= copy constructors =============== + @classmethod + def copy_from_store( + cls, + src_store, + store=None, + keys=None, + chunks: dict[str, tuple] | None = None, + compressors: dict | str | numcodecs.abc.Codec | None = None, + if_exists="replace", + **kwargs, + ): + """ + Load to memory. + """ + src_root = zarr.group(src_store) + if chunks is None: + chunks = {} + if compressors is None: + compressors = {} + root = None + if store is None: + # numpy backend + meta = {} + for key, value in src_root["meta"].items(): + if len(value.shape) == 0: + meta[key] = np.array(value) + else: + meta[key] = value[:] + + if keys is None: + keys = src_root["data"].keys() + data = {} + for key in keys: + arr = src_root["data"][key] + data[key] = arr[:] + + root = {"meta": meta, "data": data} + else: + root = zarr.group(store=store) + # copy without recompression + n_copied, n_skipped, n_bytes_copied = zarr.copy_store( + source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists + ) + data_group = root.create_group("data", overwrite=True) + if keys is None: + keys = src_root["data"].keys() + for key in keys: + value = src_root["data"][key] + cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value) + cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value) + if cks == value.chunks and cpr == value.compressor: + # copy without recompression + this_path = "/data/" + key + n_copied, n_skipped, n_bytes_copied = zarr.copy_store( + source=src_store, + dest=store, + source_path=this_path, + dest_path=this_path, + if_exists=if_exists, + ) + else: + # copy with recompression + n_copied, n_skipped, n_bytes_copied = zarr.copy( + source=value, + dest=data_group, + name=key, + chunks=cks, + compressor=cpr, + if_exists=if_exists, + ) + buffer = cls(root=root) + return buffer + + @classmethod + def copy_from_path( + cls, + zarr_path, + backend=None, + store=None, + keys=None, + chunks: dict[str, tuple] | None = None, + compressors: dict | str | numcodecs.abc.Codec | None = None, + if_exists="replace", + **kwargs, + ): + """ + Copy a on-disk zarr to in-memory compressed. + Recommended + """ + if chunks is None: + chunks = {} + if compressors is None: + compressors = {} + if backend == "numpy": + print("backend argument is deprecated!") + store = None + group = zarr.open(os.path.expanduser(zarr_path), "r") + return cls.copy_from_store( + src_store=group.store, + store=store, + keys=keys, + chunks=chunks, + compressors=compressors, + if_exists=if_exists, + **kwargs, + ) + + # ============= save methods =============== + def save_to_store( + self, + store, + chunks: dict[str, tuple] | None = None, + compressors: str | numcodecs.abc.Codec | dict | None = None, + if_exists="replace", + **kwargs, + ): + root = zarr.group(store) + if chunks is None: + chunks = {} + if compressors is None: + compressors = {} + if self.backend == "zarr": + # recompression free copy + n_copied, n_skipped, n_bytes_copied = zarr.copy_store( + source=self.root.store, + dest=store, + source_path="/meta", + dest_path="/meta", + if_exists=if_exists, + ) + else: + meta_group = root.create_group("meta", overwrite=True) + # save meta, no chunking + for key, value in self.root["meta"].items(): + _ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape) + + # save data, chunk + data_group = root.create_group("data", overwrite=True) + for key, value in self.root["data"].items(): + cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value) + cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value) + if isinstance(value, zarr.Array): + if cks == value.chunks and cpr == value.compressor: + # copy without recompression + this_path = "/data/" + key + n_copied, n_skipped, n_bytes_copied = zarr.copy_store( + source=self.root.store, + dest=store, + source_path=this_path, + dest_path=this_path, + if_exists=if_exists, + ) + else: + # copy with recompression + n_copied, n_skipped, n_bytes_copied = zarr.copy( + source=value, + dest=data_group, + name=key, + chunks=cks, + compressor=cpr, + if_exists=if_exists, + ) + else: + # numpy + _ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr) + return store + + def save_to_path( + self, + zarr_path, + chunks: dict[str, tuple] | None = None, + compressors: str | numcodecs.abc.Codec | dict | None = None, + if_exists="replace", + **kwargs, + ): + if chunks is None: + chunks = {} + if compressors is None: + compressors = {} + store = zarr.DirectoryStore(os.path.expanduser(zarr_path)) + return self.save_to_store( + store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs + ) + + @staticmethod + def resolve_compressor(compressor="default"): + if compressor == "default": + compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE) + elif compressor == "disk": + compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE) + return compressor + + @classmethod + def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array): + # allows compressor to be explicitly set to None + cpr = "nil" + if isinstance(compressors, dict): + if key in compressors: + cpr = cls.resolve_compressor(compressors[key]) + elif isinstance(array, zarr.Array): + cpr = array.compressor + else: + cpr = cls.resolve_compressor(compressors) + # backup default + if cpr == "nil": + cpr = cls.resolve_compressor("default") + return cpr + + @classmethod + def _resolve_array_chunks(cls, chunks: dict | tuple, key, array): + cks = None + if isinstance(chunks, dict): + if key in chunks: + cks = chunks[key] + elif isinstance(array, zarr.Array): + cks = array.chunks + elif isinstance(chunks, tuple): + cks = chunks + else: + raise TypeError(f"Unsupported chunks type {type(chunks)}") + # backup default + if cks is None: + cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype) + # check + check_chunks_compatible(chunks=cks, shape=array.shape) + return cks + + # ============= properties ================= + @cached_property + def data(self): + return self.root["data"] + + @cached_property + def meta(self): + return self.root["meta"] + + def update_meta(self, data): + # sanitize data + np_data = {} + for key, value in data.items(): + if isinstance(value, np.ndarray): + np_data[key] = value + else: + arr = np.array(value) + if arr.dtype == object: + raise TypeError(f"Invalid value type {type(value)}") + np_data[key] = arr + + meta_group = self.meta + if self.backend == "zarr": + for key, value in np_data.items(): + _ = meta_group.array( + name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True + ) + else: + meta_group.update(np_data) + + return meta_group + + @property + def episode_ends(self): + return self.meta["episode_ends"] + + def get_episode_idxs(self): + import numba + + numba.jit(nopython=True) + + def _get_episode_idxs(episode_ends): + result = np.zeros((episode_ends[-1],), dtype=np.int64) + for i in range(len(episode_ends)): + start = 0 + if i > 0: + start = episode_ends[i - 1] + end = episode_ends[i] + for idx in range(start, end): + result[idx] = i + return result + + return _get_episode_idxs(self.episode_ends) + + @property + def backend(self): + backend = "numpy" + if isinstance(self.root, zarr.Group): + backend = "zarr" + return backend + + # =========== dict-like API ============== + def __repr__(self) -> str: + if self.backend == "zarr": + return str(self.root.tree()) + else: + return super().__repr__() + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + def __getitem__(self, key): + return self.data[key] + + def __contains__(self, key): + return key in self.data + + # =========== our API ============== + @property + def n_steps(self): + if len(self.episode_ends) == 0: + return 0 + return self.episode_ends[-1] + + @property + def n_episodes(self): + return len(self.episode_ends) + + @property + def chunk_size(self): + if self.backend == "zarr": + return next(iter(self.data.arrays()))[-1].chunks[0] + return None + + @property + def episode_lengths(self): + ends = self.episode_ends[:] + ends = np.insert(ends, 0, 0) + lengths = np.diff(ends) + return lengths + + def add_episode( + self, + data: dict[str, np.ndarray], + chunks: dict[str, tuple] | None = None, + compressors: str | numcodecs.abc.Codec | dict | None = None, + ): + if chunks is None: + chunks = {} + if compressors is None: + compressors = {} + assert len(data) > 0 + is_zarr = self.backend == "zarr" + + curr_len = self.n_steps + episode_length = None + for value in data.values(): + assert len(value.shape) >= 1 + if episode_length is None: + episode_length = len(value) + else: + assert episode_length == len(value) + new_len = curr_len + episode_length + + for key, value in data.items(): + new_shape = (new_len,) + value.shape[1:] + # create array + if key not in self.data: + if is_zarr: + cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value) + cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value) + arr = self.data.zeros( + name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr + ) + else: + # copy data to prevent modify + arr = np.zeros(shape=new_shape, dtype=value.dtype) + self.data[key] = arr + else: + arr = self.data[key] + assert value.shape[1:] == arr.shape[1:] + # same method for both zarr and numpy + if is_zarr: + arr.resize(new_shape) + else: + arr.resize(new_shape, refcheck=False) + # copy data + arr[-value.shape[0] :] = value + + # append to episode ends + episode_ends = self.episode_ends + if is_zarr: + episode_ends.resize(episode_ends.shape[0] + 1) + else: + episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False) + episode_ends[-1] = new_len + + # rechunk + if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]: + rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)) + + def drop_episode(self): + is_zarr = self.backend == "zarr" + episode_ends = self.episode_ends[:].copy() + assert len(episode_ends) > 0 + start_idx = 0 + if len(episode_ends) > 1: + start_idx = episode_ends[-2] + for value in self.data.values(): + new_shape = (start_idx,) + value.shape[1:] + if is_zarr: + value.resize(new_shape) + else: + value.resize(new_shape, refcheck=False) + if is_zarr: + self.episode_ends.resize(len(episode_ends) - 1) + else: + self.episode_ends.resize(len(episode_ends) - 1, refcheck=False) + + def pop_episode(self): + assert self.n_episodes > 0 + episode = self.get_episode(self.n_episodes - 1, copy=True) + self.drop_episode() + return episode + + def extend(self, data): + self.add_episode(data) + + def get_episode(self, idx, copy=False): + idx = list(range(len(self.episode_ends)))[idx] + start_idx = 0 + if idx > 0: + start_idx = self.episode_ends[idx - 1] + end_idx = self.episode_ends[idx] + result = self.get_steps_slice(start_idx, end_idx, copy=copy) + return result + + def get_episode_slice(self, idx): + start_idx = 0 + if idx > 0: + start_idx = self.episode_ends[idx - 1] + end_idx = self.episode_ends[idx] + return slice(start_idx, end_idx) + + def get_steps_slice(self, start, stop, step=None, copy=False): + _slice = slice(start, stop, step) + + result = {} + for key, value in self.data.items(): + x = value[_slice] + if copy and isinstance(value, np.ndarray): + x = x.copy() + result[key] = x + return result + + # =========== chunking ============= + def get_chunks(self) -> dict: + assert self.backend == "zarr" + chunks = {} + for key, value in self.data.items(): + chunks[key] = value.chunks + return chunks + + def set_chunks(self, chunks: dict): + assert self.backend == "zarr" + for key, value in chunks.items(): + if key in self.data: + arr = self.data[key] + if value != arr.chunks: + check_chunks_compatible(chunks=value, shape=arr.shape) + rechunk_recompress_array(self.data, key, chunks=value) + + def get_compressors(self) -> dict: + assert self.backend == "zarr" + compressors = {} + for key, value in self.data.items(): + compressors[key] = value.compressor + return compressors + + def set_compressors(self, compressors: dict): + assert self.backend == "zarr" + for key, value in compressors.items(): + if key in self.data: + arr = self.data[key] + compressor = self.resolve_compressor(value) + if compressor != arr.compressor: + rechunk_recompress_array(self.data, key, compressor=compressor) From 6c867d78efe64ad0bc78d921d9f0ef1eb4a8f629 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 10 Mar 2024 16:33:03 +0100 Subject: [PATCH 3/6] Integrate pusht env from diffusion --- lerobot/common/datasets/pusht.py | 20 +- lerobot/common/envs/factory.py | 2 +- .../common/envs/{pusht.py => pusht/env.py} | 7 +- lerobot/common/envs/pusht/pusht_env.py | 378 ++++++++++++++++++ lerobot/common/envs/pusht/pusht_image_env.py | 55 +++ lerobot/common/envs/pusht/pymunk_override.py | 244 +++++++++++ .../policies/diffusion/model/ema_model.py | 84 ++++ .../policies/diffusion/pytorch_utils.py | 30 ++ lerobot/configs/policy/diffusion.yaml | 5 +- tests/test_envs.py | 2 +- 10 files changed, 801 insertions(+), 26 deletions(-) rename lerobot/common/envs/{pusht.py => pusht/env.py} (96%) create mode 100644 lerobot/common/envs/pusht/pusht_env.py create mode 100644 lerobot/common/envs/pusht/pusht_image_env.py create mode 100644 lerobot/common/envs/pusht/pymunk_override.py create mode 100644 lerobot/common/policies/diffusion/model/ema_model.py diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 835f4cb59..ae987ad1e 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -5,7 +5,6 @@ import numpy as np import pygame import pymunk -import shapely.geometry as sg import torch import torchrl import tqdm @@ -16,29 +15,16 @@ from lerobot.common.datasets.abstract import AbstractExperienceReplay from lerobot.common.datasets.utils import download_and_extract_zip +from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, -DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS() PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip" PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr") -def pymunk_to_shapely(body, shapes): - geoms = [] - for shape in shapes: - if isinstance(shape, pymunk.shapes.Poly): - verts = [body.local_to_world(v) for v in shape.get_vertices()] - verts += [verts[0]] - geoms.append(sg.Polygon(verts)) - else: - raise RuntimeError(f"Unsupported shape type {type(shape)}") - geom = sg.MultiPolygon(geoms) - return geom - - def get_goal_pose_body(pose): mass = 1 inertia = pymunk.moment_for_box(mass, (50, 100)) @@ -62,8 +48,10 @@ def add_tee( angle, scale=30, color="LightSlateGray", - mask=DEFAULT_TEE_MASK, + mask=None, ): + if mask is None: + mask = pymunk.ShapeFilter.ALL_MASKS() mass = 1 length = 4 vertices1 = [ diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 984b866a4..35ebfa4a9 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -18,7 +18,7 @@ def make_env(cfg, transform=None): kwargs["task"] = cfg.env.task clsfunc = SimxarmEnv elif cfg.env.name == "pusht": - from lerobot.common.envs.pusht import PushtEnv + from lerobot.common.envs.pusht.pusht import PushtEnv # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." diff --git a/lerobot/common/envs/pusht.py b/lerobot/common/envs/pusht/env.py similarity index 96% rename from lerobot/common/envs/pusht.py rename to lerobot/common/envs/pusht/env.py index 7f9a3f63c..ff49f791e 100644 --- a/lerobot/common/envs/pusht.py +++ b/lerobot/common/envs/pusht/env.py @@ -17,7 +17,6 @@ from lerobot.common.utils import set_seed _has_gym = importlib.util.find_spec("gym") is not None -_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym class PushtEnv(EnvBase): @@ -45,17 +44,15 @@ def __init__( if from_pixels: assert image_size - if not _has_diffpolicy: - raise ImportError("Cannot import diffusion_policy.") if not _has_gym: raise ImportError("Cannot import gym.") # TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on) - # from diffusion_policy.env.pusht.pusht_env import PushTEnv + # from lerobot.common.envs.pusht.pusht_env import PushTEnv if not from_pixels: raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv") - from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv + from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv self._env = PushTImageEnv(render_size=self.image_size) diff --git a/lerobot/common/envs/pusht/pusht_env.py b/lerobot/common/envs/pusht/pusht_env.py new file mode 100644 index 000000000..690bfe12d --- /dev/null +++ b/lerobot/common/envs/pusht/pusht_env.py @@ -0,0 +1,378 @@ +import collections + +import cv2 +import gym +import numpy as np +import pygame +import pymunk +import pymunk.pygame_util +import shapely.geometry as sg +import skimage.transform as st +from gym import spaces +from pymunk.vec2d import Vec2d + +from lerobot.common.envs.pusht.pymunk_override import DrawOptions + + +def pymunk_to_shapely(body, shapes): + geoms = [] + for shape in shapes: + if isinstance(shape, pymunk.shapes.Poly): + verts = [body.local_to_world(v) for v in shape.get_vertices()] + verts += [verts[0]] + geoms.append(sg.Polygon(verts)) + else: + raise RuntimeError(f"Unsupported shape type {type(shape)}") + geom = sg.MultiPolygon(geoms) + return geom + + +class PushTEnv(gym.Env): + metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10} + reward_range = (0.0, 1.0) + + def __init__( + self, + legacy=False, + block_cog=None, + damping=None, + render_action=True, + render_size=96, + reset_to_state=None, + ): + self._seed = None + self.seed() + self.window_size = ws = 512 # The size of the PyGame window + self.render_size = render_size + self.sim_hz = 100 + # Local controller params. + self.k_p, self.k_v = 100, 20 # PD control.z + self.control_hz = self.metadata["video.frames_per_second"] + # legcay set_state for data compatibility + self.legacy = legacy + + # agent_pos, block_pos, block_angle + self.observation_space = spaces.Box( + low=np.array([0, 0, 0, 0, 0], dtype=np.float64), + high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64), + shape=(5,), + dtype=np.float64, + ) + + # positional goal for agent + self.action_space = spaces.Box( + low=np.array([0, 0], dtype=np.float64), + high=np.array([ws, ws], dtype=np.float64), + shape=(2,), + dtype=np.float64, + ) + + self.block_cog = block_cog + self.damping = damping + self.render_action = render_action + + """ + If human-rendering is used, `self.window` will be a reference + to the window that we draw to. `self.clock` will be a clock that is used + to ensure that the environment is rendered at the correct framerate in + human-mode. They will remain `None` until human-mode is used for the + first time. + """ + self.window = None + self.clock = None + self.screen = None + + self.space = None + self.teleop = None + self.render_buffer = None + self.latest_action = None + self.reset_to_state = reset_to_state + + def reset(self): + seed = self._seed + self._setup() + if self.block_cog is not None: + self.block.center_of_gravity = self.block_cog + if self.damping is not None: + self.space.damping = self.damping + + # use legacy RandomState for compatibility + state = self.reset_to_state + if state is None: + rs = np.random.RandomState(seed=seed) + state = np.array( + [ + rs.randint(50, 450), + rs.randint(50, 450), + rs.randint(100, 400), + rs.randint(100, 400), + rs.randn() * 2 * np.pi - np.pi, + ] + ) + self._set_state(state) + + observation = self._get_obs() + return observation + + def step(self, action): + dt = 1.0 / self.sim_hz + self.n_contact_points = 0 + n_steps = self.sim_hz // self.control_hz + if action is not None: + self.latest_action = action + for _ in range(n_steps): + # Step PD control. + # self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too. + acceleration = self.k_p * (action - self.agent.position) + self.k_v * ( + Vec2d(0, 0) - self.agent.velocity + ) + self.agent.velocity += acceleration * dt + + # Step physics. + self.space.step(dt) + + # compute reward + goal_body = self._get_goal_pose_body(self.goal_pose) + goal_geom = pymunk_to_shapely(goal_body, self.block.shapes) + block_geom = pymunk_to_shapely(self.block, self.block.shapes) + + intersection_area = goal_geom.intersection(block_geom).area + goal_area = goal_geom.area + coverage = intersection_area / goal_area + reward = np.clip(coverage / self.success_threshold, 0, 1) + done = coverage > self.success_threshold + + observation = self._get_obs() + info = self._get_info() + + return observation, reward, done, info + + def render(self, mode): + return self._render_frame(mode) + + def teleop_agent(self): + TeleopAgent = collections.namedtuple("TeleopAgent", ["act"]) + + def act(obs): + act = None + mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen) + if self.teleop or (mouse_position - self.agent.position).length < 30: + self.teleop = True + act = mouse_position + return act + + return TeleopAgent(act) + + def _get_obs(self): + obs = np.array( + tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),) + ) + return obs + + def _get_goal_pose_body(self, pose): + mass = 1 + inertia = pymunk.moment_for_box(mass, (50, 100)) + body = pymunk.Body(mass, inertia) + # preserving the legacy assignment order for compatibility + # the order here doesn't matter somehow, maybe because CoM is aligned with body origin + body.position = pose[:2].tolist() + body.angle = pose[2] + return body + + def _get_info(self): + n_steps = self.sim_hz // self.control_hz + n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps)) + info = { + "pos_agent": np.array(self.agent.position), + "vel_agent": np.array(self.agent.velocity), + "block_pose": np.array(list(self.block.position) + [self.block.angle]), + "goal_pose": self.goal_pose, + "n_contacts": n_contact_points_per_step, + } + return info + + def _render_frame(self, mode): + if self.window is None and mode == "human": + pygame.init() + pygame.display.init() + self.window = pygame.display.set_mode((self.window_size, self.window_size)) + if self.clock is None and mode == "human": + self.clock = pygame.time.Clock() + + canvas = pygame.Surface((self.window_size, self.window_size)) + canvas.fill((255, 255, 255)) + self.screen = canvas + + draw_options = DrawOptions(canvas) + + # Draw goal pose. + goal_body = self._get_goal_pose_body(self.goal_pose) + for shape in self.block.shapes: + goal_points = [ + pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) + for v in shape.get_vertices() + ] + goal_points += [goal_points[0]] + pygame.draw.polygon(canvas, self.goal_color, goal_points) + + # Draw agent and block. + self.space.debug_draw(draw_options) + + if mode == "human": + # The following line copies our drawings from `canvas` to the visible window + self.window.blit(canvas, canvas.get_rect()) + pygame.event.pump() + pygame.display.update() + + # the clock is already ticked during in step for "human" + + img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)) + img = cv2.resize(img, (self.render_size, self.render_size)) + if self.render_action and self.latest_action is not None: + action = np.array(self.latest_action) + coord = (action / 512 * 96).astype(np.int32) + marker_size = int(8 / 96 * self.render_size) + thickness = int(1 / 96 * self.render_size) + cv2.drawMarker( + img, + coord, + color=(255, 0, 0), + markerType=cv2.MARKER_CROSS, + markerSize=marker_size, + thickness=thickness, + ) + return img + + def close(self): + if self.window is not None: + pygame.display.quit() + pygame.quit() + + def seed(self, seed=None): + if seed is None: + seed = np.random.randint(0, 25536) + self._seed = seed + self.np_random = np.random.default_rng(seed) + + def _handle_collision(self, arbiter, space, data): + self.n_contact_points += len(arbiter.contact_point_set.points) + + def _set_state(self, state): + if isinstance(state, np.ndarray): + state = state.tolist() + pos_agent = state[:2] + pos_block = state[2:4] + rot_block = state[4] + self.agent.position = pos_agent + # setting angle rotates with respect to center of mass + # therefore will modify the geometric position + # if not the same as CoM + # therefore should be modified first. + if self.legacy: + # for compatibility with legacy data + self.block.position = pos_block + self.block.angle = rot_block + else: + self.block.angle = rot_block + self.block.position = pos_block + + # Run physics to take effect + self.space.step(1.0 / self.sim_hz) + + def _set_state_local(self, state_local): + agent_pos_local = state_local[:2] + block_pose_local = state_local[2:] + tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2]) + tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2]) + tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params) + agent_pos_new = tf_img_new(agent_pos_local) + new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation]) + self._set_state(new_state) + return new_state + + def _setup(self): + self.space = pymunk.Space() + self.space.gravity = 0, 0 + self.space.damping = 0 + self.teleop = False + self.render_buffer = [] + + # Add walls. + walls = [ + self._add_segment((5, 506), (5, 5), 2), + self._add_segment((5, 5), (506, 5), 2), + self._add_segment((506, 5), (506, 506), 2), + self._add_segment((5, 506), (506, 506), 2), + ] + self.space.add(*walls) + + # Add agent, block, and goal zone. + self.agent = self.add_circle((256, 400), 15) + self.block = self.add_tee((256, 300), 0) + self.goal_color = pygame.Color("LightGreen") + self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) + + # Add collision handling + self.collision_handeler = self.space.add_collision_handler(0, 0) + self.collision_handeler.post_solve = self._handle_collision + self.n_contact_points = 0 + + self.max_score = 50 * 100 + self.success_threshold = 0.95 # 95% coverage. + + def _add_segment(self, a, b, radius): + shape = pymunk.Segment(self.space.static_body, a, b, radius) + shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names + return shape + + def add_circle(self, position, radius): + body = pymunk.Body(body_type=pymunk.Body.KINEMATIC) + body.position = position + body.friction = 1 + shape = pymunk.Circle(body, radius) + shape.color = pygame.Color("RoyalBlue") + self.space.add(body, shape) + return body + + def add_box(self, position, height, width): + mass = 1 + inertia = pymunk.moment_for_box(mass, (height, width)) + body = pymunk.Body(mass, inertia) + body.position = position + shape = pymunk.Poly.create_box(body, (height, width)) + shape.color = pygame.Color("LightSlateGray") + self.space.add(body, shape) + return body + + def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None): + if mask is None: + mask = pymunk.ShapeFilter.ALL_MASKS() + mass = 1 + length = 4 + vertices1 = [ + (-length * scale / 2, scale), + (length * scale / 2, scale), + (length * scale / 2, 0), + (-length * scale / 2, 0), + ] + inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) + vertices2 = [ + (-scale / 2, scale), + (-scale / 2, length * scale), + (scale / 2, length * scale), + (scale / 2, scale), + ] + inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) + body = pymunk.Body(mass, inertia1 + inertia2) + shape1 = pymunk.Poly(body, vertices1) + shape2 = pymunk.Poly(body, vertices2) + shape1.color = pygame.Color(color) + shape2.color = pygame.Color(color) + shape1.filter = pymunk.ShapeFilter(mask=mask) + shape2.filter = pymunk.ShapeFilter(mask=mask) + body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 + body.position = position + body.angle = angle + body.friction = 1 + self.space.add(body, shape1, shape2) + return body diff --git a/lerobot/common/envs/pusht/pusht_image_env.py b/lerobot/common/envs/pusht/pusht_image_env.py new file mode 100644 index 000000000..5f7bc03c8 --- /dev/null +++ b/lerobot/common/envs/pusht/pusht_image_env.py @@ -0,0 +1,55 @@ +import cv2 +import numpy as np +from gym import spaces + +from lerobot.common.envs.pusht.pusht_env import PushTEnv + + +class PushTImageEnv(PushTEnv): + metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10} + + def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96): + super().__init__( + legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False + ) + ws = self.window_size + self.observation_space = spaces.Dict( + { + "image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32), + "agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32), + } + ) + self.render_cache = None + + def _get_obs(self): + img = super()._render_frame(mode="rgb_array") + + agent_pos = np.array(self.agent.position) + img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) + obs = {"image": img_obs, "agent_pos": agent_pos} + + # draw action + if self.latest_action is not None: + action = np.array(self.latest_action) + coord = (action / 512 * 96).astype(np.int32) + marker_size = int(8 / 96 * self.render_size) + thickness = int(1 / 96 * self.render_size) + cv2.drawMarker( + img, + coord, + color=(255, 0, 0), + markerType=cv2.MARKER_CROSS, + markerSize=marker_size, + thickness=thickness, + ) + self.render_cache = img + + return obs + + def render(self, mode): + assert mode == "rgb_array" + + if self.render_cache is None: + self._get_obs() + + return self.render_cache diff --git a/lerobot/common/envs/pusht/pymunk_override.py b/lerobot/common/envs/pusht/pymunk_override.py new file mode 100644 index 000000000..7ad762372 --- /dev/null +++ b/lerobot/common/envs/pusht/pymunk_override.py @@ -0,0 +1,244 @@ +# ---------------------------------------------------------------------------- +# pymunk +# Copyright (c) 2007-2016 Victor Blomqvist +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ---------------------------------------------------------------------------- + +"""This submodule contains helper functions to help with quick prototyping +using pymunk together with pygame. + +Intended to help with debugging and prototyping, not for actual production use +in a full application. The methods contained in this module is opinionated +about your coordinate system and not in any way optimized. +""" + +__docformat__ = "reStructuredText" + +__all__ = [ + "DrawOptions", + "get_mouse_pos", + "to_pygame", + "from_pygame", + # "lighten", + "positive_y_is_up", +] + +from typing import Sequence, Tuple + +import numpy as np +import pygame +import pymunk +from pymunk.space_debug_draw_options import SpaceDebugColor +from pymunk.vec2d import Vec2d + +positive_y_is_up: bool = False +"""Make increasing values of y point upwards. + +When True:: + + y + ^ + | . (3, 3) + | + | . (2, 2) + | + +------ > x + +When False:: + + +------ > x + | + | . (2, 2) + | + | . (3, 3) + v + y + +""" + + +class DrawOptions(pymunk.SpaceDebugDrawOptions): + def __init__(self, surface: pygame.Surface) -> None: + """Draw a pymunk.Space on a pygame.Surface object. + + Typical usage:: + + >>> import pymunk + >>> surface = pygame.Surface((10,10)) + >>> space = pymunk.Space() + >>> options = pymunk.pygame_util.DrawOptions(surface) + >>> space.debug_draw(options) + + You can control the color of a shape by setting shape.color to the color + you want it drawn in:: + + >>> c = pymunk.Circle(None, 10) + >>> c.color = pygame.Color("pink") + + See pygame_util.demo.py for a full example + + Since pygame uses a coordinate system where y points down (in contrast + to many other cases), you either have to make the physics simulation + with Pymunk also behave in that way, or flip everything when you draw. + + The easiest is probably to just make the simulation behave the same + way as Pygame does. In that way all coordinates used are in the same + orientation and easy to reason about:: + + >>> space = pymunk.Space() + >>> space.gravity = (0, -1000) + >>> body = pymunk.Body() + >>> body.position = (0, 0) # will be positioned in the top left corner + >>> space.debug_draw(options) + + To flip the drawing its possible to set the module property + :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip + the simulation upside down before drawing:: + + >>> positive_y_is_up = True + >>> body = pymunk.Body() + >>> body.position = (0, 0) + >>> # Body will be position in bottom left corner + + :Parameters: + surface : pygame.Surface + Surface that the objects will be drawn on + """ + self.surface = surface + super().__init__() + + def draw_circle( + self, + pos: Vec2d, + angle: float, + radius: float, + outline_color: SpaceDebugColor, + fill_color: SpaceDebugColor, + ) -> None: + p = to_pygame(pos, self.surface) + + pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0) + pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0) + + # circle_edge = pos + Vec2d(radius, 0).rotated(angle) + # p2 = to_pygame(circle_edge, self.surface) + # line_r = 2 if radius > 20 else 1 + # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r) + + def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None: + p1 = to_pygame(a, self.surface) + p2 = to_pygame(b, self.surface) + + pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2]) + + def draw_fat_segment( + self, + a: Tuple[float, float], + b: Tuple[float, float], + radius: float, + outline_color: SpaceDebugColor, + fill_color: SpaceDebugColor, + ) -> None: + p1 = to_pygame(a, self.surface) + p2 = to_pygame(b, self.surface) + + r = round(max(1, radius * 2)) + pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r) + if r > 2: + orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])] + if orthog[0] == 0 and orthog[1] == 0: + return + scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5 + orthog[0] = round(orthog[0] * scale) + orthog[1] = round(orthog[1] * scale) + points = [ + (p1[0] - orthog[0], p1[1] - orthog[1]), + (p1[0] + orthog[0], p1[1] + orthog[1]), + (p2[0] + orthog[0], p2[1] + orthog[1]), + (p2[0] - orthog[0], p2[1] - orthog[1]), + ] + pygame.draw.polygon(self.surface, fill_color.as_int(), points) + pygame.draw.circle( + self.surface, + fill_color.as_int(), + (round(p1[0]), round(p1[1])), + round(radius), + ) + pygame.draw.circle( + self.surface, + fill_color.as_int(), + (round(p2[0]), round(p2[1])), + round(radius), + ) + + def draw_polygon( + self, + verts: Sequence[Tuple[float, float]], + radius: float, + outline_color: SpaceDebugColor, + fill_color: SpaceDebugColor, + ) -> None: + ps = [to_pygame(v, self.surface) for v in verts] + ps += [ps[0]] + + radius = 2 + pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps) + + if radius > 0: + for i in range(len(verts)): + a = verts[i] + b = verts[(i + 1) % len(verts)] + self.draw_fat_segment(a, b, radius, fill_color, fill_color) + + def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None: + p = to_pygame(pos, self.surface) + pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0) + + +def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]: + """Get position of the mouse pointer in pymunk coordinates.""" + p = pygame.mouse.get_pos() + return from_pygame(p, surface) + + +def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]: + """Convenience method to convert pymunk coordinates to pygame surface + local coordinates. + + Note that in case positive_y_is_up is False, this function won't actually do + anything except converting the point to integers. + """ + if positive_y_is_up: + return round(p[0]), surface.get_height() - round(p[1]) + else: + return round(p[0]), round(p[1]) + + +def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]: + """Convenience method to convert pygame surface local coordinates to + pymunk coordinates + """ + return to_pygame(p, surface) + + +def light_color(color: SpaceDebugColor): + color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255])) + color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3]) + return color diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py new file mode 100644 index 000000000..6dc128dec --- /dev/null +++ b/lerobot/common/policies/diffusion/model/ema_model.py @@ -0,0 +1,84 @@ +import torch +from torch.nn.modules.batchnorm import _BatchNorm + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999 + ): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = model + self.averaged_model.eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.max_value = max_value + + self.decay = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + + @torch.no_grad() + def step(self, new_model): + self.decay = self.get_decay(self.optimization_step) + + # old_all_dataptrs = set() + # for param in new_model.parameters(): + # data_ptr = param.data_ptr() + # if data_ptr != 0: + # old_all_dataptrs.add(data_ptr) + + # all_dataptrs = set() + for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False): + for param, ema_param in zip( + module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False + ): + # iterative over immediate parameters only. + if isinstance(param, dict): + raise RuntimeError("Dict parameter not supported") + + # data_ptr = param.data_ptr() + # if data_ptr != 0: + # all_dataptrs.add(data_ptr) + + if isinstance(module, _BatchNorm): + # skip batchnorms + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + elif not param.requires_grad: + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + else: + ema_param.mul_(self.decay) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) + + # verify that iterating over module and then parameters is identical to parameters recursively. + # assert old_all_dataptrs == all_dataptrs + self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/pytorch_utils.py b/lerobot/common/policies/diffusion/pytorch_utils.py index c1444b062..ed5dc23ac 100644 --- a/lerobot/common/policies/diffusion/pytorch_utils.py +++ b/lerobot/common/policies/diffusion/pytorch_utils.py @@ -2,6 +2,36 @@ import torch import torch.nn as nn +import torchvision + + +def get_resnet(name, weights=None, **kwargs): + """ + name: resnet18, resnet34, resnet50 + weights: "IMAGENET1K_V1", "r3m" + """ + # load r3m weights + if (weights == "r3m") or (weights == "R3M"): + return get_r3m(name=name, **kwargs) + + func = getattr(torchvision.models, name) + resnet = func(weights=weights, **kwargs) + resnet.fc = torch.nn.Identity() + return resnet + + +def get_r3m(name, **kwargs): + """ + name: resnet18, resnet34, resnet50 + """ + import r3m + + r3m.device = "cpu" + model = r3m.load_r3m(name) + r3m_model = model.module + resnet_model = r3m_model.convnet + resnet_model = resnet_model.to("cpu") + return resnet_model def dict_apply( diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 1adc8a9eb..0dae5056d 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -74,7 +74,6 @@ noise_scheduler: prediction_type: epsilon # or sample obs_encoder: - # _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder shape_meta: ${shape_meta} # resize_shape: null # crop_shape: [76, 76] @@ -85,12 +84,12 @@ obs_encoder: imagenet_norm: True rgb_model: - _target_: diffusion_policy.model.vision.model_getter.get_resnet + _target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet name: resnet18 weights: null ema: - _target_: diffusion_policy.model.diffusion.ema_model.EMAModel + _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel update_after_step: 0 inv_gamma: 1.0 power: 0.75 diff --git a/tests/test_envs.py b/tests/test_envs.py index 3382125a0..48e637d5f 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -3,7 +3,7 @@ from torchrl.envs.utils import check_env_specs, step_mdp from lerobot.common.envs.factory import make_env -from lerobot.common.envs.pusht import PushtEnv +from lerobot.common.envs.pusht.pusht import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv from .utils import init_config From 7982425670fff9580165451399bfaf24a3b4597b Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 10 Mar 2024 16:36:30 +0100 Subject: [PATCH 4/6] Remove diffusion-policy dependency --- .github/workflows/test.yml | 5 +---- .gitignore | 3 --- .pre-commit-config.yaml | 2 +- README.md | 6 ------ poetry.lock | 17 +---------------- pyproject.toml | 1 - 6 files changed, 3 insertions(+), 31 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ad773f3f..e9edad051 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -69,10 +69,7 @@ jobs: key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }} - name: Install dependencies if: steps.restore-dependencies-cache.outputs.cache-hit != 'true' - run: | - poetry install --no-interaction --no-root - git clone https://github.com/real-stanford/diffusion_policy - cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/ + run: poetry install --no-interaction --no-root - name: Save cached venv if: | steps.restore-dependencies-cache.outputs.cache-hit != 'true' && diff --git a/.gitignore b/.gitignore index e2d4c0ec3..ad9892d4f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# Custom -diffusion_policy - # Logging logs tmp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c23a42bcb..2b79434a4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: ^(data/|tests/|diffusion_policy/) +exclude: ^(data/|tests/) default_language_version: python: python3.10 repos: diff --git a/README.md b/README.md index 795bacaf0..551270b58 100644 --- a/README.md +++ b/README.md @@ -24,12 +24,6 @@ mkdir ~/tmp export TMPDIR='~/tmp' ``` -Install `diffusion_policy` #HACK -``` -# from this directory -git clone https://github.com/real-stanford/diffusion_policy -cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/ -``` ## Usage diff --git a/poetry.lock b/poetry.lock index 9a35071b4..db4f8f3e7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -477,21 +477,6 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi torch = ["accelerate (>=0.11.0)", "torch (>=1.4,<2.2.0)"] training = ["Jinja2", "accelerate (>=0.11.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] -[[package]] -name = "diffusion_policy" -version = "0.0.0" -description = "" -optional = false -python-versions = "*" -files = [] -develop = false - -[package.source] -type = "git" -url = "https://github.com/real-stanford/diffusion_policy" -reference = "HEAD" -resolved_reference = "548a52bbb105518058e27bf34dcf90bf6f73681a" - [[package]] name = "distlib" version = "0.3.8" @@ -3140,4 +3125,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "9c3e86956dd11bc8d7823e5e6c5e74a073051b495f71f96179113d99791f7ca0" +content-hash = "c4d83579aed1c8c2e54cad7c8ec81b95a09ab8faff74fc9a4cb20bd00e4ddec6" diff --git a/pyproject.toml b/pyproject.toml index 64cbb8506..ebce8f326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,6 @@ mujoco = "^3.1.2" mujoco-py = "^2.1.2.14" gym = "^0.26.2" opencv-python = "^4.9.0.80" -diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"} diffusers = "^0.26.3" torchvision = "^0.17.1" h5py = "^3.10.0" From 134009f337ded7ead878631ecec9c950076fab3c Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 10 Mar 2024 16:38:49 +0100 Subject: [PATCH 5/6] Remove init files --- lerobot/common/__init__.py | 0 lerobot/common/datasets/__init__.py | 0 lerobot/common/envs/__init__.py | 0 lerobot/common/policies/__init__.py | 0 lerobot/common/policies/tdmpc/__init__.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 lerobot/common/__init__.py delete mode 100644 lerobot/common/datasets/__init__.py delete mode 100644 lerobot/common/envs/__init__.py delete mode 100644 lerobot/common/policies/__init__.py delete mode 100644 lerobot/common/policies/tdmpc/__init__.py diff --git a/lerobot/common/__init__.py b/lerobot/common/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lerobot/common/datasets/__init__.py b/lerobot/common/datasets/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lerobot/common/envs/__init__.py b/lerobot/common/envs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lerobot/common/policies/tdmpc/__init__.py b/lerobot/common/policies/tdmpc/__init__.py deleted file mode 100644 index e69de29bb..000000000 From f54ee7cda0b194b27da9b2ea98368ab79512a420 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 10 Mar 2024 16:51:50 +0100 Subject: [PATCH 6/6] Fix paths --- lerobot/common/envs/factory.py | 2 +- tests/test_envs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 35ebfa4a9..dd8ab2f7d 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -18,7 +18,7 @@ def make_env(cfg, transform=None): kwargs["task"] = cfg.env.task clsfunc = SimxarmEnv elif cfg.env.name == "pusht": - from lerobot.common.envs.pusht.pusht import PushtEnv + from lerobot.common.envs.pusht.env import PushtEnv # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." diff --git a/tests/test_envs.py b/tests/test_envs.py index 48e637d5f..b51c441b6 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -3,7 +3,7 @@ from torchrl.envs.utils import check_env_specs, step_mdp from lerobot.common.envs.factory import make_env -from lerobot.common.envs.pusht.pusht import PushtEnv +from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv from .utils import init_config