Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate SerialEnv and introduct multistep policy logic #26

Merged
merged 26 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2a01487
early training loss as expected
alexander-soare Mar 11, 2024
304355c
Merge remote-tracking branch 'origin/main' into train_pusht
alexander-soare Mar 11, 2024
87fcc53
wip - still need to verify full training run
alexander-soare Mar 11, 2024
9512d1d
Merge branch 'main' into user/alexander-soare/train_pusht
alexander-soare Mar 12, 2024
98484ac
ready for review
alexander-soare Mar 12, 2024
ba91976
wip: still needs batch logic for act and tdmp
alexander-soare Mar 14, 2024
4822d63
Merge branch 'main' into user/alexander-soare/multistep_policy_and_se…
alexander-soare Mar 14, 2024
736bc96
Merge branch 'main' into user/alexander-soare/train_pusht
alexander-soare Mar 14, 2024
a222c88
Merge branch 'user/alexander-soare/train_pusht' into user/alexander-s…
alexander-soare Mar 14, 2024
a45896d
Merge remote-tracking branch 'origin/main' into user/alexander-soare/…
alexander-soare Mar 15, 2024
3124f71
Merge remote-tracking branch 'origin/main' into user/alexander-soare/…
alexander-soare Mar 15, 2024
bae7e7b
Merge remote-tracking branch 'origin/main' into user/alexander-soare/…
alexander-soare Mar 15, 2024
09ddd9b
Merge branch 'main' into user/alexander-soare/multistep_policy_and_se…
alexander-soare Mar 18, 2024
8834796
revert dp changes, make act and tdmpc batch friendly
alexander-soare Mar 18, 2024
ea17f4c
backup wip
alexander-soare Mar 19, 2024
896a11f
backup wip
alexander-soare Mar 19, 2024
46ac87d
ready for review
alexander-soare Mar 19, 2024
b54cdc9
break_when_any_done==True for batch_size==1
alexander-soare Mar 19, 2024
18fa884
Move reset_warning_issued flag to class attribute
alexander-soare Mar 20, 2024
c5010fe
fix seeding
alexander-soare Mar 20, 2024
4f1955e
Clear action queue when environment is reset
alexander-soare Mar 20, 2024
52e149f
Only save video frames in first rollout
alexander-soare Mar 20, 2024
d16f6a9
Merge remote-tracking branch 'upstream/main' into user/alexander-soar…
alexander-soare Mar 20, 2024
b1ec3da
remove internal rendering hooks
alexander-soare Mar 20, 2024
5332766
revision
alexander-soare Mar 20, 2024
4b7ec81
remove abstracmethods, fix online training
alexander-soare Mar 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lerobot/common/datasets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> c 1 1",
("action",): "b c -> c",
}

@property
Expand Down
6 changes: 3 additions & 3 deletions lerobot/common/datasets/aloha.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def __init__(
@property
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("action",): "b c -> c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
d[("observation", "image", cam)] = "b c h w -> c 1 1"
return d

@property
Expand Down
124 changes: 57 additions & 67 deletions lerobot/common/envs/aloha/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
self._reset_warning_issued = False

def _make_env(self):
if not _has_gym:
Expand Down Expand Up @@ -120,90 +121,79 @@ def _format_raw_obs(self, raw_obs):
return obs

def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)

# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset

raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed

obs = self._format_raw_obs(raw_obs.observation)
if tensordict is not None and not self._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True

# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)

# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset

raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed

obs = self._format_raw_obs(raw_obs.observation)

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)

self.call_rendering_hooks()
return td

def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0

if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
_, reward, _, raw_obs = self._env.step(action)

num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
# TODO(rcadene): add an enum
success = done = reward == 4
obs = self._format_raw_obs(raw_obs)

# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

self.call_rendering_hooks()
self.call_rendering_hooks()
Cadene marked this conversation as resolved.
Show resolved Hide resolved

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
Expand Down
45 changes: 31 additions & 14 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from torchrl.envs import SerialEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv


def make_env(cfg, transform=None):
"""
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
environments. The env therefore returns batches.`
"""

kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
}
Expand All @@ -31,22 +36,34 @@ def make_env(cfg, transform=None):
else:
raise ValueError(cfg.env.name)

env = clsfunc(**kwargs)
def _make_env(seed):
nonlocal kwargs
kwargs["seed"] = seed
env = clsfunc(**kwargs)

# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))

# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()

if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env

return env
return SerialEnv(
cfg.rollout_batch_size,
create_env_fn=_make_env,
create_env_kwargs={
"seed": env_seed # noqa: B035
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
},
)


# def make_env(env_name, frame_skip, device, is_test=False):
Expand Down
107 changes: 49 additions & 58 deletions lerobot/common/envs/pusht/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import importlib
import logging
from collections import deque
from typing import Optional

import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
self._reset_warning_issued = False

def _make_env(self):
if not _has_gym:
Expand Down Expand Up @@ -80,80 +81,70 @@ def _format_raw_obs(self, raw_obs):
return obs

def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed

obs = self._format_raw_obs(raw_obs)
if tensordict is not None and not self._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True

# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed

obs = self._format_raw_obs(raw_obs)

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)

self.call_rendering_hooks()
return td

def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0

if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()

num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
raw_obs, reward, done, info = self._env.step(action)

obs = self._format_raw_obs(raw_obs)
obs = self._format_raw_obs(raw_obs)

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

self.call_rendering_hooks()
self.call_rendering_hooks()

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"reward": torch.tensor([reward], dtype=torch.float32),
# success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool),
},
Expand Down
Loading
Loading