Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate SerialEnv and introduct multistep policy logic #26

Merged
merged 26 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2a01487
early training loss as expected
alexander-soare Mar 11, 2024
304355c
Merge remote-tracking branch 'origin/main' into train_pusht
alexander-soare Mar 11, 2024
87fcc53
wip - still need to verify full training run
alexander-soare Mar 11, 2024
9512d1d
Merge branch 'main' into user/alexander-soare/train_pusht
alexander-soare Mar 12, 2024
98484ac
ready for review
alexander-soare Mar 12, 2024
ba91976
wip: still needs batch logic for act and tdmp
alexander-soare Mar 14, 2024
4822d63
Merge branch 'main' into user/alexander-soare/multistep_policy_and_se…
alexander-soare Mar 14, 2024
736bc96
Merge branch 'main' into user/alexander-soare/train_pusht
alexander-soare Mar 14, 2024
a222c88
Merge branch 'user/alexander-soare/train_pusht' into user/alexander-s…
alexander-soare Mar 14, 2024
a45896d
Merge remote-tracking branch 'origin/main' into user/alexander-soare/…
alexander-soare Mar 15, 2024
3124f71
Merge remote-tracking branch 'origin/main' into user/alexander-soare/…
alexander-soare Mar 15, 2024
bae7e7b
Merge remote-tracking branch 'origin/main' into user/alexander-soare/…
alexander-soare Mar 15, 2024
09ddd9b
Merge branch 'main' into user/alexander-soare/multistep_policy_and_se…
alexander-soare Mar 18, 2024
8834796
revert dp changes, make act and tdmpc batch friendly
alexander-soare Mar 18, 2024
ea17f4c
backup wip
alexander-soare Mar 19, 2024
896a11f
backup wip
alexander-soare Mar 19, 2024
46ac87d
ready for review
alexander-soare Mar 19, 2024
b54cdc9
break_when_any_done==True for batch_size==1
alexander-soare Mar 19, 2024
18fa884
Move reset_warning_issued flag to class attribute
alexander-soare Mar 20, 2024
c5010fe
fix seeding
alexander-soare Mar 20, 2024
4f1955e
Clear action queue when environment is reset
alexander-soare Mar 20, 2024
52e149f
Only save video frames in first rollout
alexander-soare Mar 20, 2024
d16f6a9
Merge remote-tracking branch 'upstream/main' into user/alexander-soar…
alexander-soare Mar 20, 2024
b1ec3da
remove internal rendering hooks
alexander-soare Mar 20, 2024
5332766
revision
alexander-soare Mar 20, 2024
4b7ec81
remove abstracmethods, fix online training
alexander-soare Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions lerobot/common/envs/aloha/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,42 +168,31 @@ def _reset(self, tensordict: Optional[TensorDict] = None):
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0

if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()

num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
_, reward, _, raw_obs = self._env.step(action)

# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
# TODO(rcadene): add an enum
success = done = reward == 4
obs = self._format_raw_obs(raw_obs)

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

self.call_rendering_hooks()
self.call_rendering_hooks()
Cadene marked this conversation as resolved.
Show resolved Hide resolved

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
Expand Down
8 changes: 5 additions & 3 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv


def make_env(cfg, transform=None):
def make_env(cfg, seed=None, transform=None):
"""
Provide seed to override the seed in the cfg (useful for batched environments).
"""
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
"seed": seed if seed is not None else cfg.seed,
}

if cfg.env.name == "simxarm":
Expand Down
41 changes: 15 additions & 26 deletions lerobot/common/envs/pusht/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import deque
from typing import Optional

import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -120,40 +119,30 @@ def _reset(self, tensordict: Optional[TensorDict] = None):
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0

if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
raw_obs, reward, done, info = self._env.step(action)

num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
obs = self._format_raw_obs(raw_obs)

obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

self.call_rendering_hooks()
self.call_rendering_hooks()

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"reward": torch.tensor([reward], dtype=torch.float32),
# success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool),
},
Expand Down
54 changes: 54 additions & 0 deletions lerobot/common/policies/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import abstractmethod
from collections import deque

import torch
from torch import Tensor, nn


class AbstractPolicy(nn.Module):
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def update(self, replay_buffer, step):
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
"""One step of the policy's learning algorithm."""
pass

def save(self, fp):
torch.save(self.state_dict(), fp)

def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)

@abstractmethod
def select_action(self, observation) -> Tensor:
Cadene marked this conversation as resolved.
Show resolved Hide resolved
"""Select an action (or trajectory of actions) based on an observation during rollout.

Should return a (batch_size, n_action_steps, *) tensor of actions.
"""
pass
Cadene marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, *args, **kwargs):
"""Inference step that makes multi-step policies compatible with their single-step environments.

WARNING: In general, this should not be overriden.

Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
the subclass doesn't have to.

This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
the action trajectory horizon and * is the action dimensions.
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
"""
n_action_steps_attr = "n_action_steps"
if not hasattr(self, n_action_steps_attr):
Cadene marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
if not hasattr(self, "_action_queue"):
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
if len(self._action_queue) == 0:
# Each element in the queue has shape (B, *).
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))

return self._action_queue.popleft()
6 changes: 3 additions & 3 deletions lerobot/common/policies/act/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import time

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms

from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build


Expand Down Expand Up @@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
return total_kld, dimension_wise_kld, mean_kld


class ActionChunkingTransformerPolicy(nn.Module):
class ActionChunkingTransformerPolicy(AbstractPolicy):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
self.cfg = cfg
Expand Down Expand Up @@ -147,7 +147,7 @@ def compute_loss(self, batch):
return loss

@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
Cadene marked this conversation as resolved.
Show resolved Hide resolved
# TODO(rcadene): remove unused step_count
del step_count

Expand Down
14 changes: 5 additions & 9 deletions lerobot/common/policies/diffusion/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import hydra
import torch
import torch.nn as nn

from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder


class DiffusionPolicy(nn.Module):
class DiffusionPolicy(AbstractPolicy):
def __init__(
self,
cfg,
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(
**cfg_obs_encoder,
)

self.n_action_steps = n_action_steps # needed for the parent class
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
Expand Down Expand Up @@ -93,21 +94,16 @@ def __init__(
)

@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count

# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)

obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)

action = out["action"].squeeze(0)
action = out["action"]
return action

def update(self, replay_buffer, step):
Expand Down
5 changes: 3 additions & 2 deletions lerobot/common/policies/tdmpc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn

import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.abstract import AbstractPolicy

FIRST_FRAME = 0

Expand Down Expand Up @@ -85,7 +86,7 @@ def Q(self, z, a, return_type): # noqa: N802
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2


class TDMPC(nn.Module):
class TDMPC(AbstractPolicy):
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
"""Implementation of TD-MPC learning + inference."""

def __init__(self, cfg, device):
Expand Down Expand Up @@ -124,7 +125,7 @@ def load(self, fp):
self.model_target.load_state_dict(d["model_target"])

@torch.no_grad()
def forward(self, observation, step_count):
def select_action(self, observation, step_count):
t0 = step_count.item() == 0

# TODO(rcadene): remove unsqueeze hack...
Expand Down
2 changes: 2 additions & 0 deletions lerobot/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ hydra:
name: default

seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
rollout_batch_size: 10
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
device: cuda # cpu
prefetch: 4
eval_freq: ???
Expand Down
Loading
Loading