Skip to content

Commit

Permalink
revert dp changes, make act and tdmpc batch friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare committed Mar 18, 2024
1 parent 09ddd9b commit 8834796
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 58 deletions.
12 changes: 8 additions & 4 deletions lerobot/common/policies/abstract.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections import deque

import torch
from torch import Tensor, nn


class AbstractPolicy(nn.Module):
class AbstractPolicy(nn.Module, ABC):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
"""

@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
pass

def save(self, fp):
torch.save(self.state_dict(), fp)
Expand All @@ -24,7 +29,6 @@ def select_action(self, observation) -> Tensor:
Should return a (batch_size, n_action_steps, *) tensor of actions.
"""
pass

def forward(self, *args, **kwargs):
"""Inference step that makes multi-step policies compatible with their single-step environments.
Expand Down
9 changes: 1 addition & 8 deletions lerobot/common/policies/act/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ def select_action(self, observation, step_count):

self.eval()

# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
# observation["state"] = observation["state"].unsqueeze(0)

# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
Expand All @@ -180,11 +176,8 @@ def select_action(self, observation, step_count):
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)

# remove bsize=1
action = action.squeeze(0)

# take first predicted action or n first actions
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
action = action[: self.n_action_steps]
return action

def _forward(self, qpos, image, actions=None, is_pad=None):
Expand Down
31 changes: 5 additions & 26 deletions lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,15 @@
import copy
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Tuple, Union

import timm
import torch
import torch.nn as nn
import torchvision
from robomimic.models.base_nets import SpatialSoftmax

from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules


class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""

def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
"""
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)

def forward(self, x):
return torch.flatten(self.pool(self.backbone(x)), start_dim=1)


class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
Expand All @@ -46,7 +24,7 @@ def __init__(
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
norm_mean_std: Optional[tuple[float, float]] = None,
imagenet_norm: bool = False,
):
"""
Assumes rgb input: B,C,H,W
Expand Down Expand Up @@ -120,9 +98,10 @@ def __init__(
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if norm_mean_std is not None:
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize(
mean=norm_mean_std[0], std=norm_mean_std[1]
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
Expand Down
4 changes: 2 additions & 2 deletions lerobot/common/policies/diffusion/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder


class DiffusionPolicy(AbstractPolicy):
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
self.cfg = cfg

noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
Expand Down
7 changes: 1 addition & 6 deletions lerobot/common/policies/tdmpc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,6 @@ def load(self, fp):
def select_action(self, observation, step_count):
t0 = step_count.item() == 0

# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)

obs = {
# TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(),
Expand All @@ -149,7 +144,7 @@ def act(self, obs, t0=False, step=None):
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
a = self.model.pi(z, self.cfg.min_std * self.model.training)
return a

@torch.no_grad()
Expand Down
20 changes: 10 additions & 10 deletions lerobot/configs/policy/diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ policy:
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null
diffusion_step_embed_dim: 128
down_dims: [512, 1024, 2048]
diffusion_step_embed_dim: 256 # before 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
Expand Down Expand Up @@ -81,12 +81,12 @@ obs_encoder:
# random_crop: True
use_group_norm: True
share_rgb_model: False
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
imagenet_norm: True

rgb_model:
model_name: resnet18
pretrained: false
num_keypoints: 32
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
name: resnet18
weights: null

ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
Expand All @@ -109,13 +109,13 @@ training:
debug: False
resume: True
# optimization
lr_scheduler: cosine
lr_warmup_steps: 500
num_epochs: 500
# lr_scheduler: cosine
# lr_warmup_steps: 500
num_epochs: 8000
# gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm.
use_ema: True
# use_ema: True
freeze_encoder: False
# training loop control
# in epochs
Expand Down
4 changes: 2 additions & 2 deletions lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def eval(cfg: dict, out_dir=None):
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
{"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform}
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __init__(self):
self.n_action_steps = n_action_steps
self.n_policy_invocations = 0

def update(self):
pass

def select_action(self):
self.n_policy_invocations += 1
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
Expand Down

0 comments on commit 8834796

Please sign in to comment.