From f4c517bc802a09c730b2e474b4918735a28cb037 Mon Sep 17 00:00:00 2001 From: Yaroslav Date: Sat, 7 May 2022 09:48:27 +0300 Subject: [PATCH] Add support for listening to unix socket (#531) Two processes cannot bind (and listen) to the same unix socket. That's why only one worker binds the socket. Introduce a separate ConfigField and function to parse unix socket permissions from config file. This number is defined in octal form (e.g. 770 or 0777). --- kvrocks.conf | 9 +++++++ src/config.cc | 6 +++-- src/config.h | 2 ++ src/config_type.h | 32 ++++++++++++++++++++-- src/redis_cmd.cc | 4 +-- src/scripting.cc | 8 +++--- src/server.cc | 11 ++++++++ src/util.cc | 14 +++++++++- src/util.h | 3 ++- src/worker.cc | 68 ++++++++++++++++++++++++++++++++++++++++++----- src/worker.h | 10 ++++--- 11 files changed, 146 insertions(+), 21 deletions(-) diff --git a/kvrocks.conf b/kvrocks.conf index 912e2bcfa4a..805b84bfdce 100644 --- a/kvrocks.conf +++ b/kvrocks.conf @@ -11,6 +11,15 @@ # bind 127.0.0.1 bind 0.0.0.0 +# Unix socket. +# +# Specify the path for the unix socket that will be used to listen for +# incoming connections. There is no default, so kvrocks will not listen +# on a unix socket when not specified. +# +# unixsocket /tmp/kvrocks.sock +# unixsocketperm 777 + # Accept connections on the specified port, default is 6666. port 6666 diff --git a/src/config.cc b/src/config.cc index 8942ba63355..34a7dd7b07d 100644 --- a/src/config.cc +++ b/src/config.cc @@ -133,6 +133,8 @@ Config::Config() { {"migrate-speed", false, new IntField(&migrate_speed, 4096, 0, INT_MAX)}, {"migrate-pipeline-size", false, new IntField(&pipeline_size, 16, 1, INT_MAX)}, {"migrate-sequence-gap", false, new IntField(&sequence_gap, 10000, 1, INT_MAX)}, + {"unixsocket", true, new StringField(&unixsocket, "")}, + {"unixsocketperm", true, new OctalField(&unixsocketperm, 0777, 1, INT_MAX)}, /* rocksdb options */ {"rocksdb.compression", false, new EnumField(&RocksDB.compression, compression_type_enum, 0)}, @@ -227,9 +229,9 @@ void Config::initFieldValidator() { return Status(Status::NotOK, "invalid range format, the range should be between 0 and 24"); } int64_t start, stop; - Status s = Util::StringToNum(args[0], &start, 0, 24); + Status s = Util::DecimalStringToNum(args[0], &start, 0, 24); if (!s.IsOK()) return s; - s = Util::StringToNum(args[1], &stop, 0, 24); + s = Util::DecimalStringToNum(args[1], &stop, 0, 24); if (!s.IsOK()) return s; if (start > stop) return Status(Status::NotOK, "invalid range format, start should be smaller than stop"); compaction_checker_range.Start = start; diff --git a/src/config.h b/src/config.h index 20458e7ea92..fcc4867a2e7 100644 --- a/src/config.h +++ b/src/config.h @@ -100,6 +100,8 @@ struct Config{ std::string masterauth; std::string requirepass; std::string master_host; + std::string unixsocket; + int unixsocketperm = 0777; int master_port = 0; Cron compact_cron; Cron bgsave_cron; diff --git a/src/config_type.h b/src/config_type.h index 0e83ffc716b..6cf811e537d 100644 --- a/src/config_type.h +++ b/src/config_type.h @@ -89,7 +89,35 @@ class IntField : public ConfigField { } Status Set(const std::string &v) override { int64_t n; - auto s = Util::StringToNum(v, &n, min_, max_); + auto s = Util::DecimalStringToNum(v, &n, min_, max_); + if (!s.IsOK()) return s; + *receiver_ = static_cast(n); + return Status::OK(); + } + + private: + int *receiver_; + int min_ = INT_MIN; + int max_ = INT_MAX; +}; + +class OctalField : public ConfigField { + public: + OctalField(int *receiver, int n, int min, int max) + : receiver_(receiver), min_(min), max_(max) { + *receiver_ = n; + } + ~OctalField() override = default; + std::string ToString() override { + return std::to_string(*receiver_); + } + Status ToNumber(int64_t *n) override { + *n = *receiver_; + return Status::OK(); + } + Status Set(const std::string &v) override { + int64_t n; + auto s = Util::OctalStringToNum(v, &n, min_, max_); if (!s.IsOK()) return s; *receiver_ = static_cast(n); return Status::OK(); @@ -117,7 +145,7 @@ class Int64Field : public ConfigField { } Status Set(const std::string &v) override { int64_t n; - auto s = Util::StringToNum(v, &n, min_, max_); + auto s = Util::DecimalStringToNum(v, &n, min_, max_); if (!s.IsOK()) return s; *receiver_ = n; return Status::OK(); diff --git a/src/redis_cmd.cc b/src/redis_cmd.cc index 7caaa20302f..cd7e32992a5 100644 --- a/src/redis_cmd.cc +++ b/src/redis_cmd.cc @@ -3658,7 +3658,7 @@ class CommandPerfLog : public Commander { if (args[2] == "*") { cnt_ = 0; } else { - Status s = Util::StringToNum(args[2], &cnt_); + Status s = Util::DecimalStringToNum(args[2], &cnt_); return s; } } @@ -3694,7 +3694,7 @@ class CommandSlowlog : public Commander { if (args[2] == "*") { cnt_ = 0; } else { - Status s = Util::StringToNum(args[2], &cnt_); + Status s = Util::DecimalStringToNum(args[2], &cnt_); return s; } } diff --git a/src/scripting.cc b/src/scripting.cc index d6e49f0c89d..15384cc6b94 100644 --- a/src/scripting.cc +++ b/src/scripting.cc @@ -245,7 +245,7 @@ namespace Lua { Server *srv = conn->GetServer(); lua_State *lua = srv->Lua(); - auto s = Util::StringToNum(args[2], &numkeys); + auto s = Util::DecimalStringToNum(args[2], &numkeys); if (!s.IsOK()) { return s; } @@ -604,7 +604,7 @@ const char *redisProtocolToLuaType_Int(lua_State *lua, const char *reply) { const char *p = strchr(reply+1, '\r'); int64_t value; - Util::StringToNum(std::string(reply+1, p-reply-1), &value); + Util::DecimalStringToNum(std::string(reply+1, p-reply-1), &value); lua_pushnumber(lua, static_cast(value)); return p+2; } @@ -613,7 +613,7 @@ const char *redisProtocolToLuaType_Bulk(lua_State *lua, const char *reply) { const char *p = strchr(reply+1, '\r'); int64_t bulklen; - Util::StringToNum(std::string(reply+1, p-reply-1), &bulklen); + Util::DecimalStringToNum(std::string(reply+1, p-reply-1), &bulklen); if (bulklen == -1) { lua_pushboolean(lua, 0); return p+2; @@ -648,7 +648,7 @@ const char *redisProtocolToLuaType_Aggregate(lua_State *lua, const char *reply, int64_t mbulklen; int j = 0; - Util::StringToNum(std::string(reply+1, p-reply-1), &mbulklen); + Util::DecimalStringToNum(std::string(reply+1, p-reply-1), &mbulklen); p += 2; if (mbulklen == -1) { lua_pushboolean(lua, 0); diff --git a/src/server.cc b/src/server.cc index 1ea5b23ccac..08ca7cdc0f4 100644 --- a/src/server.cc +++ b/src/server.cc @@ -55,6 +55,17 @@ Server::Server(Engine::Storage *storage, Config *config) : for (int i = 0; i < config->workers; i++) { auto worker = new Worker(this, config); + // multiple workers can't listen to the same unix socket, so + // listen unix socket only from a single worker - the first one + if (!config->unixsocket.empty() && i == 0) { + Status s = worker->ListenUnixSocket(config->unixsocket, config->unixsocketperm, config->backlog); + if (!s.IsOK()) { + LOG(ERROR) << "[server] Failed to listen on unix socket: "<< config->unixsocket + << ", encounter error: " << s.Msg(); + delete worker; + exit(1); + } + } worker_threads_.emplace_back(new WorkerThread(worker)); } AdjustOpenFilesLimit(); diff --git a/src/util.cc b/src/util.cc index 36c327cfa6e..9bb138c8629 100644 --- a/src/util.cc +++ b/src/util.cc @@ -323,7 +323,7 @@ int GetPeerAddr(int fd, std::string *addr, uint32_t *port) { return -2; // only support AF_INET currently } -Status StringToNum(const std::string &str, int64_t *n, int64_t min, int64_t max) { +Status DecimalStringToNum(const std::string &str, int64_t *n, int64_t min, int64_t max) { try { *n = static_cast(std::stoll(str)); if (max > min && (*n < min || *n > max)) { @@ -335,6 +335,18 @@ Status StringToNum(const std::string &str, int64_t *n, int64_t min, int64_t max) return Status::OK(); } +Status OctalStringToNum(const std::string &str, int64_t *n, int64_t min, int64_t max) { + try { + *n = static_cast(std::stoll(str, nullptr, 8)); + if (max > min && (*n < min || *n > max)) { + return Status(Status::NotOK, "value shoud between "+std::to_string(min)+" and "+std::to_string(max)); + } + } catch (std::exception &e) { + return Status(Status::NotOK, "value is not an integer or out of range"); + } + return Status::OK(); +} + std::string ToLower(std::string in) { std::transform(in.begin(), in.end(), in.begin(), [](char c) -> char { return static_cast(std::tolower(c)); }); diff --git a/src/util.h b/src/util.h index 843b4260919..8e5c3f3c100 100644 --- a/src/util.h +++ b/src/util.h @@ -71,7 +71,8 @@ int GetPeerAddr(int fd, std::string *addr, uint32_t *port); bool IsPortInUse(int port); // string util -Status StringToNum(const std::string &str, int64_t *n, int64_t min = INT64_MIN, int64_t max = INT64_MAX); +Status DecimalStringToNum(const std::string &str, int64_t *n, int64_t min = INT64_MIN, int64_t max = INT64_MAX); +Status OctalStringToNum(const std::string &str, int64_t *n, int64_t min = INT64_MIN, int64_t max = INT64_MAX); const std::string Float2String(double d); std::string ToLower(std::string in); void BytesToHuman(char *buf, size_t size, uint64_t n); diff --git a/src/worker.cc b/src/worker.cc index 39830b9cd93..6caf2fada5d 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -22,8 +22,11 @@ #include #include +#include +#include #include #include +#include #include #include #include @@ -44,7 +47,7 @@ Worker::Worker(Server *svr, Config *config, bool repl) : svr_(svr) { int port = config->port; auto binds = config->binds; for (const auto &bind : binds) { - Status s = listen(bind, port, config->backlog); + Status s = listenTCP(bind, port, config->backlog); if (!s.IsOK()) { LOG(ERROR) << "[worker] Failed to listen on: "<< bind << ":" << port << ", encounter error: " << s.Msg(); @@ -78,8 +81,8 @@ void Worker::TimerCB(int, int16_t events, void *ctx) { worker->KickoutIdleClients(config->timeout); } -void Worker::newConnection(evconnlistener *listener, evutil_socket_t fd, - sockaddr *address, int socklen, void *ctx) { +void Worker::newTCPConnection(evconnlistener *listener, evutil_socket_t fd, + sockaddr *address, int socklen, void *ctx) { auto worker = static_cast(ctx); DLOG(INFO) << "[worker] New connection: fd=" << fd << " from port: " << worker->svr_->GetConfig()->port << " thread #" @@ -106,23 +109,51 @@ void Worker::newConnection(evconnlistener *listener, evutil_socket_t fd, Redis::Connection::OnEvent, conn); bufferevent_enable(bev, EV_READ); Status status = worker->AddConnection(conn); + if (!status.IsOK()) { + std::string err_msg = Redis::Error("ERR " + status.Msg()); + write(fd, err_msg.data(), err_msg.size()); + conn->Close(); + return; + } std::string ip; uint32_t port; if (Util::GetPeerAddr(fd, &ip, &port) == 0) { conn->SetAddr(ip, port); } + if (worker->rate_limit_group_ != nullptr) { + bufferevent_add_to_rate_limit_group(bev, worker->rate_limit_group_); + } +} + +void Worker::newUnixSocketConnection(evconnlistener *listener, evutil_socket_t fd, + sockaddr *address, int socklen, void *ctx) { + auto worker = static_cast(ctx); + DLOG(INFO) << "[worker] New connection: fd=" << fd + << " from unixsocket: " << worker->svr_->GetConfig()->unixsocket << " thread #" + << worker->tid_; + event_base *base = evconnlistener_get_base(listener); + auto evThreadSafeFlags = BEV_OPT_THREADSAFE | BEV_OPT_DEFER_CALLBACKS | BEV_OPT_UNLOCK_CALLBACKS; + bufferevent *bev = bufferevent_socket_new(base, + fd, + evThreadSafeFlags); + auto conn = new Redis::Connection(bev, worker); + bufferevent_setcb(bev, Redis::Connection::OnRead, Redis::Connection::OnWrite, + Redis::Connection::OnEvent, conn); + bufferevent_enable(bev, EV_READ); + Status status = worker->AddConnection(conn); if (!status.IsOK()) { std::string err_msg = Redis::Error("ERR " + status.Msg()); write(fd, err_msg.data(), err_msg.size()); conn->Close(); + return; } - + conn->SetAddr(worker->svr_->GetConfig()->unixsocket, 0); if (worker->rate_limit_group_ != nullptr) { bufferevent_add_to_rate_limit_group(bev, worker->rate_limit_group_); } } -Status Worker::listen(const std::string &host, int port, int backlog) { +Status Worker::listenTCP(const std::string &host, int port, int backlog) { sockaddr_in sin{}; sin.sin_family = AF_INET; evutil_inet_pton(AF_INET, host.data(), &(sin.sin_addr)); @@ -140,9 +171,34 @@ Status Worker::listen(const std::string &host, int port, int backlog) { return Status(Status::NotOK, evutil_socket_error_to_string(EVUTIL_SOCKET_ERROR())); } evutil_make_socket_nonblocking(fd); - auto lev = evconnlistener_new(base_, newConnection, this, + auto lev = evconnlistener_new(base_, newTCPConnection, this, + LEV_OPT_CLOSE_ON_FREE, backlog, fd); + listen_events_.emplace_back(lev); + return Status::OK(); +} + +Status Worker::ListenUnixSocket(const std::string &path, int perm, int backlog) { + unlink(path.c_str()); + sockaddr_un sa{}; + if (path.size() > sizeof(sa.sun_path) - 1) { + return Status(Status::NotOK, "unix socket path too long"); + } + sa.sun_family = AF_LOCAL; + strncpy(sa.sun_path, path.c_str(), sizeof(sa.sun_path) - 1); + int fd = socket(AF_LOCAL, SOCK_STREAM, 0); + if (fd == -1) { + return Status(Status::NotOK, evutil_socket_error_to_string(EVUTIL_SOCKET_ERROR())); + } + if (bind(fd, (struct sockaddr *)&sa, sizeof(sa)) < 0) { + return Status(Status::NotOK, evutil_socket_error_to_string(EVUTIL_SOCKET_ERROR())); + } + evutil_make_socket_nonblocking(fd); + auto lev = evconnlistener_new(base_, newUnixSocketConnection, this, LEV_OPT_CLOSE_ON_FREE, backlog, fd); listen_events_.emplace_back(lev); + if (perm != 0) { + chmod(sa.sun_path, (mode_t)perm); + } return Status::OK(); } diff --git a/src/worker.h b/src/worker.h index ad2f56301a8..14a278dc600 100644 --- a/src/worker.h +++ b/src/worker.h @@ -60,12 +60,16 @@ class Worker { uint64_t type, bool skipme, int64_t *killed); void KickoutIdleClients(int timeout); + Status ListenUnixSocket(const std::string &path, int perm, int backlog); + Server *svr_; private: - Status listen(const std::string &host, int port, int backlog); - static void newConnection(evconnlistener *listener, evutil_socket_t fd, - sockaddr *address, int socklen, void *ctx); + Status listenTCP(const std::string &host, int port, int backlog); + static void newTCPConnection(evconnlistener *listener, evutil_socket_t fd, + sockaddr *address, int socklen, void *ctx); + static void newUnixSocketConnection(evconnlistener *listener, evutil_socket_t fd, + sockaddr *address, int socklen, void *ctx); static void TimerCB(int, int16_t events, void *ctx); Redis::Connection *removeConnection(int fd);