Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

refactor(security): add join point to start negotiation #626

Merged
merged 16 commits into from
Sep 17, 2020
4 changes: 2 additions & 2 deletions include/dsn/dist/failure_detector/fd.code.definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ DEFINE_THREAD_POOL_CODE(THREAD_POOL_DEFAULT)
DEFINE_TASK_CODE_RPC(RPC_FD_FAILURE_DETECTOR_PING, TASK_PRIORITY_COMMON, THREAD_POOL_DEFAULT)
// test timer task code
DEFINE_TASK_CODE(LPC_FD_TEST_TIMER, TASK_PRIORITY_COMMON, THREAD_POOL_DEFAULT)
}
}
} // namespace fd
} // namespace dsn
20 changes: 8 additions & 12 deletions include/dsn/tool-api/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ class rpc_session : public ref_counter
bool delay_recv(int delay_ms);
bool on_recv_message(message_ex *msg, int delay_ms);

/// for negotiation
void start_negotiation();
security::negotiation *get_negotiation() const;
/// interfaces for security authentication,
/// you can ignore them if you don't enable auth
void set_negotiation_succeed();
bool is_negotiation_succeed() const;

public:
///
Expand All @@ -258,7 +259,6 @@ class rpc_session : public ref_counter
virtual void send(uint64_t signature) = 0;
void on_send_completed(uint64_t signature = 0);
virtual void on_failure(bool is_write = false);
virtual void on_success();

protected:
///
Expand All @@ -267,13 +267,15 @@ class rpc_session : public ref_counter
enum session_state
{
SS_CONNECTING,
SS_NEGOTIATING,
SS_CONNECTED,
SS_DISCONNECTED
};
::dsn::utils::ex_lock_nr _lock; // [
utils::ex_lock_nr _lock; // [
volatile session_state _connect_state;

bool negotiation_succeed = false;
acelyc111 marked this conversation as resolved.
Show resolved Hide resolved
// TODO(zlw): add send pending message

// messages are sent in batch, firstly all messages are linked together
// in a doubly-linked list "_messages".
// if no messages are on-the-flying, a batch of messages are fetch from the "_messages"
Expand All @@ -298,7 +300,6 @@ class rpc_session : public ref_counter
bool set_connecting();
// return true when it is permitted
bool set_disconnected();
void set_negotiation();
void set_connected();

void clear_send_queue(bool resend_msgs);
Expand All @@ -313,16 +314,11 @@ class rpc_session : public ref_counter
message_reader _reader;
message_parser_ptr _parser;

private:
void auth_negotiation();

private:
const bool _is_client;
rpc_client_matcher *_matcher;

std::atomic_int _delay_server_receive_ms;

std::unique_ptr<security::negotiation> _negotiation;
};

// --------- inline implementation --------------
Expand Down
3 changes: 0 additions & 3 deletions src/runtime/rpc/asio_net_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@ void asio_network_provider::do_accept()
null_parser,
false);

// start negotiation when server accepts the connection
s->start_negotiation();

// when server connection threshold is hit, close the session, otherwise accept it
if (check_if_conn_threshold_exceeded(s->remote_address())) {
dwarn("close rpc connection from %s to %s due to hitting server "
Expand Down
5 changes: 2 additions & 3 deletions src/runtime/rpc/asio_rpc_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,8 @@ void asio_rpc_session::connect()
dinfo("client session %s connected", _remote_addr.to_string());

set_options();

// start auth negotiation when client is connecting to server
start_negotiation();
set_connected();
on_send_completed();
start_read_next();
} else {
derror("client session connect to %s failed, error = %s",
Expand Down
48 changes: 7 additions & 41 deletions src/runtime/rpc/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ void rpc_session::set_connected()

{
utils::auto_lock<utils::ex_lock_nr> l(_lock);
dassert((_connect_state == SS_NEGOTIATING && security::FLAGS_enable_auth) ||
(_connect_state == SS_CONNECTING && !security::FLAGS_enable_auth),
"wrong session state");
dcheck_eq(_connect_state, SS_CONNECTING);
_connect_state = SS_CONNECTED;
}

Expand All @@ -86,17 +84,6 @@ void rpc_session::set_connected()
on_rpc_session_connected.execute(this);
}

void rpc_session::set_negotiation()
{
dassert(is_client(), "must be client session");

{
utils::auto_lock<utils::ex_lock_nr> l(_lock);
dassert(_connect_state == SS_CONNECTING, "session must be connecting");
_connect_state = SS_NEGOTIATING;
}
}

bool rpc_session::set_disconnected()
{
{
Expand Down Expand Up @@ -400,7 +387,7 @@ bool rpc_session::on_disconnected(bool is_write)

bool rpc_session::is_auth_success(message_ex *msg)
{
if (security::FLAGS_enable_auth && !_negotiation->negotiation_succeed()) {
if (security::FLAGS_enable_auth && !is_negotiation_succeed()) {
Comment on lines 388 to +390
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my point, if you want to decouple the negotiation entirely with rpc_session, you should remove this function and modify rpc_session::on_recv_message, where a session rejects the incoming requests when the negotiation is unfinished.

Nevertheless, there's no joint point for recv_message as far as I see. (There's one joint point at task_spec::on_rpc_call, but not suitable since it targets at customization of tasks, rather than general rpc control).

Maybe you can add such joint point later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will do it in the next pull request. I split the refactor of network into 3 pull requests.

dwarn_f("reject message({}) from {}, session {} client",
msg->rpc_code().to_string(),
_remote_addr.to_string(),
Expand All @@ -418,14 +405,6 @@ void rpc_session::on_failure(bool is_write)
}
}

void rpc_session::on_success()
{
if (is_client()) {
set_connected();
on_send_completed();
}
}

bool rpc_session::on_recv_message(message_ex *msg, int delay_ms)
{
if (msg->header->from_address.is_invalid())
Expand Down Expand Up @@ -478,28 +457,15 @@ bool rpc_session::on_recv_message(message_ex *msg, int delay_ms)
return true;
}

void rpc_session::start_negotiation()
void rpc_session::set_negotiation_succeed()
{
if (security::FLAGS_enable_auth) {
// set the negotiation state if it's a client rpc_session
if (is_client()) {
set_negotiation();
}

auth_negotiation();
} else {
// set negotiation success if auth is disabled
on_success();
}
}
utils::auto_lock<utils::ex_lock_nr> l(_lock);
negotiation_succeed = true;

void rpc_session::auth_negotiation()
{
_negotiation = security::create_negotiation(is_client(), this);
_negotiation->start();
// todo(zlw): resend pending messages when negotiation is succeed
}

security::negotiation *rpc_session::get_negotiation() const { return _negotiation.get(); }
bool rpc_session::is_negotiation_succeed() const { return negotiation_succeed; }

////////////////////////////////////////////////////////////////////////////////////////////////
network::network(rpc_engine *srv, network *inner_provider)
Expand Down
4 changes: 0 additions & 4 deletions src/runtime/rpc/network.sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class sim_client_session : public rpc_session
virtual void close() override {}

virtual void on_failure(bool is_write = false) override {}

virtual void on_success() override {}
};

class sim_server_session : public rpc_session
Expand All @@ -79,8 +77,6 @@ class sim_server_session : public rpc_session

virtual void on_failure(bool is_write = false) override {}

virtual void on_success() override {}

private:
rpc_session_ptr _client;
};
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void client_negotiation::send(negotiation_status::type status, const blob &msg)
void client_negotiation::succ_negotiation()
{
_status = negotiation_status::type::SASL_SUCC;
_session->on_success();
_session->set_negotiation_succeed();
}
} // namespace security
} // namespace dsn
2 changes: 2 additions & 0 deletions src/runtime/security/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "kinit_context.h"
#include "sasl_init.h"
#include "negotiation_service.h"

#include <dsn/dist/fmt_logging.h>
#include <dsn/utility/flags.h>
Expand Down Expand Up @@ -63,6 +64,7 @@ bool init(bool is_server)
}
ddebug("initialize sasl succeed");

init_join_point();
return true;
}

Expand Down
41 changes: 39 additions & 2 deletions src/runtime/security/negotiation_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
#include "server_negotiation.h"

#include <dsn/utility/flags.h>
#include <dsn/tool-api/zlocks.h>

namespace dsn {
namespace security {
DSN_DECLARE_bool(enable_auth);

negotiation_map negotiation_service::_negotiations;
zrwlock_nr negotiation_service::_lock;

negotiation_service::negotiation_service() : serverlet("negotiation_service") {}

void negotiation_service::open_service()
Expand All @@ -44,10 +48,43 @@ void negotiation_service::on_negotiation_request(negotiation_rpc rpc)
return;
}

server_negotiation *srv_negotiation =
static_cast<server_negotiation *>(rpc.dsn_request()->io_session->get_negotiation());
server_negotiation *srv_negotiation = nullptr;
{
zauto_read_lock l(_lock);
srv_negotiation =
levy5307 marked this conversation as resolved.
Show resolved Hide resolved
static_cast<server_negotiation *>(_negotiations[rpc.dsn_request()->io_session].get());
}

dassert(srv_negotiation != nullptr,
"negotiation is null for msg: {}",
rpc.dsn_request()->rpc_code().to_string());
srv_negotiation->handle_request(rpc);
}

void negotiation_service::on_rpc_connected(rpc_session *session)
{
std::unique_ptr<negotiation> nego = security::create_negotiation(session->is_client(), session);
nego->start();
{
zauto_write_lock l(_lock);
_negotiations[session] = std::move(nego);
acelyc111 marked this conversation as resolved.
Show resolved Hide resolved
}
}

void negotiation_service::on_rpc_disconnected(rpc_session *session)
{
{
zauto_write_lock l(_lock);
_negotiations.erase(session);
}
}

void init_join_point()
{
rpc_session::on_rpc_session_connected.put_back(negotiation_service::on_rpc_connected,
"security");
rpc_session::on_rpc_session_disconnected.put_back(negotiation_service::on_rpc_disconnected,
"security");
Comment on lines +84 to +87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}
} // namespace security
} // namespace dsn
10 changes: 10 additions & 0 deletions src/runtime/security/negotiation_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,32 @@
#include "server_negotiation.h"

#include <dsn/cpp/serverlet.h>
#include <dsn/tool-api/zlocks.h>

namespace dsn {
namespace security {
typedef std::unordered_map<rpc_session *, std::unique_ptr<negotiation>> negotiation_map;

class negotiation_service : public serverlet<negotiation_service>,
public utils::singleton<negotiation_service>
{
public:
static void on_rpc_connected(rpc_session *session);
static void on_rpc_disconnected(rpc_session *session);

void open_service();

private:
negotiation_service();
void on_negotiation_request(negotiation_rpc rpc);
friend class utils::singleton<negotiation_service>;
friend class negotiation_service_test;

static zrwlock_nr _lock; // [
static negotiation_map _negotiations;
//]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add '//]' here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

//[ and //] indicates that the _lock is used for protecting the members locates between [ and ]

};

void init_join_point();
} // namespace security
} // namespace dsn
10 changes: 8 additions & 2 deletions src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,19 @@ void server_negotiation::do_challenge(negotiation_rpc rpc, error_s err_s, const
}

if (err_s.is_ok()) {
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SUCC;
succ_negotiation(rpc);
} else {
negotiation_response &challenge = rpc.response();
_status = challenge.status = negotiation_status::type::SASL_CHALLENGE;
challenge.msg = resp_msg;
}
}

void server_negotiation::succ_negotiation(negotiation_rpc rpc)
{
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SUCC;
_session->set_negotiation_succeed();
}
} // namespace security
} // namespace dsn
1 change: 1 addition & 0 deletions src/runtime/security/server_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class server_negotiation : public negotiation
void on_challenge_resp(negotiation_rpc rpc);

void do_challenge(negotiation_rpc rpc, error_s err_s, const blob &resp_msg);
void succ_negotiation(negotiation_rpc rpc);

friend class server_negotiation_test;
};
Expand Down