diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 01b2ef4f4..693ff40c9 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -53,12 +53,14 @@ def make_optimizer_and_scheduler(cfg, policy): "params": [ p for n, p in policy.named_parameters() - if not n.startswith("backbone") and p.requires_grad + if not n.startswith("model.backbone") and p.requires_grad ] }, { "params": [ - p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad + p + for n, p in policy.named_parameters() + if n.startswith("model.backbone") and p.requires_grad ], "lr": cfg.training.lr_backbone, }, diff --git a/tests/test_policies.py b/tests/test_policies.py index c099bef00..95da20c9f 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -30,6 +30,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config +from lerobot.scripts.train import make_optimizer_and_scheduler from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -174,6 +175,33 @@ def test_policy(env_name, policy_name, extra_overrides): env.step(action) +def test_act_backbone_lr(): + """ + Test that the ACT policy can be instantiated with a different learning rate for the backbone. + """ + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[ + "env=aloha", + "policy=act", + f"device={DEVICE}", + "training.lr_backbone=0.001", + "training.lr=0.01", + ], + ) + assert cfg.training.lr == 0.01 + assert cfg.training.lr_backbone == 0.001 + + dataset = make_dataset(cfg) + policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) + optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + assert len(optimizer.param_groups) == 2 + assert optimizer.param_groups[0]["lr"] == cfg.training.lr + assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone + assert len(optimizer.param_groups[0]["params"]) == 133 + assert len(optimizer.param_groups[1]["params"]) == 20 + + @pytest.mark.parametrize("policy_name", available_policies) def test_policy_defaults(policy_name: str): """Check that the policy can be instantiated with defaults."""