Skip to content

Commit

Permalink
[Core] GCS FT with redis sentinel (ray-project#47335)
Browse files Browse the repository at this point in the history
Signed-off-by: Kan Wang <kan.wang@datadoghq.com>
Signed-off-by: Connor Sanders <connor@elastiflow.com>
  • Loading branch information
kanwang authored and jecsand838 committed Dec 4, 2024
1 parent a20b193 commit 3558655
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ See {ref}`Ray Serve end-to-end fault tolerance documentation <serve-e2e-ft-guide

* Ray 2.0.0+
* KubeRay 0.6.0+
* Redis: single shard, one or multiple replicas
* Redis: single shard Redis Cluster or Redis Sentinel, one or multiple replicas

## Quickstart

Expand Down
63 changes: 63 additions & 0 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def redis_replicas():
return int(os.environ.get("TEST_EXTERNAL_REDIS_REPLICAS", "1"))


def redis_sentinel_replicas():
import os

return int(os.environ.get("TEST_EXTERNAL_REDIS_SENTINEL_REPLICAS", "2"))


def get_redis_cli(port, enable_tls):
try:
# If there is no redis libs installed, skip the check.
Expand All @@ -122,6 +128,63 @@ def get_redis_cli(port, enable_tls):
return redis.Redis("localhost", str(port), **params)


def start_redis_sentinel_instance(
session_dir_path: str,
port: int,
redis_master_port: int,
password: Optional[str] = None,
enable_tls: bool = False,
db_dir=None,
free_port=0,
):
config_file = os.path.join(
session_dir_path, "redis-sentinel-" + uuid.uuid4().hex + ".conf"
)
config_lines = []
# Port for this Sentinel instance
if enable_tls:
config_lines.append(f"port {free_port}")
else:
config_lines.append(f"port {port}")

# Monitor the Redis master
config_lines.append(f"sentinel monitor redis-test 127.0.0.1 {redis_master_port} 1")
config_lines.append(
"sentinel down-after-milliseconds redis-test 1000"
) # failover after 1 second
config_lines.append("sentinel failover-timeout redis-test 5000") #
config_lines.append("sentinel parallel-syncs redis-test 1")

if password:
config_lines.append(f"sentinel auth-pass redis-test {password}")

if enable_tls:
config_lines.append(f"tls-port {port}")
if Config.REDIS_CA_CERT():
config_lines.append(f"tls-ca-cert-file {Config.REDIS_CA_CERT()}")
# Check and add TLS client certificate file
if Config.REDIS_CLIENT_CERT():
config_lines.append(f"tls-cert-file {Config.REDIS_CLIENT_CERT()}")
# Check and add TLS client key file
if Config.REDIS_CLIENT_KEY():
config_lines.append(f"tls-key-file {Config.REDIS_CLIENT_KEY()}")
config_lines.append("tls-auth-clients no")
config_lines.append("sentinel tls-auth-clients redis-test no")
if db_dir:
config_lines.append(f"dir {db_dir}")

with open(config_file, "w") as f:
f.write("\n".join(config_lines))

command = [REDIS_EXECUTABLE, config_file, "--sentinel"]
process_info = ray._private.services.start_ray_process(
command,
ray_constants.PROCESS_TYPE_REDIS_SERVER,
fate_share=False,
)
return process_info


def start_redis_instance(
session_dir_path: str,
port: int,
Expand Down
54 changes: 52 additions & 2 deletions python/ray/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
This file defines the common pytest fixtures used in current directory.
"""

import json
import logging
import os
Expand Down Expand Up @@ -34,6 +35,8 @@
redis_replicas,
get_redis_cli,
start_redis_instance,
start_redis_sentinel_instance,
redis_sentinel_replicas,
find_available_port,
wait_for_condition,
find_free_port,
Expand Down Expand Up @@ -201,6 +204,34 @@ def redis_alive(port, enable_tls):
return False


def start_redis_with_sentinel(db_dir):
temp_dir = ray._private.utils.get_ray_temp_dir()

redis_ports = find_available_port(49159, 55535, redis_sentinel_replicas() + 1)
sentinel_port = redis_ports[0]
master_port = redis_ports[1]
redis_processes = [
start_redis_instance(temp_dir, p, listen_to_localhost_only=True, db_dir=db_dir)[
1
]
for p in redis_ports[1:]
]

# ensure all redis servers are up
for port in redis_ports[1:]:
wait_for_condition(redis_alive, 3, 100, port=port, enable_tls=False)

# setup replicas of the master
for port in redis_ports[2:]:
redis_cli = get_redis_cli(port, False)
redis_cli.replicaof("127.0.0.1", master_port)
sentinel_process = start_redis_sentinel_instance(
temp_dir, sentinel_port, master_port
)
address_str = f"127.0.0.1:{sentinel_port}"
return address_str, redis_processes + [sentinel_process]


def start_redis(db_dir):
retry_num = 0
while True:
Expand Down Expand Up @@ -289,10 +320,14 @@ def kill_all_redis_server():


@contextmanager
def _setup_redis(request):
def _setup_redis(request, with_sentinel=False):
with tempfile.TemporaryDirectory() as tmpdirname:
kill_all_redis_server()
address_str, processes = start_redis(tmpdirname)
address_str, processes = (
start_redis_with_sentinel(tmpdirname)
if with_sentinel
else start_redis(tmpdirname)
)
old_addr = os.environ.get("RAY_REDIS_ADDRESS")
os.environ["RAY_REDIS_ADDRESS"] = address_str
import uuid
Expand Down Expand Up @@ -332,6 +367,12 @@ def external_redis(request):
yield


@pytest.fixture
def external_redis_with_sentinel(request):
with _setup_redis(request, True):
yield


@pytest.fixture
def shutdown_only(maybe_external_redis):
yield None
Expand Down Expand Up @@ -535,6 +576,15 @@ def ray_start_cluster_head_with_external_redis(request, external_redis):
yield res


@pytest.fixture
def ray_start_cluster_head_with_external_redis_sentinel(
request, external_redis_with_sentinel
):
param = getattr(request, "param", {})
with _ray_start_cluster(do_init=True, num_nodes=1, **param) as res:
yield res


@pytest.fixture
def ray_start_cluster_head_with_env_vars(request, maybe_external_redis, monkeypatch):
param = getattr(request, "param", {})
Expand Down
115 changes: 115 additions & 0 deletions python/ray/tests/test_gcs_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
wait_for_condition,
wait_for_pid_to_exit,
run_string_as_driver,
redis_sentinel_replicas,
)
from ray.job_submission import JobSubmissionClient, JobStatus
from ray._raylet import GcsClient
Expand Down Expand Up @@ -871,6 +872,120 @@ def f():
wait_for_pid_to_exit(gcs_server_pid, 10000)


@pytest.mark.parametrize(
"ray_start_cluster_head_with_external_redis_sentinel",
[
generate_system_config_map(
gcs_rpc_server_reconnect_timeout_s=60,
gcs_server_request_timeout_seconds=10,
redis_db_connect_retries=50,
)
],
indirect=True,
)
def test_redis_with_sentinel_failureover(
ray_start_cluster_head_with_external_redis_sentinel,
):
"""This test is to cover ray cluster's behavior with Redis sentinel.
The expectation is Redis sentinel should manage failover
automatically, and GCS can continue talking to the same address
without any human intervention on Redis.
For this test we ensure:
- When Redis master failed, Ray should crash (TODO: GCS should
autommatically try re-connect to sentinel).
- When restart Ray, it should continue talking to sentinel, which
should return information about new master.
"""
cluster = ray_start_cluster_head_with_external_redis_sentinel
import redis

redis_addr = os.environ.get("RAY_REDIS_ADDRESS")
ip, port = redis_addr.split(":")
redis_cli = redis.Redis(ip, port)
print(redis_cli.info("sentinel"))
redis_name = redis_cli.info("sentinel")["master0"]["name"]

def get_sentinel_nodes():
leader_address = (
redis_cli.sentinel_master(redis_name)["ip"],
redis_cli.sentinel_master(redis_name)["port"],
)
follower_addresses = [
(x["ip"], x["port"]) for x in redis_cli.sentinel_slaves(redis_name)
]
return [leader_address] + follower_addresses

wait_for_condition(lambda: len(get_sentinel_nodes()) == redis_sentinel_replicas())

@ray.remote(max_restarts=-1)
class Counter:
def r(self, v):
return v

def pid(self):
import os

return os.getpid()

c = Counter.options(name="c", namespace="test", lifetime="detached").remote()
c_pid = ray.get(c.pid.remote())
c_process = psutil.Process(pid=c_pid)
r = ray.get(c.r.remote(10))
assert r == 10

head_node = cluster.head_node
gcs_server_process = head_node.all_processes["gcs_server"][0].process
gcs_server_pid = gcs_server_process.pid

leader_cli = redis.Redis(*get_sentinel_nodes()[0])
leader_pid = leader_cli.info()["process_id"]
follower_cli = [redis.Redis(*x) for x in get_sentinel_nodes()[1:]]

# Wait until all data is updated in the replica
leader_cli.set("_hole", "0")
wait_for_condition(lambda: all([b"_hole" in f.keys("*") for f in follower_cli]))
current_leader = get_sentinel_nodes()[0]

# Now kill pid
leader_process = psutil.Process(pid=leader_pid)
leader_process.kill()

print(">>> Waiting gcs server to exit", gcs_server_pid)
wait_for_pid_to_exit(gcs_server_pid, 1000)
print("GCS killed")

wait_for_condition(lambda: current_leader != get_sentinel_nodes()[0])

# Kill Counter actor. It should restart after GCS is back
c_process.kill()
# Cleanup the in memory data and then start gcs
cluster.head_node.kill_gcs_server(False)

print("Start gcs")
sleep(2)
cluster.head_node.start_gcs_server()

assert len(ray.nodes()) == 1
assert ray.nodes()[0]["alive"]

driver_script = f"""
import ray
ray.init('{cluster.address}')
@ray.remote
def f():
return 10
assert ray.get(f.remote()) == 10
c = ray.get_actor("c", namespace="test")
v = ray.get(c.r.remote(10))
assert v == 10
print("DONE")
"""

# Make sure the cluster is usable
wait_for_condition(lambda: "DONE" in run_string_as_driver(driver_script))


@pytest.mark.parametrize(
"ray_start_regular",
[
Expand Down
Loading

0 comments on commit 3558655

Please sign in to comment.