Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves Type Annotations #252

Merged
merged 12 commits into from
Jun 10, 2024
1 change: 1 addition & 0 deletions lerobot/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,6 @@ def log_dict(self, d, step, mode="train"):

def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
6 changes: 5 additions & 1 deletion lerobot/common/policies/policy_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,23 @@ def reset(self):

Does things like clearing caches.
"""
...
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.

Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
...

def select_action(self, batch: dict[str, Tensor]):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).

When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
...


@runtime_checkable
Expand All @@ -73,3 +76,4 @@ def update(self):
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""
...
2 changes: 1 addition & 1 deletion lerobot/common/policies/tdmpc/modeling_tdmpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def reset(self):
self._prev_mean: torch.Tensor | None = None

@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
Expand Down
41 changes: 30 additions & 11 deletions lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from PIL import Image as PILImage
from torch import Tensor
from torch import Tensor, nn
from tqdm import trange

from lerobot.common.datasets.factory import make_dataset
Expand Down Expand Up @@ -99,13 +99,13 @@ def rollout(
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
environment termination/truncation).
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
the first True is followed by True's all the way till the end. This can be used for masking
extraneous elements from the sequences above.

Args:
env: The batch of environments.
policy: The policy.
policy: The policy. Must be a PyTorch nn module.
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
specifies the seeds for each of the environments.
return_observations: Whether to include all observations in the returned rollout data. Observations
Expand All @@ -116,6 +116,7 @@ def rollout(
Returns:
The dictionary described above.
"""
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved
device = get_device_from_parameters(policy)

# Reset the policy and environments.
Expand Down Expand Up @@ -231,6 +232,10 @@ def eval_policy(
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")

assert isinstance(policy, Policy)
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
start = time.time()
policy.eval()

Expand Down Expand Up @@ -271,11 +276,16 @@ def render_frame(env: gym.vector.VectorEnv):
if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = []

seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
if start_seed is None:
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env,
policy,
seeds=seeds,
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
enable_progbar=enable_inner_progbar,
Expand All @@ -285,7 +295,8 @@ def render_frame(env: gym.vector.VectorEnv):
# this won't be included).
n_steps = rollout_data["done"].shape[1]
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)

# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
Expand All @@ -296,8 +307,12 @@ def render_frame(env: gym.vector.VectorEnv):
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist())
all_seeds.extend(seeds)
if seeds:
all_seeds.extend(seeds)
else:
all_seeds.append(None)

# FIXME: episode_data is either None or it doesn't exist
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
if return_episode_data:
this_episode_data = _compile_episode_data(
rollout_data,
Expand Down Expand Up @@ -347,6 +362,7 @@ def render_frame(env: gym.vector.VectorEnv):
):
if n_episodes_rendered >= max_episodes_rendered:
break

videos_dir.mkdir(parents=True, exist_ok=True)
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
video_paths.append(str(video_path))
Expand Down Expand Up @@ -504,16 +520,17 @@ def _compile_episode_data(


def main(
pretrained_policy_path: str | None = None,
pretrained_policy_path: Path | None = None,
hydra_cfg_path: str | None = None,
out_dir: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if hydra_cfg_path is None:
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
if pretrained_policy_path is not None:
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)

if out_dir is None:
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"

Expand All @@ -531,10 +548,12 @@ def main(

logging.info("Making policy.")
if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)

assert isinstance(policy, nn.Module)
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
policy.eval()

with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
Expand Down
12 changes: 8 additions & 4 deletions lerobot/scripts/push_dataset_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import json
import shutil
from pathlib import Path
from typing import Any

import torch
from huggingface_hub import HfApi
Expand All @@ -77,7 +78,7 @@
from lerobot.common.datasets.utils import flatten_dict


def get_from_raw_to_lerobot_format_fn(raw_format):
def get_from_raw_to_lerobot_format_fn(raw_format: str):
if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
elif raw_format == "umi_zarr":
Expand All @@ -96,7 +97,9 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
return from_raw_to_lerobot_format


def save_meta_data(info, stats, episode_data_index, meta_data_dir):
def save_meta_data(
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
):
meta_data_dir.mkdir(parents=True, exist_ok=True)

# save info
Expand All @@ -114,7 +117,7 @@ def save_meta_data(info, stats, episode_data_index, meta_data_dir):
save_file(episode_data_index, ep_data_idx_path)


def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
"""Expect all meta data files to be all stored in a single "meta_data" directory.
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
"""
Expand All @@ -128,7 +131,7 @@ def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
)


def push_videos_to_hub(repo_id, videos_dir, revision):
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
"""Expect mp4 files to be all stored in a single "videos" directory.
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
"""
Expand Down Expand Up @@ -209,6 +212,7 @@ def push_dataset_to_hub(
save_meta_data(info, stats, episode_data_index, meta_data_dir)

if not dry_run:
# TODO(rcadene): token needs to be a str | None
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)

Expand Down
8 changes: 6 additions & 2 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler

from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
Expand Down Expand Up @@ -292,6 +293,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
Expand All @@ -302,7 +304,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dataset_stats=offline_dataset.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)

assert isinstance(policy, nn.Module)
alexander-soare marked this conversation as resolved.
Show resolved Hide resolved
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
Expand Down Expand Up @@ -333,6 +335,7 @@ def evaluate_and_checkpoint_if_needed(step):
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
assert eval_env is not None
eval_info = eval_policy(
eval_env,
policy,
Expand Down Expand Up @@ -414,7 +417,8 @@ def evaluate_and_checkpoint_if_needed(step):

step += 1

eval_env.close()
if eval_env:
eval_env.close()
logging.info("End of training")


Expand Down
11 changes: 7 additions & 4 deletions lerobot/scripts/visualize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,31 @@
import logging
import time
from pathlib import Path
from typing import Iterator

import numpy as np
import rerun as rr
import torch
import torch.utils.data
import tqdm

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset


class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)

def __iter__(self):
def __iter__(self) -> Iterator:
return iter(self.frame_ids)

def __len__(self):
def __len__(self) -> int:
return len(self.frame_ids)


def to_hwc_uint8_numpy(chw_float32_torch):
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
Expand Down
Loading