Skip to content

Commit

Permalink
Improve RESP handling code in replication (#2334)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice authored May 29, 2024
1 parent 479338d commit 5e99b27
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 29 deletions.
33 changes: 18 additions & 15 deletions src/cluster/replication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <string>
#include <thread>

#include "commands/error_constants.h"
#include "event_util.h"
#include "fmt/format.h"
#include "io_util.h"
Expand Down Expand Up @@ -402,13 +403,13 @@ ReplicationThread::CBState ReplicationThread::authWriteCB(bufferevent *bev) {
return CBState::NEXT;
}

inline bool ResponseLineIsOK(const char *line) { return strncmp(line, "+OK", 3) == 0; }
inline bool ResponseLineIsOK(std::string_view line) { return line == RESP_PREFIX_SIMPLE_STRING "OK"; }

ReplicationThread::CBState ReplicationThread::authReadCB(bufferevent *bev) { // NOLINT
auto input = bufferevent_get_input(bev);
UniqueEvbufReadln line(input, EVBUFFER_EOL_CRLF_STRICT);
if (!line) return CBState::AGAIN;
if (!ResponseLineIsOK(line.get())) {
if (!ResponseLineIsOK(line.View())) {
// Auth failed
LOG(ERROR) << "[replication] Auth failed: " << line.get();
return CBState::RESTART;
Expand All @@ -430,7 +431,7 @@ ReplicationThread::CBState ReplicationThread::checkDBNameReadCB(bufferevent *bev
if (!line) return CBState::AGAIN;

if (line[0] == '-') {
if (isRestoringError(line.get())) {
if (isRestoringError(line.View())) {
LOG(WARNING) << "The master was restoring the db, retry later";
} else {
LOG(ERROR) << "Failed to get the db name, " << line.get();
Expand Down Expand Up @@ -468,18 +469,18 @@ ReplicationThread::CBState ReplicationThread::replConfReadCB(bufferevent *bev) {
if (!line) return CBState::AGAIN;

// on unknown option: first try without announce ip, if it fails again - do nothing (to prevent infinite loop)
if (isUnknownOption(line.get()) && !next_try_without_announce_ip_address_) {
if (isUnknownOption(line.View()) && !next_try_without_announce_ip_address_) {
next_try_without_announce_ip_address_ = true;
LOG(WARNING) << "The old version master, can't handle ip-address, "
<< "try without it again";
// Retry previous state, i.e. send replconf again
return CBState::PREV;
}
if (line[0] == '-' && isRestoringError(line.get())) {
if (line[0] == '-' && isRestoringError(line.View())) {
LOG(WARNING) << "The master was restoring the db, retry later";
return CBState::RESTART;
}
if (!ResponseLineIsOK(line.get())) {
if (!ResponseLineIsOK(line.View())) {
LOG(WARNING) << "[replication] Failed to replconf: " << line.get() + 1;
// backward compatible with old version that doesn't support replconf cmd
return CBState::NEXT;
Expand Down Expand Up @@ -530,20 +531,20 @@ ReplicationThread::CBState ReplicationThread::tryPSyncReadCB(bufferevent *bev) {
UniqueEvbufReadln line(input, EVBUFFER_EOL_CRLF_STRICT);
if (!line) return CBState::AGAIN;

if (line[0] == '-' && isRestoringError(line.get())) {
if (line[0] == '-' && isRestoringError(line.View())) {
LOG(WARNING) << "The master was restoring the db, retry later";
return CBState::RESTART;
}

if (line[0] == '-' && isWrongPsyncNum(line.get())) {
if (line[0] == '-' && isWrongPsyncNum(line.View())) {
next_try_old_psync_ = true;
LOG(WARNING) << "The old version master, can't handle new PSYNC, "
<< "try old PSYNC again";
// Retry previous state, i.e. send PSYNC again
return CBState::PREV;
}

if (!ResponseLineIsOK(line.get())) {
if (!ResponseLineIsOK(line.View())) {
// PSYNC isn't OK, we should use FullSync
// Switch to fullsync state machine
fullsync_steps_.Start();
Expand Down Expand Up @@ -844,7 +845,7 @@ Status ReplicationThread::sendAuth(int sock_fd, ssl_st *ssl) {
}
UniqueEvbufReadln line(evbuf.get(), EVBUFFER_EOL_CRLF_STRICT);
if (!line) continue;
if (!ResponseLineIsOK(line.get())) {
if (!ResponseLineIsOK(line.View())) {
return {Status::NotOK, "auth got invalid response"};
}
break;
Expand Down Expand Up @@ -998,15 +999,17 @@ Status ReplicationThread::parseWriteBatch(const std::string &batch_string) {
return Status::OK();
}

bool ReplicationThread::isRestoringError(const char *err) {
return std::string(err) == "-ERR restoring the db from backup";
bool ReplicationThread::isRestoringError(std::string_view err) {
return err == std::string(RESP_PREFIX_ERROR) + redis::errRestoringBackup;
}

bool ReplicationThread::isWrongPsyncNum(const char *err) {
return std::string(err) == "-ERR wrong number of arguments";
bool ReplicationThread::isWrongPsyncNum(std::string_view err) {
return err == std::string(RESP_PREFIX_ERROR) + redis::errWrongNumArguments;
}

bool ReplicationThread::isUnknownOption(const char *err) { return std::string(err) == "-ERR unknown option"; }
bool ReplicationThread::isUnknownOption(std::string_view err) {
return err == fmt::format("{}ERR {}", RESP_PREFIX_ERROR, redis::errUnknownOption);
}

rocksdb::Status WriteBatchHandler::PutCF(uint32_t column_family_id, const rocksdb::Slice &key,
const rocksdb::Slice &value) {
Expand Down
6 changes: 3 additions & 3 deletions src/cluster/replication.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ class ReplicationThread : private EventCallbackBase<ReplicationThread> {
Status fetchFiles(int sock_fd, const std::string &dir, const std::vector<std::string> &files,
const std::vector<uint32_t> &crcs, const FetchFileCallback &fn, ssl_st *ssl);
Status parallelFetchFile(const std::string &dir, const std::vector<std::pair<std::string, uint32_t>> &files);
static bool isRestoringError(const char *err);
static bool isWrongPsyncNum(const char *err);
static bool isUnknownOption(const char *err);
static bool isRestoringError(std::string_view err);
static bool isWrongPsyncNum(std::string_view err);
static bool isUnknownOption(std::string_view err);

Status parseWriteBatch(const std::string &batch_string);
};
Expand Down
5 changes: 3 additions & 2 deletions src/commands/cmd_replication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "error_constants.h"
#include "io_util.h"
#include "scope_exit.h"
#include "server/redis_reply.h"
#include "server/server.h"
#include "thread_util.h"
#include "time_util.h"
Expand Down Expand Up @@ -101,7 +102,7 @@ class CommandPSync : public Commander {
srv->stats.IncrPSyncOKCount();
s = srv->AddSlave(conn, next_repl_seq_);
if (!s.IsOK()) {
std::string err = "-ERR " + s.Msg() + "\r\n";
std::string err = redis::Error(s.Msg());
s = util::SockSend(conn->GetFD(), err, conn->GetBufferEvent());
if (!s.IsOK()) {
LOG(WARNING) << "failed to send error message to the replica: " << s.Msg();
Expand Down Expand Up @@ -229,7 +230,7 @@ class CommandFetchMeta : public Commander {
std::string files;
auto s = engine::Storage::ReplDataManager::GetFullReplDataInfo(srv->storage, &files);
if (!s.IsOK()) {
s = util::SockSend(repl_fd, "-ERR can't create db checkpoint", bev);
s = util::SockSend(repl_fd, redis::Error("can't create db checkpoint"), bev);
if (!s.IsOK()) {
LOG(WARNING) << "[replication] Failed to send error response: " << s.Msg();
}
Expand Down
4 changes: 3 additions & 1 deletion src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "config/config.h"
#include "error_constants.h"
#include "server/redis_connection.h"
#include "server/redis_reply.h"
#include "server/server.h"
#include "stats/disk_stats.h"
#include "storage/rdb.h"
Expand Down Expand Up @@ -740,7 +741,8 @@ class CommandHello final : public Commander {
// kvrocks only supports REPL2 by now, but for supporting some
// `hello 3`, it will not report error when using 3.
if (protocol < 2 || protocol > 3) {
return {Status::NotOK, "-NOPROTO unsupported protocol version"};
conn->Reply(redis::Error("NOPROTO unsupported protocol version"));
return Status::OK();
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/commands/error_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ inline constexpr const char *errValueIsNotFloat = "value is not a valid float";
inline constexpr const char *errNoMatchingScript = "NOSCRIPT No matching script. Please use EVAL";
inline constexpr const char *errUnknownOption = "unknown option";
inline constexpr const char *errUnknownSubcommandOrWrongArguments = "Unknown subcommand or wrong number of arguments";
inline constexpr const char *errWrongNumArguments = "ERR wrong number of arguments";
inline constexpr const char *errRestoringBackup = "LOADING kvrocks is restoring the db from backup";

} // namespace redis
3 changes: 3 additions & 0 deletions src/common/event_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <cstdlib>
#include <memory>
#include <string_view>
#include <utility>

#include "event2/buffer.h"
Expand All @@ -44,6 +45,8 @@ struct UniqueEvbufReadln : UniqueFreePtr<char[]> {
: UniqueFreePtr(evbuffer_readln(buffer, &length, eol_style)) {}

size_t length;

std::string_view View() { return {get(), length}; }
};

using StaticEvbufFree = StaticFunction<decltype(evbuffer_free), evbuffer_free>;
Expand Down
8 changes: 5 additions & 3 deletions src/server/redis_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#include <shared_mutex>

#include "commands/commander.h"
#include "commands/error_constants.h"
#include "fmt/format.h"
#include "server/redis_reply.h"
#include "string_util.h"
#ifdef ENABLE_OPENSSL
#include <event2/bufferevent_ssl.h>
Expand Down Expand Up @@ -138,7 +140,7 @@ std::string Connection::Bool(bool b) const {
}

std::string Connection::MultiBulkString(const std::vector<std::string> &values) const {
std::string result = "*" + std::to_string(values.size()) + CRLF;
std::string result = MultiLen(values.size());
for (const auto &value : values) {
if (value.empty()) {
result += NilString();
Expand All @@ -151,7 +153,7 @@ std::string Connection::MultiBulkString(const std::vector<std::string> &values)

std::string Connection::MultiBulkString(const std::vector<std::string> &values,
const std::vector<rocksdb::Status> &statuses) const {
std::string result = "*" + std::to_string(values.size()) + CRLF;
std::string result = MultiLen(values.size());
for (size_t i = 0; i < values.size(); i++) {
if (i < statuses.size() && !statuses[i].ok()) {
result += NilString();
Expand Down Expand Up @@ -470,7 +472,7 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
}

if (srv_->IsLoading() && !(cmd_flags & kCmdLoading)) {
Reply(redis::Error("LOADING kvrocks is restoring the db from backup"));
Reply(redis::Error(errRestoringBackup));
if (is_multi_exec) multi_error_ = true;
continue;
}
Expand Down
4 changes: 2 additions & 2 deletions src/server/redis_reply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ std::string BulkString(const std::string &data) { return "$" + std::to_string(da

std::string Array(const std::vector<std::string> &list) {
size_t n = std::accumulate(list.begin(), list.end(), 0, [](size_t n, const std::string &s) { return n + s.size(); });
std::string result = "*" + std::to_string(list.size()) + CRLF;
std::string result = MultiLen(list.size());
std::string::size_type final_size = result.size() + n;
result.reserve(final_size);
for (const auto &i : list) result += i;
return result;
}

std::string ArrayOfBulkStrings(const std::vector<std::string> &elems) {
std::string result = "*" + std::to_string(elems.size()) + CRLF;
std::string result = MultiLen(elems.size());
for (const auto &elem : elems) {
result += BulkString(elem);
}
Expand Down
4 changes: 3 additions & 1 deletion src/server/redis_reply.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include <string>
#include <vector>

#define CRLF "\r\n" // NOLINT
#define CRLF "\r\n" // NOLINT
#define RESP_PREFIX_ERROR "-" // NOLINT
#define RESP_PREFIX_SIMPLE_STRING "+" // NOLINT

namespace redis {

Expand Down
4 changes: 2 additions & 2 deletions tests/gocase/unit/hello/hello_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestHello(t *testing.T) {

t.Run("hello with wrong protocol", func(t *testing.T) {
r := rdb.Do(ctx, "HELLO", "1")
require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
require.ErrorContains(t, r.Err(), "NOPROTO unsupported protocol version")
})

t.Run("hello with protocol 2", func(t *testing.T) {
Expand All @@ -61,7 +61,7 @@ func TestHello(t *testing.T) {

t.Run("hello with wrong protocol", func(t *testing.T) {
r := rdb.Do(ctx, "HELLO", "5")
require.ErrorContains(t, r.Err(), "-NOPROTO unsupported protocol version")
require.ErrorContains(t, r.Err(), "NOPROTO unsupported protocol version")
})

t.Run("hello with non protocol", func(t *testing.T) {
Expand Down

0 comments on commit 5e99b27

Please sign in to comment.