Skip to content

Commit

Permalink
wip: still needs batch logic for act and tdmp
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare committed Mar 14, 2024
1 parent 8c56770 commit ba91976
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 100 deletions.
43 changes: 16 additions & 27 deletions lerobot/common/envs/aloha/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,42 +168,31 @@ def _reset(self, tensordict: Optional[TensorDict] = None):
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0

if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()

num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
_, reward, _, raw_obs = self._env.step(action)

# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
# TODO(rcadene): add an enum
success = done = reward == 4
obs = self._format_raw_obs(raw_obs)

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

self.call_rendering_hooks()
self.call_rendering_hooks()

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
Expand Down
8 changes: 5 additions & 3 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv


def make_env(cfg, transform=None):
def make_env(cfg, seed=None, transform=None):
"""
Provide seed to override the seed in the cfg (useful for batched environments).
"""
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
"seed": seed if seed is not None else cfg.seed,
}

if cfg.env.name == "simxarm":
Expand Down
41 changes: 15 additions & 26 deletions lerobot/common/envs/pusht/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import deque
from typing import Optional

import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -120,40 +119,30 @@ def _reset(self, tensordict: Optional[TensorDict] = None):
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0

if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
raw_obs, reward, done, info = self._env.step(action)

num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
obs = self._format_raw_obs(raw_obs)

obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

self.call_rendering_hooks()
self.call_rendering_hooks()

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"reward": torch.tensor([reward], dtype=torch.float32),
# success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool),
},
Expand Down
54 changes: 54 additions & 0 deletions lerobot/common/policies/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import abstractmethod
from collections import deque

import torch
from torch import Tensor, nn


class AbstractPolicy(nn.Module):
@abstractmethod
def update(self, replay_buffer, step):
"""One step of the policy's learning algorithm."""
pass

def save(self, fp):
torch.save(self.state_dict(), fp)

def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)

@abstractmethod
def select_action(self, observation) -> Tensor:
"""Select an action (or trajectory of actions) based on an observation during rollout.
Should return a (batch_size, n_action_steps, *) tensor of actions.
"""
pass

def forward(self, *args, **kwargs):
"""Inference step that makes multi-step policies compatible with their single-step environments.
WARNING: In general, this should not be overriden.
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
"""
n_action_steps_attr = "n_action_steps"
if not hasattr(self, n_action_steps_attr):
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
if not hasattr(self, "_action_queue"):
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *).
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))

return self._action_queue.popleft()
6 changes: 3 additions & 3 deletions lerobot/common/policies/act/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import time

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms

from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build


Expand Down Expand Up @@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld


class ActionChunkingTransformerPolicy(nn.Module):
class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
self.cfg = cfg
Expand Down Expand Up @@ -147,7 +147,7 @@ def compute_loss(self, batch):
return loss

@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count

Expand Down
14 changes: 5 additions & 9 deletions lerobot/common/policies/diffusion/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import hydra
import torch
import torch.nn as nn

from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder


class DiffusionPolicy(nn.Module):
class DiffusionPolicy(AbstractPolicy):
def __init__(
self,
cfg,
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(
**cfg_obs_encoder,
)

self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
Expand Down Expand Up @@ -93,21 +94,16 @@ def __init__(
)

@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count

# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)

obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)

action = out["action"].squeeze(0)
action = out["action"]
return action

def update(self, replay_buffer, step):
Expand Down
5 changes: 3 additions & 2 deletions lerobot/common/policies/tdmpc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn

import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy

FIRST_FRAME = 0

Expand Down Expand Up @@ -85,7 +86,7 @@ def Q(self, z, a, return_type): # noqa: N802
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2


class TDMPC(nn.Module):
class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""

def __init__(self, cfg, device):
Expand Down Expand Up @@ -124,7 +125,7 @@ def load(self, fp):
self.model_target.load_state_dict(d["model_target"])

@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
t0 = step_count.item() == 0

# TODO(rcadene): remove unsqueeze hack...
Expand Down
2 changes: 2 additions & 0 deletions lerobot/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ hydra:
name: default

seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
rollout_batch_size: 10
device: cuda # cpu
prefetch: 4
eval_freq: ???
Expand Down
Loading

0 comments on commit ba91976

Please sign in to comment.