Skip to content

Commit

Permalink
Fix minimal install test (add require_package_arg)
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Jun 12, 2024
1 parent 71e0689 commit c6bd330
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
18 changes: 10 additions & 8 deletions tests/test_push_dataset_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
from tests.utils import require_package_arg


def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
Expand Down Expand Up @@ -250,18 +251,19 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
)


@patch("lerobot.scripts.push_dataset_to_hub.download_raw", _mock_download_raw)
@pytest.mark.parametrize(
"raw_format, repo_id",
"required_packages, raw_format, repo_id",
[
("pusht_zarr", "lerobot/pusht"),
("xarm_pkl", "lerobot/xarm_lift_medium"),
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
("dora_parquet", "cadene/wrist_gripper"),
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
(None, "dora_parquet", "cadene/wrist_gripper"),
],
)
@patch("lerobot.scripts.push_dataset_to_hub.download_raw", _mock_download_raw)
def test_push_dataset_to_hub_format(tmpdir, raw_format, repo_id):
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
num_episodes = 3
tmpdir = Path(tmpdir)

Expand Down
35 changes: 34 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
If 'env_name' is None, this check is skipped.
"""

@wraps(func)
Expand All @@ -91,14 +92,46 @@ def wrapper(*args, **kwargs):

# Perform the package check
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
if env_name is not None and not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")

return func(*args, **kwargs)

return wrapper


def require_package_arg(func):
"""
Decorator that skips the test if the required package is not installed.
This is similar to `require_env` but more general in that it can check any package (not just environments).
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
If 'required_packages' is None, this check is skipped.
"""

@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'required_packages' is provided and extract its value
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if "required_packages" in arg_names:
# Get the index of 'required_packages' and retrieve the value from args
index = arg_names.index("required_packages")
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
else:
raise ValueError("Function does not have 'required_packages' as an argument.")

if required_packages is None:
return func(*args, **kwargs)

# Perform the package check
for package in required_packages:
if not is_package_available(package):
pytest.skip(f"{package} not installed")

return func(*args, **kwargs)

return wrapper


def require_package(package_name):
"""
Decorator that skips the test if the specified package is not installed.
Expand Down

0 comments on commit c6bd330

Please sign in to comment.