Skip to content

Commit

Permalink
updating tests
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Jun 19, 2024
1 parent a576523 commit 722b765
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 96 deletions.
83 changes: 0 additions & 83 deletions lerobot/configs/policy/act_1000_actions.yaml

This file was deleted.

12 changes: 7 additions & 5 deletions tests/scripts/save_policy_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
return output_dict, grad_stats, param_stats, actions


def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}"
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"

if env_policy_dir.exists():
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
Expand All @@ -114,9 +114,11 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# ),
("aloha", "act_1000_actions", []),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
for env, policy, extra_overrides, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
)
19 changes: 11 additions & 8 deletions tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,25 +287,26 @@ def test_normalize(insert_temporal_dim):


@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
"env_name, policy_name, extra_overrides, file_name_extra",
[
("xarm", "tdmpc", []),
("xarm", "tdmpc", [], ""),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
"",
),
("aloha", "act", ["policy.n_action_steps=10"]),
("aloha", "act_1000_actions", []),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(env_name, policy_name, extra_overrides):
def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
Expand All @@ -317,7 +318,9 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
env_policy_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
)
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
Expand Down

0 comments on commit 722b765

Please sign in to comment.