Skip to content

Commit

Permalink
[Core] Fix check failure: sync_reactors_.find(reactor->GetRemoteNodeI…
Browse files Browse the repository at this point in the history
…D()) == sync_reactors_.end() (#47861)

Signed-off-by: Jiajun Yao <jeromeyjj@gmail.com>
  • Loading branch information
jjyao authored Oct 4, 2024
1 parent b86d85f commit 2532cca
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 46 deletions.
20 changes: 16 additions & 4 deletions python/ray/tests/conftest_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytest_docker_tools import container, fetch, network, volume
from pytest_docker_tools import wrappers
import subprocess
import docker
from typing import List

# If you need to debug tests using fixtures in this file,
Expand Down Expand Up @@ -65,7 +66,13 @@ def print_logs(self):
print(content.decode())


gcs_network = network(driver="bridge")
# This allows us to assign static ips to docker containers
ipam_config = docker.types.IPAMConfig(
pool_configs=[
docker.types.IPAMPool(subnet="192.168.52.0/24", gateway="192.168.52.254")
]
)
gcs_network = network(driver="bridge", ipam=ipam_config)

redis_image = fetch(repository="redis:latest")

Expand Down Expand Up @@ -96,6 +103,8 @@ def gen_head_node(envs):
# ip:port is treated as a different raylet.
"--node-manager-port",
"9379",
"--dashboard-host",
"0.0.0.0",
],
volumes={"{head_node_vol.name}": {"bind": "/tmp", "mode": "rw"}},
environment=envs,
Expand All @@ -109,7 +118,7 @@ def gen_head_node(envs):
)


def gen_worker_node(envs):
def gen_worker_node(envs, num_cpus):
return container(
image="rayproject/ray:ha_integration",
network="{gcs_network.name}",
Expand All @@ -123,6 +132,8 @@ def gen_worker_node(envs):
# ip:port is treated as a different raylet.
"--node-manager-port",
"9379",
"--num-cpus",
f"{num_cpus}",
],
volumes={"{worker_node_vol.name}": {"bind": "/tmp", "mode": "rw"}},
environment=envs,
Expand All @@ -145,11 +156,12 @@ def gen_worker_node(envs):
)

worker_node = gen_worker_node(
{
envs={
"RAY_REDIS_ADDRESS": "{redis.ips.primary}:6379",
"RAY_raylet_client_num_connect_attempts": "10",
"RAY_raylet_client_connect_timeout_milliseconds": "100",
}
},
num_cpus=8,
)


Expand Down
76 changes: 74 additions & 2 deletions python/ray/tests/test_network_failure_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ def f():
)

worker = gen_worker_node(
{
envs={
"RAY_grpc_keepalive_time_ms": "1000",
"RAY_grpc_client_keepalive_time_ms": "1000",
"RAY_grpc_client_keepalive_timeout_ms": "1000",
"RAY_health_check_initial_delay_ms": "1000",
"RAY_health_check_period_ms": "1000",
"RAY_health_check_timeout_ms": "1000",
"RAY_health_check_failure_threshold": "2",
}
},
num_cpus=8,
)


Expand Down Expand Up @@ -124,6 +125,77 @@ def check_task_pending(n=0):
wait_for_condition(lambda: check_task_pending(2))


head2 = gen_head_node(
{
"RAY_grpc_keepalive_time_ms": "1000",
"RAY_grpc_client_keepalive_time_ms": "1000",
"RAY_grpc_client_keepalive_timeout_ms": "1000",
"RAY_health_check_initial_delay_ms": "1000",
"RAY_health_check_period_ms": "1000",
"RAY_health_check_timeout_ms": "100000",
"RAY_health_check_failure_threshold": "20",
}
)

worker2 = gen_worker_node(
envs={
"RAY_grpc_keepalive_time_ms": "1000",
"RAY_grpc_client_keepalive_time_ms": "1000",
"RAY_grpc_client_keepalive_timeout_ms": "1000",
"RAY_health_check_initial_delay_ms": "1000",
"RAY_health_check_period_ms": "1000",
"RAY_health_check_timeout_ms": "100000",
"RAY_health_check_failure_threshold": "20",
},
num_cpus=2,
)


def test_transient_network_error(head2, worker2, gcs_network):
# Test to make sure the head node and worker node
# connection can be recovered from transient network error.
network = gcs_network

check_two_nodes = """
import sys
import ray
from ray._private.test_utils import wait_for_condition
ray.init()
wait_for_condition(lambda: len(ray.nodes()) == 2)
"""
result = head2.exec_run(cmd=f"python -c '{check_two_nodes}'")
assert result.exit_code == 0, result.output.decode("utf-8")

# Simulate transient network error
worker_ip = worker2._container.attrs["NetworkSettings"]["Networks"][network.name][
"IPAddress"
]
network.disconnect(worker2.name, force=True)
sleep(2)
network.connect(worker2.name, ipv4_address=worker_ip)

# Make sure the connection is recovered by scheduling
# an actor.
check_actor_scheduling = """
import ray
from ray._private.test_utils import wait_for_condition
ray.init()
@ray.remote(num_cpus=1)
class Actor:
def ping(self):
return 1
actor = Actor.remote()
ray.get(actor.ping.remote())
wait_for_condition(lambda: ray.available_resources()["CPU"] == 1.0)
"""
result = head2.exec_run(cmd=f"python -c '{check_actor_scheduling}'")
assert result.exit_code == 0, result.output.decode("utf-8")


if __name__ == "__main__":
import os

Expand Down
8 changes: 4 additions & 4 deletions src/ray/common/ray_syncer/ray_syncer-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase<ServerBidiReactor>
instrumented_io_context &io_context,
const std::string &local_node_id,
std::function<void(std::shared_ptr<const RaySyncMessage>)> message_processor,
std::function<void(const std::string &, bool)> cleanup_cb);
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb);

~RayServerBidiReactor() override = default;

Expand All @@ -379,7 +379,7 @@ class RayServerBidiReactor : public RaySyncerBidiReactorBase<ServerBidiReactor>
void OnDone() override;

/// Cleanup callback when the call ends.
const std::function<void(const std::string &, bool)> cleanup_cb_;
const std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb_;

/// grpc callback context
grpc::CallbackServerContext *server_context_;
Expand All @@ -395,7 +395,7 @@ class RayClientBidiReactor : public RaySyncerBidiReactorBase<ClientBidiReactor>
const std::string &local_node_id,
instrumented_io_context &io_context,
std::function<void(std::shared_ptr<const RaySyncMessage>)> message_processor,
std::function<void(const std::string &, bool)> cleanup_cb,
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb,
std::unique_ptr<ray::rpc::syncer::RaySyncer::Stub> stub);

~RayClientBidiReactor() override = default;
Expand All @@ -406,7 +406,7 @@ class RayClientBidiReactor : public RaySyncerBidiReactorBase<ClientBidiReactor>
void OnDone(const grpc::Status &status) override;

/// Cleanup callback when the call ends.
const std::function<void(const std::string &, bool)> cleanup_cb_;
const std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb_;

/// grpc callback context
grpc::ClientContext client_context_;
Expand Down
41 changes: 31 additions & 10 deletions src/ray/common/ray_syncer/ray_syncer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ RayServerBidiReactor::RayServerBidiReactor(
instrumented_io_context &io_context,
const std::string &local_node_id,
std::function<void(std::shared_ptr<const RaySyncMessage>)> message_processor,
std::function<void(const std::string &, bool)> cleanup_cb)
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb)
: RaySyncerBidiReactorBase<ServerBidiReactor>(
io_context,
GetNodeIDFromServerContext(server_context),
Expand All @@ -122,7 +122,7 @@ void RayServerBidiReactor::OnCancel() {
void RayServerBidiReactor::OnDone() {
io_context_.dispatch(
[this, cleanup_cb = cleanup_cb_, remote_node_id = GetRemoteNodeID()]() {
cleanup_cb(remote_node_id, false);
cleanup_cb(this, false);
delete this;
},
"");
Expand All @@ -133,7 +133,7 @@ RayClientBidiReactor::RayClientBidiReactor(
const std::string &local_node_id,
instrumented_io_context &io_context,
std::function<void(std::shared_ptr<const RaySyncMessage>)> message_processor,
std::function<void(const std::string &, bool)> cleanup_cb,
std::function<void(RaySyncerBidiReactor *, bool)> cleanup_cb,
std::unique_ptr<ray::rpc::syncer::RaySyncer::Stub> stub)
: RaySyncerBidiReactorBase<ClientBidiReactor>(
io_context, remote_node_id, std::move(message_processor)),
Expand All @@ -151,7 +151,7 @@ RayClientBidiReactor::RayClientBidiReactor(
void RayClientBidiReactor::OnDone(const grpc::Status &status) {
io_context_.dispatch(
[this, status]() {
cleanup_cb_(GetRemoteNodeID(), !status.ok());
cleanup_cb_(this, !status.ok());
delete this;
},
"");
Expand Down Expand Up @@ -221,7 +221,13 @@ void RaySyncer::Connect(const std::string &node_id,
/* io_context */ io_context_,
/* message_processor */ [this](auto msg) { BroadcastRaySyncMessage(msg); },
/* cleanup_cb */
[this, channel](const std::string &node_id, bool restart) {
[this, channel](RaySyncerBidiReactor *reactor, bool restart) {
const std::string &node_id = reactor->GetRemoteNodeID();
if (sync_reactors_.contains(node_id) &&
sync_reactors_.at(node_id) != reactor) {
// The client is already reconnected.
return;
}
sync_reactors_.erase(node_id);
if (restart) {
execute_after(
Expand All @@ -247,7 +253,7 @@ void RaySyncer::Connect(RaySyncerBidiReactor *reactor) {
boost::asio::dispatch(
io_context_.get_executor(), std::packaged_task<void()>([this, reactor]() {
RAY_CHECK(sync_reactors_.find(reactor->GetRemoteNodeID()) == sync_reactors_.end())
<< reactor->GetRemoteNodeID();
<< NodeID::FromBinary(reactor->GetRemoteNodeID());
sync_reactors_[reactor->GetRemoteNodeID()] = reactor;
// Send the view for new connections.
for (const auto &[_, messages] : node_state_->GetClusterView()) {
Expand All @@ -274,9 +280,7 @@ void RaySyncer::Disconnect(const std::string &node_id) {
}

auto reactor = iter->second;
if (iter != sync_reactors_.end()) {
sync_reactors_.erase(iter);
}
sync_reactors_.erase(iter);
reactor->Disconnect();
});
boost::asio::dispatch(io_context_.get_executor(), std::move(task)).get();
Expand Down Expand Up @@ -350,14 +354,31 @@ ServerBidiReactor *RaySyncerService::StartSync(grpc::CallbackServerContext *cont
syncer_.GetIOContext(),
syncer_.GetLocalNodeID(),
[this](auto msg) mutable { syncer_.BroadcastMessage(msg); },
[this](const std::string &node_id, bool reconnect) mutable {
[this](RaySyncerBidiReactor *reactor, bool reconnect) mutable {
// No need to reconnect for server side.
RAY_CHECK(!reconnect);
const auto &node_id = reactor->GetRemoteNodeID();
if (syncer_.sync_reactors_.contains(node_id) &&
syncer_.sync_reactors_.at(node_id) != reactor) {
// There is a new connection to the node, no need to clean up.
// This can happen when there is transient network error and the client
// reconnects. The sequence of events are:
// 1. Client reconnects, StartSync is called
// 2. syncer_.Disconnect is called and the old reactor is removed from
// sync_reactors_
// 3. syncer_.Connect is called and the new reactor is added to sync_reactors_
// 4. OnDone method of the old reactor is called which calls this cleanup_cb_
return;
}
syncer_.sync_reactors_.erase(node_id);
syncer_.node_state_->RemoveNode(node_id);
});
RAY_LOG(INFO).WithField(kLogKeyNodeID, NodeID::FromBinary(reactor->GetRemoteNodeID()))
<< "Get connection";
// Disconnect exiting connection if there is any.
// This can happen when there is transient network error
// and the client reconnects.
syncer_.Disconnect(reactor->GetRemoteNodeID());
syncer_.Connect(reactor);
return reactor;
}
Expand Down
Loading

0 comments on commit 2532cca

Please sign in to comment.