Skip to content

Commit

Permalink
Bug fix: fix setting different learning rates between backbone and ma…
Browse files Browse the repository at this point in the history
…in model in ACT policy (#280)
  • Loading branch information
thomwolf authored Jun 18, 2024
1 parent b72d574 commit 11f1cb5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ def make_optimizer_and_scheduler(cfg, policy):
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("backbone") and p.requires_grad
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
Expand Down
28 changes: 28 additions & 0 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel

Expand Down Expand Up @@ -174,6 +175,33 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)


def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
"env=aloha",
"policy=act",
f"device={DEVICE}",
"training.lr_backbone=0.001",
"training.lr=0.01",
],
)
assert cfg.training.lr == 0.01
assert cfg.training.lr_backbone == 0.001

dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
assert len(optimizer.param_groups[0]["params"]) == 133
assert len(optimizer.param_groups[1]["params"]) == 20


@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_name: str):
"""Check that the policy can be instantiated with defaults."""
Expand Down

0 comments on commit 11f1cb5

Please sign in to comment.