Skip to content

Commit

Permalink
Support TLS for replication (#1630)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice authored Sep 12, 2023
1 parent 2ecaef2 commit f3d796d
Show file tree
Hide file tree
Showing 13 changed files with 279 additions and 39 deletions.
7 changes: 7 additions & 0 deletions kvrocks.conf
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@ redis-cursor-compatible no
#
# tls-session-cache-timeout 60

# By default, a replica does not attempt to establish a TLS connection
# with its master.
#
# Use the following directive to enable TLS on replication links.
#
# tls-replication yes

################################## SLOW LOG ###################################

# The Kvrocks Slow Log is a mechanism to log queries that exceeded a specified
Expand Down
74 changes: 57 additions & 17 deletions src/cluster/replication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "fmt/format.h"
#include "io_util.h"
#include "rocksdb_crc32c.h"
#include "scope_exit.h"
#include "server/redis_reply.h"
#include "server/server.h"
#include "status.h"
Expand All @@ -45,6 +46,12 @@
#include "time_util.h"
#include "unique_fd.h"

#ifdef ENABLE_OPENSSL
#include <event2/bufferevent_ssl.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#endif

Status FeedSlaveThread::Start() {
auto s = util::CreateThread("feed-replica", [this] {
sigset_t mask, omask;
Expand All @@ -54,7 +61,7 @@ Status FeedSlaveThread::Start() {
sigaddset(&mask, SIGHUP);
sigaddset(&mask, SIGPIPE);
pthread_sigmask(SIG_BLOCK, &mask, &omask);
auto s = util::SockSend(conn_->GetFD(), redis::SimpleString("OK"));
auto s = util::SockSend(conn_->GetFD(), redis::SimpleString("OK"), conn_->GetBufferEvent());
if (!s.IsOK()) {
LOG(ERROR) << "failed to send OK response to the replica: " << s.Msg();
return;
Expand Down Expand Up @@ -85,7 +92,7 @@ void FeedSlaveThread::Join() {
void FeedSlaveThread::checkLivenessIfNeed() {
if (++interval_ % 1000) return;
const auto ping_command = redis::BulkString("ping");
auto s = util::SockSend(conn_->GetFD(), ping_command);
auto s = util::SockSend(conn_->GetFD(), ping_command, conn_->GetBufferEvent());
if (!s.IsOK()) {
LOG(ERROR) << "Ping slave[" << conn_->GetAddr() << "] err: " << s.Msg() << ", would stop the thread";
Stop();
Expand Down Expand Up @@ -134,7 +141,7 @@ void FeedSlaveThread::loop() {
if (is_first_repl_batch || batches_bulk.size() >= kMaxDelayBytes || updates_in_batches >= kMaxDelayUpdates ||
srv_->storage->LatestSeqNumber() - batch.sequence <= kMaxDelayUpdates) {
// Send entire bulk which contain multiple batches
auto s = util::SockSend(conn_->GetFD(), batches_bulk);
auto s = util::SockSend(conn_->GetFD(), batches_bulk, conn_->GetBufferEvent());
if (!s.IsOK()) {
LOG(ERROR) << "Write error while sending batch to slave: " << s.Msg() << ". batches: 0x"
<< util::StringToHex(batches_bulk);
Expand Down Expand Up @@ -257,12 +264,35 @@ void ReplicationThread::CallbacksStateMachine::Start() {
LOG(ERROR) << "[replication] Failed to connect the master, err: " << cfd.Msg();
continue;
}
#ifdef ENABLE_OPENSSL
SSL *ssl = nullptr;
if (repl_->srv_->GetConfig()->tls_replication) {
ssl = SSL_new(repl_->srv_->ssl_ctx.get());
if (!ssl) {
LOG(ERROR) << "Failed to construct SSL structure for new connection: " << SSLErrors{};
evutil_closesocket(*cfd);
return;
}
bev = bufferevent_openssl_socket_new(repl_->base_, *cfd, ssl, BUFFEREVENT_SSL_CONNECTING, BEV_OPT_CLOSE_ON_FREE);
} else {
bev = bufferevent_socket_new(repl_->base_, *cfd, BEV_OPT_CLOSE_ON_FREE);
}
#else
bev = bufferevent_socket_new(repl_->base_, *cfd, BEV_OPT_CLOSE_ON_FREE);
#endif
if (bev == nullptr) {
#ifdef ENABLE_OPENSSL
if (ssl) SSL_free(ssl);
#endif
close(*cfd);
LOG(ERROR) << "[replication] Failed to create the event socket";
continue;
}
#ifdef ENABLE_OPENSSL
if (repl_->srv_->GetConfig()->tls_replication) {
bufferevent_openssl_set_allow_dirty_shutdown(bev, 1);
}
#endif
}
if (bev == nullptr) { // failed to connect the master and received the stop signal
return;
Expand Down Expand Up @@ -728,9 +758,19 @@ Status ReplicationThread::parallelFetchFile(const std::string &dir,
if (this->stop_flag_) {
return {Status::NotOK, "replication thread was stopped"};
}
int sock_fd = GET_OR_RET(util::SockConnect(this->host_, this->port_).Prefixed("connect the server err"));
ssl_st *ssl = nullptr;
#ifdef ENABLE_OPENSSL
if (this->srv_->GetConfig()->tls_replication) {
ssl = SSL_new(this->srv_->ssl_ctx.get());
}
auto exit = MakeScopeExit([ssl] { SSL_free(ssl); });
#endif
int sock_fd = GET_OR_RET(util::SockConnect(this->host_, this->port_, ssl).Prefixed("connect the server err"));
#ifdef ENABLE_OPENSSL
exit.Disable();
#endif
UniqueFD unique_fd{sock_fd};
auto s = this->sendAuth(sock_fd);
auto s = this->sendAuth(sock_fd, ssl);
if (!s.IsOK()) {
return s.Prefixed("send the auth command err");
}
Expand Down Expand Up @@ -770,12 +810,12 @@ Status ReplicationThread::parallelFetchFile(const std::string &dir,
// command, so we need to fetch all files by multiple command interactions.
if (srv_->GetConfig()->master_use_repl_port) {
for (unsigned i = 0; i < fetch_files.size(); i++) {
s = this->fetchFiles(sock_fd, dir, {fetch_files[i]}, {crcs[i]}, fn);
s = this->fetchFiles(sock_fd, dir, {fetch_files[i]}, {crcs[i]}, fn, ssl);
if (!s.IsOK()) break;
}
} else {
if (!fetch_files.empty()) {
s = this->fetchFiles(sock_fd, dir, fetch_files, crcs, fn);
s = this->fetchFiles(sock_fd, dir, fetch_files, crcs, fn, ssl);
}
}
return s;
Expand All @@ -790,13 +830,13 @@ Status ReplicationThread::parallelFetchFile(const std::string &dir,
return Status::OK();
}

Status ReplicationThread::sendAuth(int sock_fd) {
Status ReplicationThread::sendAuth(int sock_fd, ssl_st *ssl) {
// Send auth when needed
std::string auth = srv_->GetConfig()->masterauth;
if (!auth.empty()) {
UniqueEvbuf evbuf;
const auto auth_command = redis::MultiBulkString({"AUTH", auth});
auto s = util::SockSend(sock_fd, auth_command);
auto s = util::SockSend(sock_fd, auth_command, ssl);
if (!s.IsOK()) return s.Prefixed("send auth command err");
while (true) {
if (evbuffer_read(evbuf.get(), sock_fd, -1) <= 0) {
Expand All @@ -814,15 +854,15 @@ Status ReplicationThread::sendAuth(int sock_fd) {
}

Status ReplicationThread::fetchFile(int sock_fd, evbuffer *evbuf, const std::string &dir, const std::string &file,
uint32_t crc, const FetchFileCallback &fn) {
uint32_t crc, const FetchFileCallback &fn, ssl_st *ssl) {
size_t file_size = 0;

// Read file size line
while (true) {
UniqueEvbufReadln line(evbuf, EVBUFFER_EOL_CRLF_STRICT);
if (!line) {
if (evbuffer_read(evbuf, sock_fd, -1) <= 0) {
return {Status::NotOK, fmt::format("read size: {}", strerror(errno))};
if (auto s = util::EvbufferRead(evbuf, sock_fd, -1, ssl); !s) {
return std::move(s).Prefixed("read size");
}
continue;
}
Expand Down Expand Up @@ -854,8 +894,8 @@ Status ReplicationThread::fetchFile(int sock_fd, evbuffer *evbuf, const std::str
tmp_crc = rocksdb::crc32c::Extend(tmp_crc, data, data_len);
remain -= data_len;
} else {
if (evbuffer_read(evbuf, sock_fd, -1) <= 0) {
return {Status::NotOK, fmt::format("read sst file: {}", strerror(errno))};
if (auto s = util::EvbufferRead(evbuf, sock_fd, -1, ssl); !s) {
return std::move(s).Prefixed("read sst file");
}
}
}
Expand All @@ -873,7 +913,7 @@ Status ReplicationThread::fetchFile(int sock_fd, evbuffer *evbuf, const std::str
}

Status ReplicationThread::fetchFiles(int sock_fd, const std::string &dir, const std::vector<std::string> &files,
const std::vector<uint32_t> &crcs, const FetchFileCallback &fn) {
const std::vector<uint32_t> &crcs, const FetchFileCallback &fn, ssl_st *ssl) {
std::string files_str;
for (const auto &file : files) {
files_str += file;
Expand All @@ -882,13 +922,13 @@ Status ReplicationThread::fetchFiles(int sock_fd, const std::string &dir, const
files_str.pop_back();

const auto fetch_command = redis::MultiBulkString({"_fetch_file", files_str});
auto s = util::SockSend(sock_fd, fetch_command);
auto s = util::SockSend(sock_fd, fetch_command, ssl);
if (!s.IsOK()) return s.Prefixed("send fetch file command");

UniqueEvbuf evbuf;
for (unsigned i = 0; i < files.size(); i++) {
DLOG(INFO) << "[fetch] Start to fetch file " << files[i];
s = fetchFile(sock_fd, evbuf.get(), dir, files[i], crcs[i], fn);
s = fetchFile(sock_fd, evbuf.get(), dir, files[i], crcs[i], fn, ssl);
if (!s.IsOK()) {
s = Status(Status::NotOK, "fetch file err: " + s.Msg());
LOG(WARNING) << "[fetch] Fail to fetch file " << files[i] << ", err: " << s.Msg();
Expand Down
7 changes: 4 additions & 3 deletions src/cluster/replication.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <vector>

#include "event_util.h"
#include "io_util.h"
#include "server/redis_connection.h"
#include "status.h"
#include "storage/storage.h"
Expand Down Expand Up @@ -197,11 +198,11 @@ class ReplicationThread : private EventCallbackBase<ReplicationThread> {
CBState fullSyncReadCB(bufferevent *bev);

// Synchronized-Blocking ops
Status sendAuth(int sock_fd);
Status sendAuth(int sock_fd, ssl_st *ssl);
Status fetchFile(int sock_fd, evbuffer *evbuf, const std::string &dir, const std::string &file, uint32_t crc,
const FetchFileCallback &fn);
const FetchFileCallback &fn, ssl_st *ssl);
Status fetchFiles(int sock_fd, const std::string &dir, const std::vector<std::string> &files,
const std::vector<uint32_t> &crcs, const FetchFileCallback &fn);
const std::vector<uint32_t> &crcs, const FetchFileCallback &fn, ssl_st *ssl);
Status parallelFetchFile(const std::string &dir, const std::vector<std::pair<std::string, uint32_t>> &files);
static bool isRestoringError(const char *err);
static bool isWrongPsyncNum(const char *err);
Expand Down
10 changes: 5 additions & 5 deletions src/commands/cmd_replication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class CommandPSync : public Commander {
s = svr->AddSlave(conn, next_repl_seq_);
if (!s.IsOK()) {
std::string err = "-ERR " + s.Msg() + "\r\n";
s = util::SockSend(conn->GetFD(), err);
s = util::SockSend(conn->GetFD(), err, conn->GetBufferEvent());
if (!s.IsOK()) {
LOG(WARNING) << "failed to send error message to the replica: " << s.Msg();
}
Expand Down Expand Up @@ -229,15 +229,15 @@ class CommandFetchMeta : public Commander {
std::string files;
auto s = engine::Storage::ReplDataManager::GetFullReplDataInfo(svr->storage, &files);
if (!s.IsOK()) {
s = util::SockSend(repl_fd, "-ERR can't create db checkpoint");
s = util::SockSend(repl_fd, "-ERR can't create db checkpoint", bev);
if (!s.IsOK()) {
LOG(WARNING) << "[replication] Failed to send error response: " << s.Msg();
}
LOG(WARNING) << "[replication] Failed to get full data file info: " << s.Msg();
return;
}
// Send full data file info
if (util::SockSend(repl_fd, files + CRLF).IsOK()) {
if (util::SockSend(repl_fd, files + CRLF, bev).IsOK()) {
LOG(INFO) << "[replication] Succeed sending full data file info to " << ip;
} else {
LOG(WARNING) << "[replication] Fail to send full data file info " << ip << ", error: " << strerror(errno);
Expand Down Expand Up @@ -291,8 +291,8 @@ class CommandFetchFile : public Commander {
if (!fd) break;

// Send file size and content
if (util::SockSend(repl_fd, std::to_string(file_size) + CRLF).IsOK() &&
util::SockSendFile(repl_fd, *fd, file_size).IsOK()) {
if (util::SockSend(repl_fd, std::to_string(file_size) + CRLF, bev).IsOK() &&
util::SockSendFile(repl_fd, *fd, file_size, bev).IsOK()) {
LOG(INFO) << "[replication] Succeed sending file " << file << " to " << ip;
} else {
LOG(WARNING) << "[replication] Fail to send file " << file << " to " << ip << ", error: " << strerror(errno);
Expand Down
Loading

0 comments on commit f3d796d

Please sign in to comment.