Skip to content

Commit

Permalink
[core] [1/N] Ray syncer set observer (#49122)
Browse files Browse the repository at this point in the history
Signed-off-by: hjiang <dentinyhao@gmail.com>
  • Loading branch information
dentiny authored Dec 16, 2024
1 parent 4e23798 commit 9f51eb5
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 23 deletions.
11 changes: 11 additions & 0 deletions src/ray/common/ray_syncer/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@

#pragma once

#include <functional>

#include "ray/common/id.h"
#include "src/ray/protobuf/ray_syncer.grpc.pb.h"
#include "src/ray/protobuf/ray_syncer.pb.h"

namespace ray::syncer {

inline constexpr size_t kComponentArraySize =
static_cast<size_t>(ray::rpc::syncer::MessageType_ARRAYSIZE);

// TODO(hjiang): As of now, only ray syncer uses it so we put it under `ray_syncer`
// folder, better to place it into other common folders if uses elsewhere.
//
// A callback, which is called whenever a rpc succeeds (at rpc communication level)
// between the current node and the remote node.
using RpcCompletionCallback = std::function<void(const NodeID &)>;

} // namespace ray::syncer
1 change: 1 addition & 0 deletions src/ray/common/ray_syncer/node_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <array>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>

Expand Down
18 changes: 13 additions & 5 deletions src/ray/common/ray_syncer/ray_syncer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@
namespace ray::syncer {

RaySyncer::RaySyncer(instrumented_io_context &io_context,
const std::string &local_node_id)
const std::string &local_node_id,
RpcCompletionCallback on_rpc_completion)
: io_context_(io_context),
local_node_id_(local_node_id),
node_state_(std::make_unique<NodeState>()),
timer_(PeriodicalRunner::Create(io_context)) {
timer_(PeriodicalRunner::Create(io_context)),
on_rpc_completion_(std::move(on_rpc_completion)) {
stopped_ = std::make_shared<bool>(false);
}

RaySyncer::~RaySyncer() {
*stopped_ = true;
boost::asio::dispatch(io_context_.get_executor(), [reactors = sync_reactors_]() {
for (auto [_, reactor] : reactors) {
for (auto &[_, reactor] : reactors) {
reactor->Disconnect();
}
});
Expand Down Expand Up @@ -73,7 +75,7 @@ void RaySyncer::Connect(const std::string &node_id,
boost::asio::dispatch(
io_context_.get_executor(), std::packaged_task<void()>([=]() {
auto stub = ray::rpc::syncer::RaySyncer::NewStub(channel);
auto reactor = new RayClientBidiReactor(
auto *reactor = new RayClientBidiReactor(
/* remote_node_id */ node_id,
/* local_node_id */ GetLocalNodeID(),
/* io_context */ io_context_,
Expand Down Expand Up @@ -111,6 +113,11 @@ void RaySyncer::Connect(const std::string &node_id,
}

void RaySyncer::Connect(RaySyncerBidiReactor *reactor) {
// Bind rpc completion callback.
if (on_rpc_completion_) {
reactor->SetRpcCompletionCallbackForOnce(on_rpc_completion_);
}

boost::asio::dispatch(
io_context_.get_executor(), std::packaged_task<void()>([this, reactor]() {
auto [_, is_new] = sync_reactors_.emplace(reactor->GetRemoteNodeID(), reactor);
Expand Down Expand Up @@ -210,7 +217,8 @@ ServerBidiReactor *RaySyncerService::StartSync(grpc::CallbackServerContext *cont
context,
syncer_.GetIOContext(),
syncer_.GetLocalNodeID(),
[this](auto msg) mutable { syncer_.BroadcastMessage(msg); },
/*message_processor=*/[this](auto msg) mutable { syncer_.BroadcastMessage(msg); },
/*cleanup_cb=*/
[this](RaySyncerBidiReactor *reactor, bool reconnect) mutable {
// No need to reconnect for server side.
RAY_CHECK(!reconnect);
Expand Down
9 changes: 8 additions & 1 deletion src/ray/common/ray_syncer/ray_syncer.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ class RaySyncer {
///
/// \param io_context The io context for this component.
/// \param node_id The id of current node.
RaySyncer(instrumented_io_context &io_context, const std::string &node_id);
/// \param on_rpc_completion A callback which invokes after a sync rpc succeeds.
RaySyncer(instrumented_io_context &io_context,
const std::string &node_id,
RpcCompletionCallback on_rpc_completion = {});
~RaySyncer();

/// Connect to a node.
Expand Down Expand Up @@ -168,6 +171,10 @@ class RaySyncer {
/// Timer is used to do broadcasting.
std::shared_ptr<PeriodicalRunner> timer_;

/// Sync message observer, which is a callback on received message response for
/// [RaySyncerBidiReactor], so should be passed to each of them.
RpcCompletionCallback on_rpc_completion_;

friend class RaySyncerService;
/// Test purpose
friend struct SyncerServerTest;
Expand Down
14 changes: 14 additions & 0 deletions src/ray/common/ray_syncer/ray_syncer_bidi_reactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>
#include <string>

#include "src/ray/common/ray_syncer/common.h"
#include "src/ray/protobuf/ray_syncer.grpc.pb.h"

namespace ray::syncer {
Expand Down Expand Up @@ -93,11 +94,24 @@ class RaySyncerBidiReactor {
}
};

/// Set rpc completion callback, which is called after rpc read finishes.
/// This function is expected to call only once, repeated invocations will check fail.
void SetRpcCompletionCallbackForOnce(RpcCompletionCallback on_rpc_completion) {
RAY_CHECK(on_rpc_completion);
RAY_CHECK(!on_rpc_completion_);
on_rpc_completion_ = std::move(on_rpc_completion);
}

/// Return true if it's disconnected.
std::shared_ptr<bool> IsDisconnected() const { return disconnected_; }

// Node id which is communicating with the current reactor.
std::string remote_node_id_;

protected:
/// Sync message observer, which is a callback on received message response.
RpcCompletionCallback on_rpc_completion_;

private:
virtual void DoDisconnect() = 0;
std::shared_ptr<bool> disconnected_ = std::make_shared<bool>(false);
Expand Down
16 changes: 11 additions & 5 deletions src/ray/common/ray_syncer/ray_syncer_bidi_reactor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,21 @@ class RaySyncerBidiReactorBase : public RaySyncerBidiReactor, public T {
if (*disconnected) {
return;
}
if (ok) {
RAY_CHECK(!msg->node_id().empty());
ReceiveUpdate(std::move(msg));
StartPull();
} else {

if (!ok) {
RAY_LOG_EVERY_MS(INFO, 1000) << "Failed to read the message from: "
<< NodeID::FromBinary(GetRemoteNodeID());
Disconnect();
return;
}

// Successful rpc completion callback.
RAY_CHECK(!msg->node_id().empty());
if (on_rpc_completion_) {
on_rpc_completion_(NodeID::FromBinary(remote_node_id_));
}
ReceiveUpdate(std::move(msg));
StartPull();
},
"");
}
Expand Down
61 changes: 49 additions & 12 deletions src/ray/common/test/ray_syncer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// clang-format off
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <chrono>
#include <sstream>
#include <gmock/gmock.h>
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/message_differencer.h>
#include <grpc/grpc.h>
#include <grpcpp/create_channel.h>
#include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/util/json_util.h>
#include <grpcpp/security/credentials.h>
#include <grpcpp/security/server_credentials.h>
#include <grpcpp/server.h>
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>

#include <chrono>
#include <sstream>

// clang-format off
#include "ray/common/ray_syncer/node_state.h"
#include "ray/common/ray_syncer/ray_syncer.h"
#include "ray/common/ray_syncer/ray_syncer_client.h"
Expand Down Expand Up @@ -203,15 +204,23 @@ TEST_F(RaySyncerTest, RaySyncerBidiReactorBase) {
}

struct SyncerServerTest {
SyncerServerTest(std::string port) : work_guard(io_context.get_executor()) {
SyncerServerTest(std::string port)
: SyncerServerTest(
std::move(port), /*node_id=*/NodeID::FromRandom(), /*ray_sync_observer=*/{}) {
}

SyncerServerTest(std::string port,
NodeID node_id,
RpcCompletionCallback ray_sync_observer)
: work_guard(io_context.get_executor()) {
this->server_port = port;
// Setup io context
auto node_id = NodeID::FromRandom();
for (auto &v : local_versions) {
v = 0;
}
// Setup syncer and grpc server
syncer = std::make_unique<RaySyncer>(io_context, node_id.Binary());
syncer = std::make_unique<RaySyncer>(
io_context, node_id.Binary(), std::move(ray_sync_observer));
thread = std::make_unique<std::thread>([this] { io_context.run(); });

auto server_address = std::string("0.0.0.0:") + port;
Expand Down Expand Up @@ -421,6 +430,14 @@ class SyncerTest : public ::testing::Test {
return *servers.back();
}

SyncerServerTest &MakeServer(std::string port,
NodeID node_id,
RpcCompletionCallback on_rpc_completion) {
servers.emplace_back(std::make_unique<SyncerServerTest>(
port, std::move(node_id), std::move(on_rpc_completion)));
return *servers.back();
}

protected:
void TearDown() override {
// Drain all grpc requests.
Expand All @@ -434,9 +451,25 @@ class SyncerTest : public ::testing::Test {
};

TEST_F(SyncerTest, Test1To1) {
auto &s1 = MakeServer("19990");
// Generate node ids for checking.
NodeID node_id1 = NodeID::FromRandom();
NodeID node_id2 = NodeID::FromRandom();

// Used to check the number of messages consumed for two servers.
int s1_observer_cb_call_cnt = 0;
int s2_observer_cb_call_cnt = 0;

// Register observer callback for syncers.
auto syncer_observer_cb = [&](const NodeID &node_id) {
if (node_id == node_id1) {
++s1_observer_cb_call_cnt;
} else if (node_id == node_id2) {
++s2_observer_cb_call_cnt;
}
};

auto &s2 = MakeServer("19991");
auto &s1 = MakeServer("19990", node_id1, syncer_observer_cb);
auto &s2 = MakeServer("19991", node_id2, syncer_observer_cb);

// Make sure the setup is correct
ASSERT_NE(nullptr, s1.receivers[MessageType::RESOURCE_VIEW]);
Expand Down Expand Up @@ -538,6 +571,10 @@ TEST_F(SyncerTest, Test1To1) {
ASSERT_LE(s1.GetNumConsumedMessages(s2.syncer->GetLocalNodeID()), max_sends * 2 + 3);
// s1 has one reporter + 1 for the one send before the measure
ASSERT_LE(s2.GetNumConsumedMessages(s1.syncer->GetLocalNodeID()), max_sends + 3);

// Make sure registered callbacks have been called.
ASSERT_GT(s1_observer_cb_call_cnt, 0);
ASSERT_GT(s2_observer_cb_call_cnt, 0);
}

TEST_F(SyncerTest, Reconnect) {
Expand Down

0 comments on commit 9f51eb5

Please sign in to comment.