diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 58001ccafd67..1cd069fb2c46 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -119,7 +119,7 @@ def f(): ray.get(bar.remote(*[f() for _ in range(200)])) -def test_default_worker_import_dependency(): +def test_default_worker_import_dependency(shutdown_only): """ Test ray's python worker import doesn't import the not-allowed dependencies. """ @@ -129,6 +129,11 @@ def test_default_worker_import_dependency(): # See https://github.com/ray-project/ray/issues/33891 blocked_deps = ["numpy"] + # Ray should not be importing pydantic (used in serialization) eagerly. + # This introduces regression in worker start up time. + # https://github.com/ray-project/ray/issues/41338 + blocked_deps += ["pydantic"] + # Remove the ray module and the blocked deps from sys.modules. sys.modules.pop("ray", None) assert "ray" not in sys.modules @@ -146,6 +151,19 @@ def test_default_worker_import_dependency(): for dep in blocked_deps: assert dep not in sys.modules + # Test starting a ray workers should not see unwanted deps loaded eagerly. + ray.init() + + @ray.remote + def f(): + import ray # noqa: F401 + + assert "ray" in sys.modules + for x in blocked_deps: + assert x not in sys.modules + + ray.get(f.remote()) + # https://github.com/ray-project/ray/issues/7287 def test_omp_threads_set(ray_start_cluster, monkeypatch): diff --git a/python/ray/util/serialization_addons.py b/python/ray/util/serialization_addons.py index cdf0b50a7a22..0f7390a29b84 100644 --- a/python/ray/util/serialization_addons.py +++ b/python/ray/util/serialization_addons.py @@ -6,7 +6,6 @@ import sys from ray.util.annotations import DeveloperAPI -from ray._private.pydantic_compat import register_pydantic_serializers @DeveloperAPI @@ -27,6 +26,8 @@ def register_starlette_serializer(serialization_context): @DeveloperAPI def apply(serialization_context): + from ray._private.pydantic_compat import register_pydantic_serializers + register_pydantic_serializers(serialization_context) register_starlette_serializer(serialization_context)