Skip to content

Commit

Permalink
Merge pull request #26 from Cadene/user/alexander-soare/multistep_pol…
Browse files Browse the repository at this point in the history
…icy_and_serial_env

Incorporate SerialEnv and introduct multistep policy logic
  • Loading branch information
Cadene authored Mar 20, 2024
2 parents 3910c48 + 4b7ec81 commit ec536ef
Show file tree
Hide file tree
Showing 19 changed files with 465 additions and 288 deletions.
6 changes: 3 additions & 3 deletions lerobot/common/datasets/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __init__(
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> c 1 1",
("action",): "b c -> c",
}

@property
Expand Down
6 changes: 3 additions & 3 deletions lerobot/common/datasets/aloha.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def __init__(
@property
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action",): "b c -> 1 c",
("observation", "state"): "b c -> c",
("action",): "b c -> c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
d[("observation", "image", cam)] = "b c h w -> c 1 1"
return d

@property
Expand Down
30 changes: 6 additions & 24 deletions lerobot/common/envs/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
from collections import deque
from typing import Optional

Expand Down Expand Up @@ -27,7 +26,6 @@ def __init__(
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []

if pixels_only:
assert from_pixels
Expand All @@ -45,36 +43,20 @@ def __init__(
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)

def register_rendering_hook(self, func):
self._rendering_hooks.append(func)

def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)

def reset_rendering_hooks(self):
self._rendering_hooks = []

@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()
raise NotImplementedError("Abstract method")
126 changes: 57 additions & 69 deletions lerobot/common/envs/aloha/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@


class AlohaEnv(AbstractEnv):
_reset_warning_issued = False

def __init__(
self,
task,
Expand Down Expand Up @@ -120,90 +122,76 @@ def _format_raw_obs(self, raw_obs):
return obs

def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)

# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset

raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed

obs = self._format_raw_obs(raw_obs.observation)
if tensordict is not None and not AlohaEnv._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
AlohaEnv._reset_warning_issued = True

# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)

# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset

raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed

obs = self._format_raw_obs(raw_obs.observation)

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)

self.call_rendering_hooks()
return td

def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
assert action.ndim == 1
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0

if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()

num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
_, reward, _, raw_obs = self._env.step(action)

# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)

if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
# TODO(rcadene): add an enum
success = done = reward == 4
obs = self._format_raw_obs(raw_obs)

self.call_rendering_hooks()
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs

td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
Expand Down
45 changes: 30 additions & 15 deletions lerobot/common/envs/factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from torchrl.envs import SerialEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv


def make_env(cfg, transform=None):
"""
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
environments. The env therefore returns batches.`
"""

kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
}

Expand All @@ -31,22 +35,33 @@ def make_env(cfg, transform=None):
else:
raise ValueError(cfg.env.name)

env = clsfunc(**kwargs)
def _make_env(seed):
nonlocal kwargs
kwargs["seed"] = seed
env = clsfunc(**kwargs)

# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))

# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()

if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env

return env
return SerialEnv(
cfg.rollout_batch_size,
create_env_fn=_make_env,
create_env_kwargs=[
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)


# def make_env(env_name, frame_skip, device, is_test=False):
Expand Down
Loading

0 comments on commit ec536ef

Please sign in to comment.