Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into user/aliberts/2024_05…
Browse files Browse the repository at this point in the history
…_28_compile_torchvision
  • Loading branch information
aliberts committed Jun 19, 2024
2 parents bf3dbbd + 56199fb commit f29d24a
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
55 changes: 54 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,73 @@ wandb login

Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.

You can also locally visualize episodes from a dataset by executing our script from the command line:
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
```

or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
```


It will open `rerun.io` and display the camera streams, robot states and actions, like this:

https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144


Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.

### The `LeRobotDataset` format

A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.

A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.

Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.

Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:

```
dataset attributes:
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
│ ├ observation.images.cam_high (VideoFrame):
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video}
│ ├ observation.state (list of float32): position of an arm joints (for instance)
│ ... (more observations)
│ ├ action (list of float32): goal position of an arm joints (for instance)
│ ├ episode_index (int64): index of the episode for this sample
│ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
│ ├ timestamp (float32): timestamp in the episode
│ ├ next.done (bool): indicates the end of en episode ; True for the last frame in each episode
│ └ index (int64): general index in the whole dataset
├ episode_data_index: contains 2 tensors with the start and end indices of each episode
│ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0
│ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,)
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
│ ...
├ info: a dictionary of metadata on the dataset
│ ├ fps (float): frame per second the dataset is recorded/synchronized to
│ └ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
└ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
```

A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
- hf_dataset stored using Hugging Face datasets library serialization to parquet
- videos are stored in mp4 format to save space or png files
- episode_data_index saved using `safetensor` tensor serialization format
- stats saved using `safetensor` tensor serialization format
- info are saved using JSON

Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.

### Evaluate a pretrained policy

Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
Expand Down
6 changes: 4 additions & 2 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ def make_optimizer_and_scheduler(cfg, policy):
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
Expand Down
28 changes: 28 additions & 0 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel

Expand Down Expand Up @@ -174,6 +175,33 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)


def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
"env=aloha",
"policy=act",
f"device={DEVICE}",
"training.lr_backbone=0.001",
"training.lr=0.01",
],
)
assert cfg.training.lr == 0.01
assert cfg.training.lr_backbone == 0.001

dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
assert len(optimizer.param_groups[0]["params"]) == 133
assert len(optimizer.param_groups[1]["params"]) == 20


@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_name: str):
"""Check that the policy can be instantiated with defaults."""
Expand Down

0 comments on commit f29d24a

Please sign in to comment.