Skip to content

Commit

Permalink
temporary_fixes for gym-lowcostrobot
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-aractingi committed Oct 22, 2024
1 parent 04029f5 commit 9a5356d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 29 deletions.
47 changes: 24 additions & 23 deletions lerobot/common/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
import torch
from torch import Tensor

##############################################
### TODO this script is modified to hackathon purposes and should be reset after.
##############################################

PIXELS_KEY="image_front"

def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Expand All @@ -28,28 +33,24 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
"""
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}

for imgkey, img in imgs.items():
img = torch.from_numpy(img)

# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"

# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"

# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255

return_observations[imgkey] = img
#if PIXELS_KEY in observations:
# if isinstance(observations[PIXELS_KEY], dict):
# imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
# else:
# imgs = {"observation.image": observations["pixels"]}
imgs = {"observation.images.image_front": observations["image_front"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[imgkey] = img

if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
Expand All @@ -58,5 +59,5 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten

# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return_observations["observation.state"] = torch.from_numpy(observations["arm_qpos"]).float()
return return_observations
8 changes: 5 additions & 3 deletions lerobot/common/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
#TODO michel_aractingi temp fix to remove before merge
del batch[self.input_image_key]

self._queues = populate_queues(self._queues, batch)

Expand Down Expand Up @@ -343,7 +345,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
batch[key] = batch[key].transpose(1, 0)

action = batch["action"] # (t, b, action_dim)
reward = batch["next.reward"] # (t, b)
reward = batch["reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}

# Apply random image augmentations.
Expand Down Expand Up @@ -420,7 +422,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
* ~batch["reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
Expand All @@ -441,7 +443,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
Expand Down
1 change: 0 additions & 1 deletion lerobot/scripts/control_sim_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,6 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
env.reset(seed=seeds[from_idx].item())

logging.info("Replaying episode")
say("Replaying episode", blocking=True)
for idx in range(from_idx, to_idx):
Expand Down
4 changes: 2 additions & 2 deletions lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ def rollout(
action = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"

# Apply the next action.
# Apply the next action. TODO (michel_aractingi) temp fix
observation, reward, terminated, truncated, info = env.step(action)
if render_callback is not None:
render_callback(env)

# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
if False and "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
else:
successes = [False] * env.num_envs
Expand Down

0 comments on commit 9a5356d

Please sign in to comment.