From 6d4497202bd207e8d7733580100187c8aa9c1c9d Mon Sep 17 00:00:00 2001 From: hjiang Date: Tue, 10 Dec 2024 00:21:52 +0000 Subject: [PATCH] allow uv cache Signed-off-by: hjiang --- python/ray/_private/runtime_env/uv.py | 8 ++-- python/ray/_private/runtime_env/validation.py | 10 ++++- python/ray/tests/test_runtime_env_uv.py | 41 +++++++++++++++++++ 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/python/ray/_private/runtime_env/uv.py b/python/ray/_private/runtime_env/uv.py index 78d4ad2c55f4..7820add75af8 100644 --- a/python/ray/_private/runtime_env/uv.py +++ b/python/ray/_private/runtime_env/uv.py @@ -166,19 +166,19 @@ async def _install_uv_packages( # # Difference with pip: # 1. `--disable-pip-version-check` has no effect for uv. - # 2. `--no-cache-dir` for `pip` maps to `--no-cache` for uv. - pip_install_cmd = [ + # 2. Provide option to enable caching for package installation. + uv_install_cmd = [ python, "-m", "uv", "pip", "install", - "--no-cache", + "" if self._uv_config.get("enable_uv_cache", False) else "--no-cache", "-r", requirements_file, ] logger.info("Installing python requirements to %s", virtualenv_path) - await check_output_cmd(pip_install_cmd, logger=logger, cwd=cwd, env=pip_env) + await check_output_cmd(uv_install_cmd, logger=logger, cwd=cwd, env=pip_env) # Check python environment for conflicts. if self._uv_config.get("uv_check", False): diff --git a/python/ray/_private/runtime_env/validation.py b/python/ray/_private/runtime_env/validation.py index a2f37aa0816e..1c53c31445ce 100644 --- a/python/ray/_private/runtime_env/validation.py +++ b/python/ray/_private/runtime_env/validation.py @@ -146,10 +146,10 @@ def parse_and_validate_uv(uv: Union[str, List[str], Dict]) -> Optional[Dict]: elif isinstance(uv, list) and all(isinstance(dep, str) for dep in uv): result = dict(packages=uv, uv_check=False) elif isinstance(uv, dict): - if set(uv.keys()) - {"packages", "uv_check", "uv_version"}: + if set(uv.keys()) - {"packages", "uv_check", "uv_version", "enable_uv_cache"}: raise ValueError( "runtime_env['uv'] can only have these fields: " - "packages, uv_check and uv_version, but got: " + "packages, uv_check, uv_version and enable_uv_cache, but got: " f"{list(uv.keys())}" ) if "packages" not in uv: @@ -166,9 +166,15 @@ def parse_and_validate_uv(uv: Union[str, List[str], Dict]) -> Optional[Dict]: "runtime_env['uv']['uv_version'] must be of type str, " f"got {type(uv['uv_version'])}" ) + if "enable_uv_cache" in uv and not isinstance(uv["enable_uv_cache"], bool): + raise TypeError( + "runtime_env['uv']['enable_uv_cache'] must be of type bool, " + f"got {type(uv['enable_uv_cache'])}" + ) result = uv.copy() result["uv_check"] = uv.get("uv_check", False) + result["enable_uv_cache"] = uv.get("enable_uv_cache", False) if not isinstance(uv["packages"], list): raise ValueError( "runtime_env['uv']['packages'] must be of type list, " diff --git a/python/ray/tests/test_runtime_env_uv.py b/python/ray/tests/test_runtime_env_uv.py index b698cd41ba3a..8fc97626130c 100644 --- a/python/ray/tests/test_runtime_env_uv.py +++ b/python/ray/tests/test_runtime_env_uv.py @@ -108,6 +108,47 @@ def f(): assert ray.get(f.remote()) == "2.3.0" +# Install different versions of the same package across different tasks, used to check +# uv cache doesn't break runtime env requirement. +def test_package_install_with_different_versions(shutdown_only): + @ray.remote(runtime_env={"uv": {"packages": ["requests==2.3.0"]}}) + def f(): + import requests + + assert requests.__version__ == "2.3.0" + + @ray.remote(runtime_env={"uv": {"packages": ["requests==2.2.0"]}}) + def g(): + import requests + + assert requests.__version__ == "2.2.0" + + ray.get(f.remote()) + ray.get(g.remote()) + + +# Install packages with cache enabled. +def test_package_install_with_cache_enabled(shutdown_only): + @ray.remote( + runtime_env={"uv": {"packages": ["requests==2.3.0"], "enable_uv_cache": True}} + ) + def f(): + import requests + + assert requests.__version__ == "2.3.0" + + @ray.remote( + runtime_env={"uv": {"packages": ["requests==2.2.0"], "enable_uv_cache": True}} + ) + def g(): + import requests + + assert requests.__version__ == "2.2.0" + + ray.get(f.remote()) + ray.get(g.remote()) + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))