diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py index 9257c6696..b32a49408 100644 --- a/tests/test_push_dataset_to_hub.py +++ b/tests/test_push_dataset_to_hub.py @@ -20,6 +20,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub +from tests.utils import require_package_arg def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3): @@ -250,18 +251,19 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir): ) +@patch("lerobot.scripts.push_dataset_to_hub.download_raw", _mock_download_raw) @pytest.mark.parametrize( - "raw_format, repo_id", + "required_packages, raw_format, repo_id", [ - ("pusht_zarr", "lerobot/pusht"), - ("xarm_pkl", "lerobot/xarm_lift_medium"), - ("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), - ("umi_zarr", "lerobot/umi_cup_in_the_wild"), - ("dora_parquet", "cadene/wrist_gripper"), + (["gym-pusht"], "pusht_zarr", "lerobot/pusht"), + (None, "xarm_pkl", "lerobot/xarm_lift_medium"), + (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), + (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"), + (None, "dora_parquet", "cadene/wrist_gripper"), ], ) -@patch("lerobot.scripts.push_dataset_to_hub.download_raw", _mock_download_raw) -def test_push_dataset_to_hub_format(tmpdir, raw_format, repo_id): +@require_package_arg +def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id): num_episodes = 3 tmpdir = Path(tmpdir) diff --git a/tests/utils.py b/tests/utils.py index ba49ee706..c1575656c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -76,6 +76,7 @@ def require_env(func): """ Decorator that skips the test if the required environment package is not installed. As it need 'env_name' in args, it also checks whether it is provided as an argument. + If 'env_name' is None, this check is skipped. """ @wraps(func) @@ -91,7 +92,7 @@ def wrapper(*args, **kwargs): # Perform the package check package_name = f"gym_{env_name}" - if not is_package_available(package_name): + if env_name is not None and not is_package_available(package_name): pytest.skip(f"gym-{env_name} not installed") return func(*args, **kwargs) @@ -99,6 +100,38 @@ def wrapper(*args, **kwargs): return wrapper +def require_package_arg(func): + """ + Decorator that skips the test if the required package is not installed. + This is similar to `require_env` but more general in that it can check any package (not just environments). + As it need 'required_packages' in args, it also checks whether it is provided as an argument. + If 'required_packages' is None, this check is skipped. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + # Determine if 'required_packages' is provided and extract its value + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + if "required_packages" in arg_names: + # Get the index of 'required_packages' and retrieve the value from args + index = arg_names.index("required_packages") + required_packages = args[index] if len(args) > index else kwargs.get("required_packages") + else: + raise ValueError("Function does not have 'required_packages' as an argument.") + + if required_packages is None: + return func(*args, **kwargs) + + # Perform the package check + for package in required_packages: + if not is_package_available(package): + pytest.skip(f"{package} not installed") + + return func(*args, **kwargs) + + return wrapper + + def require_package(package_name): """ Decorator that skips the test if the specified package is not installed.