Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

refactor(security): add join point to filter the sending message #629

Merged
merged 15 commits into from
Sep 22, 2020
11 changes: 10 additions & 1 deletion include/dsn/tool-api/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class rpc_session : public ref_counter
static join_point<void, rpc_session *> on_rpc_session_connected;
static join_point<void, rpc_session *> on_rpc_session_disconnected;
static join_point<bool, message_ex *> on_rpc_recv_message;
static join_point<bool, message_ex *> on_rpc_send_message;
/*@}*/
public:
rpc_session(connection_oriented_network &net,
Expand All @@ -232,6 +233,11 @@ class rpc_session : public ref_counter
bool cancel(message_ex *request);
bool delay_recv(int delay_ms);
bool on_recv_message(message_ex *msg, int delay_ms);
/// ret value:
/// true - pend succeed
/// false - pend failed
bool try_pend_message(message_ex *msg);
void clear_pending_messages();

/// interfaces for security authentication,
/// you can ignore them if you don't enable auth
Expand Down Expand Up @@ -275,7 +281,10 @@ class rpc_session : public ref_counter
volatile session_state _connect_state;

bool negotiation_succeed = false;
// TODO(zlw): add send pending message
// when the negotiation of a session isn't succeed,
// all messages are queued in _pending_messages.
// after connected, all of them are moved to "_messages"
std::vector<message_ex *> _pending_messages;
levy5307 marked this conversation as resolved.
Show resolved Hide resolved

// messages are sent in batch, firstly all messages are linked together
// in a doubly-linked list "_messages".
Expand Down
60 changes: 48 additions & 12 deletions src/runtime/rpc/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
* THE SOFTWARE.
*/

#include "runtime/security/negotiation_utils.h"
#include "message_parser_manager.h"
#include "runtime/rpc/rpc_engine.h"

Expand All @@ -40,9 +39,12 @@ namespace dsn {
rpc_session::on_rpc_session_disconnected("rpc.session.disconnected");
/*static*/ join_point<bool, message_ex *>
rpc_session::on_rpc_recv_message("rpc.session.recv.message");
/*static*/ join_point<bool, message_ex *>
rpc_session::on_rpc_send_message("rpc.session.send.message");

rpc_session::~rpc_session()
{
clear_pending_messages();
clear_send_queue(false);

{
Expand Down Expand Up @@ -250,9 +252,14 @@ int rpc_session::prepare_parser()
void rpc_session::send_message(message_ex *msg)
{
msg->add_ref(); // released in on_send_completed

msg->io_session = this;

// ignore msg if join point return false
if (dsn_unlikely(!on_rpc_send_message.execute(msg, true))) {
msg->release_ref();
return;
}

dassert(_parser, "parser should not be null when send");
_parser->prepare_on_send(msg);

Expand All @@ -262,11 +269,7 @@ void rpc_session::send_message(message_ex *msg)
msg->dl.insert_before(&_messages);
++_message_count;

// Attention: here we only allow two cases to send message:
// case 1: session's state is SS_CONNECTED
// case 2: session is sending negotiation message
if ((SS_CONNECTED == _connect_state || security::is_negotiation_message(msg->rpc_code())) &&
!_is_sending_next) {
if ((SS_CONNECTED == _connect_state) && !_is_sending_next) {
_is_sending_next = true;
sig = _message_sent + 1;
unlink_message_for_send();
Expand Down Expand Up @@ -397,7 +400,7 @@ bool rpc_session::on_recv_message(message_ex *msg, int delay_ms)
msg->io_session = this;

// ignore msg if join point return false
if (!on_rpc_recv_message.execute(msg, true)) {
if (dsn_unlikely(!on_rpc_recv_message.execute(msg, true))) {
delete msg;
return false;
}
Expand Down Expand Up @@ -437,12 +440,45 @@ bool rpc_session::on_recv_message(message_ex *msg, int delay_ms)
return true;
}

void rpc_session::set_negotiation_succeed()
bool rpc_session::try_pend_message(message_ex *msg)
{
// if negotiation is not succeed, we should pend msg,
// in order to resend it when the negotiation is succeed
if (dsn_unlikely(!negotiation_succeed)) {
utils::auto_lock<utils::ex_lock_nr> l(_lock);
if (!negotiation_succeed) {
msg->add_ref();
_pending_messages.push_back(msg);
return true;
}
}
return false;
}

void rpc_session::clear_pending_messages()
{
utils::auto_lock<utils::ex_lock_nr> l(_lock);
negotiation_succeed = true;
for (auto msg : _pending_messages) {
msg->release_ref();
}
_pending_messages.clear();
}

void rpc_session::set_negotiation_succeed()
{
std::vector<message_ex *> swapped_pending_msgs;
{
utils::auto_lock<utils::ex_lock_nr> l(_lock);
negotiation_succeed = true;

_pending_messages.swap(swapped_pending_msgs);
}

// todo(zlw): resend pending messages when negotiation is succeed
// resend the pending messages
for (auto msg : swapped_pending_msgs) {
send_message(msg);
msg->release_ref();
hycdong marked this conversation as resolved.
Show resolved Hide resolved
}
}

bool rpc_session::is_negotiation_succeed() const
Expand All @@ -451,7 +487,7 @@ bool rpc_session::is_negotiation_succeed() const
// Because negotiation_succeed only transfered from false to true.
// So if it is true now, it will not change in the later.
// But if it is false now, maybe it will change soon. So we should use lock to protect it.
if (negotiation_succeed) {
if (dsn_likely(negotiation_succeed)) {
return negotiation_succeed;
} else {
utils::auto_lock<utils::ex_lock_nr> l(_lock);
Expand Down
20 changes: 16 additions & 4 deletions src/runtime/security/negotiation_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ namespace dsn {
namespace security {
DSN_DECLARE_bool(enable_auth);

inline bool is_negotiation_message(dsn::task_code code)
{
return code == RPC_NEGOTIATION || code == RPC_NEGOTIATION_ACK;
}

inline bool in_white_list(task_code code)
{
return is_negotiation_message(code) || fd::is_failure_detector_message(code);
}

negotiation_map negotiation_service::_negotiations;
zrwlock_nr negotiation_service::_lock;
utils::rw_lock_nr negotiation_service::_lock;

negotiation_service::negotiation_service() : serverlet("negotiation_service") {}

Expand All @@ -56,7 +61,7 @@ void negotiation_service::on_negotiation_request(negotiation_rpc rpc)

server_negotiation *srv_negotiation = nullptr;
{
zauto_read_lock l(_lock);
utils::auto_read_lock l(_lock);
srv_negotiation =
static_cast<server_negotiation *>(_negotiations[rpc.dsn_request()->io_session].get());
}
Expand All @@ -72,15 +77,15 @@ void negotiation_service::on_rpc_connected(rpc_session *session)
std::unique_ptr<negotiation> nego = security::create_negotiation(session->is_client(), session);
nego->start();
{
zauto_write_lock l(_lock);
utils::auto_write_lock l(_lock);
_negotiations[session] = std::move(nego);
}
}

void negotiation_service::on_rpc_disconnected(rpc_session *session)
{
{
zauto_write_lock l(_lock);
utils::auto_write_lock l(_lock);
_negotiations.erase(session);
}
}
Expand All @@ -90,13 +95,20 @@ bool negotiation_service::on_rpc_recv_msg(message_ex *msg)
return in_white_list(msg->rpc_code()) || msg->io_session->is_negotiation_succeed();
}

bool negotiation_service::on_rpc_send_msg(message_ex *msg)
{
// if try_pend_message return true, it means the msg is pended to the resend message queue
return in_white_list(msg->rpc_code()) || !msg->io_session->try_pend_message(msg);
}

void init_join_point()
{
rpc_session::on_rpc_session_connected.put_back(negotiation_service::on_rpc_connected,
"security");
rpc_session::on_rpc_session_disconnected.put_back(negotiation_service::on_rpc_disconnected,
"security");
rpc_session::on_rpc_recv_message.put_native(negotiation_service::on_rpc_recv_msg);
rpc_session::on_rpc_send_message.put_native(negotiation_service::on_rpc_send_msg);
}
} // namespace security
} // namespace dsn
4 changes: 2 additions & 2 deletions src/runtime/security/negotiation_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "server_negotiation.h"

#include <dsn/cpp/serverlet.h>
#include <dsn/tool-api/zlocks.h>

namespace dsn {
namespace security {
Expand All @@ -33,6 +32,7 @@ class negotiation_service : public serverlet<negotiation_service>,
static void on_rpc_connected(rpc_session *session);
static void on_rpc_disconnected(rpc_session *session);
static bool on_rpc_recv_msg(message_ex *msg);
static bool on_rpc_send_msg(message_ex *msg);

void open_service();

Expand All @@ -42,7 +42,7 @@ class negotiation_service : public serverlet<negotiation_service>,
friend class utils::singleton<negotiation_service>;
friend class negotiation_service_test;

static zrwlock_nr _lock; // [
static utils::rw_lock_nr _lock; // [
static negotiation_map _negotiations;
//]
};
Expand Down
6 changes: 0 additions & 6 deletions src/runtime/security/negotiation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,5 @@ inline const char *enum_to_string(negotiation_status::type s)
}

DEFINE_TASK_CODE_RPC(RPC_NEGOTIATION, TASK_PRIORITY_COMMON, dsn::THREAD_POOL_DEFAULT)

inline bool is_negotiation_message(dsn::task_code code)
{
return code == RPC_NEGOTIATION || code == RPC_NEGOTIATION_ACK;
}

} // namespace security
} // namespace dsn
32 changes: 32 additions & 0 deletions src/runtime/test/negotiation_service_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class negotiation_service_test : public testing::Test
{
return negotiation_service::instance().on_rpc_recv_msg(msg);
}

bool on_rpc_send_msg(message_ex *msg)
{
return negotiation_service::instance().on_rpc_send_msg(msg);
}
};

TEST_F(negotiation_service_test, disable_auth)
Expand Down Expand Up @@ -98,5 +103,32 @@ TEST_F(negotiation_service_test, on_rpc_recv_msg)
ASSERT_EQ(test.return_value, on_rpc_recv_msg(msg));
}
}

TEST_F(negotiation_service_test, on_rpc_send_msg)
{
struct
{
task_code rpc_code;
bool negotiation_succeed;
bool return_value;
} tests[] = {{RPC_NEGOTIATION, true, true},
{RPC_NEGOTIATION_ACK, true, true},
{fd::RPC_FD_FAILURE_DETECTOR_PING, true, true},
{fd::RPC_FD_FAILURE_DETECTOR_PING_ACK, true, true},
{RPC_NEGOTIATION, false, true},
{RPC_HTTP_SERVICE, true, true},
{RPC_HTTP_SERVICE, false, false}};

for (const auto &test : tests) {
message_ptr msg = dsn::message_ex::create_request(test.rpc_code, 0, 0);
auto sim_session = create_fake_session();
msg->io_session = sim_session;
if (test.negotiation_succeed) {
sim_session->set_negotiation_succeed();
}

ASSERT_EQ(test.return_value, on_rpc_send_msg(msg));
}
}
} // namespace security
} // namespace dsn