diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index f5351a9b22b2..b4682b101474 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -35,6 +35,7 @@ redis_replicas, start_redis_instance, find_available_port, + wait_for_condition, ) from ray.cluster_utils import AutoscalingCluster, Cluster, cluster_not_supported @@ -167,18 +168,48 @@ def is_process_listen_to_port(pid, port): return False +def redis_alive(port, enable_tls): + try: + # If there is no redis libs installed, skip the check. + # This could happen In minimal test, where we don't have + # redis. + import redis + except Exception: + return True + + params = {} + if enable_tls: + from ray._raylet import Config + + params = {"ssl": True, "ssl_cert_reqs": "required"} + if Config.REDIS_CA_CERT(): + params["ssl_ca_certs"] = Config.REDIS_CA_CERT() + if Config.REDIS_CLIENT_CERT(): + params["ssl_certfile"] = Config.REDIS_CLIENT_CERT() + if Config.REDIS_CLIENT_KEY(): + params["ssl_keyfile"] = Config.REDIS_CLIENT_KEY() + + cli = redis.Redis("localhost", port, **params) + + try: + return cli.ping() + except Exception: + pass + return False + + def start_redis(db_dir): retry_num = 0 while True: is_need_restart = False - # Setup external Redis and env var for initialization. - redis_ports = find_available_port(49159, 55536, redis_replicas()) - + # Setup external Redis and env var for initialization processes = [] enable_tls = "RAY_REDIS_CA_CERT" in os.environ leader_port = None leader_id = None - for port in redis_ports: + redis_ports = [] + while len(redis_ports) != redis_replicas(): + port = find_available_port(49159, 55536, 1)[0] print("Start Redis with port: ", port) temp_dir = ray._private.utils.get_ray_temp_dir() node_id, proc = start_redis_instance( @@ -189,6 +220,16 @@ def start_redis(db_dir): leader_id=leader_id, db_dir=db_dir, ) + try: + wait_for_condition( + redis_alive, 3, 100, port=port, enable_tls=enable_tls + ) + except Exception as e: + print(e) + continue + + redis_ports.append(port) + if leader_port is None: leader_port = port leader_id = node_id @@ -223,9 +264,27 @@ def start_redis(db_dir): return address_str, processes +def kill_all_redis_server(): + import psutil + + # Find Redis server processes + redis_procs = [] + for proc in psutil.process_iter(["name", "cmdline"]): + if ( + proc.info["name"] == "redis-server" + and "redis-server" in proc.info["cmdline"] + ): + redis_procs.append(proc) + + # Kill Redis server processes + for proc in redis_procs: + proc.kill() + + @contextmanager def _setup_redis(request): with tempfile.TemporaryDirectory() as tmpdirname: + kill_all_redis_server() address_str, processes = start_redis(tmpdirname) old_addr = os.environ.get("RAY_REDIS_ADDRESS") os.environ["RAY_REDIS_ADDRESS"] = address_str @@ -248,6 +307,7 @@ def _setup_redis(request): for proc in processes: proc.process.kill() + kill_all_redis_server() @pytest.fixture