diff --git a/src/commands/cmd_pubsub.cc b/src/commands/cmd_pubsub.cc index 45272eef2ca..6ec61eea5d3 100644 --- a/src/commands/cmd_pubsub.cc +++ b/src/commands/cmd_pubsub.cc @@ -138,6 +138,44 @@ class CommandPUnSubscribe : public Commander { } }; +class CommandSSubscribe : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + uint16_t slot = 0; + if (srv->GetConfig()->cluster_enabled) { + slot = GetSlotIdFromKey(args_[1]); + for (unsigned int i = 2; i < args_.size(); i++) { + if (GetSlotIdFromKey(args_[i]) != slot) { + return {Status::RedisExecErr, "CROSSSLOT Keys in request don't hash to the same slot"}; + } + } + } + + for (unsigned int i = 1; i < args_.size(); i++) { + conn->SSubscribeChannel(args_[i], slot); + SubscribeCommandReply(output, "ssubscribe", args_[i], conn->SSubscriptionsCount()); + } + return Status::OK(); + } +}; + +class CommandSUnSubscribe : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + if (args_.size() == 1) { + conn->SUnsubscribeAll([output](const std::string &sub_name, int num) { + SubscribeCommandReply(output, "sunsubscribe", sub_name, num); + }); + } else { + for (size_t i = 1; i < args_.size(); i++) { + conn->SUnsubscribeChannel(args_[i], srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(args_[i]) : 0); + SubscribeCommandReply(output, "sunsubscribe", args_[i], conn->SSubscriptionsCount()); + } + } + return Status::OK(); + } +}; + class CommandPubSub : public Commander { public: Status Parse(const std::vector &args) override { @@ -146,14 +184,14 @@ class CommandPubSub : public Commander { return Status::OK(); } - if ((subcommand_ == "numsub") && args.size() >= 2) { + if ((subcommand_ == "numsub" || subcommand_ == "shardnumsub") && args.size() >= 2) { if (args.size() > 2) { channels_ = std::vector(args.begin() + 2, args.end()); } return Status::OK(); } - if ((subcommand_ == "channels") && args.size() <= 3) { + if ((subcommand_ == "channels" || subcommand_ == "shardchannels") && args.size() <= 3) { if (args.size() == 3) { pattern_ = args[2]; } @@ -169,9 +207,13 @@ class CommandPubSub : public Commander { return Status::OK(); } - if (subcommand_ == "numsub") { + if (subcommand_ == "numsub" || subcommand_ == "shardnumsub") { std::vector channel_subscribe_nums; - srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums); + if (subcommand_ == "numsub") { + srv->ListChannelSubscribeNum(channels_, &channel_subscribe_nums); + } else { + srv->ListSChannelSubscribeNum(channels_, &channel_subscribe_nums); + } output->append(redis::MultiLen(channel_subscribe_nums.size() * 2)); for (const auto &chan_subscribe_num : channel_subscribe_nums) { @@ -182,9 +224,13 @@ class CommandPubSub : public Commander { return Status::OK(); } - if (subcommand_ == "channels") { + if (subcommand_ == "channels" || subcommand_ == "shardchannels") { std::vector channels; - srv->GetChannelsByPattern(pattern_, &channels); + if (subcommand_ == "channels") { + srv->GetChannelsByPattern(pattern_, &channels); + } else { + srv->GetSChannelsByPattern(pattern_, &channels); + } *output = redis::MultiBulkString(channels); return Status::OK(); } @@ -205,6 +251,8 @@ REDIS_REGISTER_COMMANDS( MakeCmdAttr("unsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0), MakeCmdAttr("psubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0), MakeCmdAttr("punsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0), + MakeCmdAttr("ssubscribe", -2, "read-only pub-sub no-multi no-script", 0, 0, 0), + MakeCmdAttr("sunsubscribe", -1, "read-only pub-sub no-multi no-script", 0, 0, 0), MakeCmdAttr("pubsub", -2, "read-only pub-sub no-script", 0, 0, 0), ) } // namespace redis diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index a41094d1bd8..d94d81e73ec 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -1160,7 +1160,7 @@ class CommandAnalyze : public Commander { public: Status Parse(const std::vector &args) override { if (args.size() <= 1) return {Status::RedisExecErr, errInvalidSyntax}; - for (int i = 1; i < args.size(); ++i) { + for (unsigned int i = 1; i < args.size(); ++i) { command_args_.push_back(args[i]); } return Status::OK(); @@ -1178,7 +1178,8 @@ class CommandAnalyze : public Commander { cmd->SetArgs(command_args_); int arity = cmd->GetAttributes()->arity; - if ((arity > 0 && command_args_.size() != arity) || (arity < 0 && command_args_.size() < -arity)) { + if ((arity > 0 && static_cast(command_args_.size()) != arity) || + (arity < 0 && static_cast(command_args_.size()) < -arity)) { *output = redis::Error("ERR wrong number of arguments"); return {Status::RedisExecErr, errWrongNumOfArguments}; } diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index ae80e950434..d6e0b5f6749 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -261,6 +261,45 @@ void Connection::PUnsubscribeAll(const UnsubscribeCallback &reply) { int Connection::PSubscriptionsCount() { return static_cast(subscribe_patterns_.size()); } +void Connection::SSubscribeChannel(const std::string &channel, uint16_t slot) { + for (const auto &chan : subscribe_shard_channels_) { + if (channel == chan) return; + } + + subscribe_shard_channels_.emplace_back(channel); + owner_->srv->SSubscribeChannel(channel, this, slot); +} + +void Connection::SUnsubscribeChannel(const std::string &channel, uint16_t slot) { + for (auto iter = subscribe_shard_channels_.begin(); iter != subscribe_shard_channels_.end(); iter++) { + if (*iter == channel) { + subscribe_shard_channels_.erase(iter); + owner_->srv->SUnsubscribeChannel(channel, this, slot); + return; + } + } +} + +void Connection::SUnsubscribeAll(const UnsubscribeCallback &reply) { + if (subscribe_shard_channels_.empty()) { + if (reply) reply("", 0); + return; + } + + int removed = 0; + for (const auto &chan : subscribe_shard_channels_) { + owner_->srv->SUnsubscribeChannel(chan, this, + owner_->srv->GetConfig()->cluster_enabled ? GetSlotIdFromKey(chan) : 0); + removed++; + if (reply) { + reply(chan, static_cast(subscribe_shard_channels_.size() - removed)); + } + } + subscribe_shard_channels_.clear(); +} + +int Connection::SSubscriptionsCount() { return static_cast(subscribe_shard_channels_.size()); } + bool Connection::IsProfilingEnabled(const std::string &cmd) { auto config = srv_->GetConfig(); if (config->profiling_sample_ratio == 0) return false; diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index 25b522d848a..34fbcbae9fa 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -74,6 +74,10 @@ class Connection : public EvbufCallbackBase { void PUnsubscribeChannel(const std::string &pattern); void PUnsubscribeAll(const UnsubscribeCallback &reply = nullptr); int PSubscriptionsCount(); + void SSubscribeChannel(const std::string &channel, uint16_t slot); + void SUnsubscribeChannel(const std::string &channel, uint16_t slot); + void SUnsubscribeAll(const UnsubscribeCallback &reply = nullptr); + int SSubscriptionsCount(); uint64_t GetAge() const; uint64_t GetIdleTime() const; @@ -159,6 +163,7 @@ class Connection : public EvbufCallbackBase { std::vector subscribe_channels_; std::vector subscribe_patterns_; + std::vector subscribe_shard_channels_; Server *srv_; bool in_exec_ = false; diff --git a/src/server/server.cc b/src/server/server.cc index f8f2fb94c22..efe721b27ba 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -78,6 +78,9 @@ Server::Server(engine::Storage *storage, Config *config) // Init cluster cluster = std::make_unique(this, config_->binds, config_->port); + // init shard pub/sub channels + pubsub_shard_channels_.resize(config->cluster_enabled ? HASH_SLOTS_SIZE : 1); + for (int i = 0; i < config->workers; i++) { auto worker = std::make_unique(this, config); // multiple workers can't listen to the same unix socket, so @@ -497,6 +500,64 @@ void Server::PUnsubscribeChannel(const std::string &pattern, redis::Connection * } } +void Server::SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) { + assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0); + std::lock_guard guard(pubsub_shard_channels_mu_); + + auto conn_ctx = ConnContext(conn->Owner(), conn->GetFD()); + if (auto iter = pubsub_shard_channels_[slot].find(channel); iter == pubsub_shard_channels_[slot].end()) { + pubsub_shard_channels_[slot].emplace(channel, std::list{conn_ctx}); + } else { + iter->second.emplace_back(conn_ctx); + } +} + +void Server::SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot) { + assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0); + std::lock_guard guard(pubsub_shard_channels_mu_); + + auto iter = pubsub_shard_channels_[slot].find(channel); + if (iter == pubsub_shard_channels_[slot].end()) { + return; + } + + for (const auto &conn_ctx : iter->second) { + if (conn->GetFD() == conn_ctx.fd && conn->Owner() == conn_ctx.owner) { + iter->second.remove(conn_ctx); + if (iter->second.empty()) { + pubsub_shard_channels_[slot].erase(iter); + } + break; + } + } +} + +void Server::GetSChannelsByPattern(const std::string &pattern, std::vector *channels) { + std::lock_guard guard(pubsub_shard_channels_mu_); + + for (const auto &shard_channels : pubsub_shard_channels_) { + for (const auto &iter : shard_channels) { + if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) { + channels->emplace_back(iter.first); + } + } + } +} + +void Server::ListSChannelSubscribeNum(const std::vector &channels, + std::vector *channel_subscribe_nums) { + std::lock_guard guard(pubsub_shard_channels_mu_); + + for (const auto &chan : channels) { + uint16_t slot = config_->cluster_enabled ? GetSlotIdFromKey(chan) : 0; + if (auto iter = pubsub_shard_channels_[slot].find(chan); iter != pubsub_shard_channels_[slot].end()) { + channel_subscribe_nums->emplace_back(ChannelSubscribeNum{iter->first, iter->second.size()}); + } else { + channel_subscribe_nums->emplace_back(ChannelSubscribeNum{chan, 0}); + } + } +} + void Server::BlockOnKey(const std::string &key, redis::Connection *conn) { std::lock_guard guard(blocking_keys_mu_); diff --git a/src/server/server.h b/src/server/server.h index 2acd0f5dbf1..a86eedf1cd8 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -201,6 +201,11 @@ class Server { void PSubscribeChannel(const std::string &pattern, redis::Connection *conn); void PUnsubscribeChannel(const std::string &pattern, redis::Connection *conn); size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); } + void SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot); + void SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot); + void GetSChannelsByPattern(const std::string &pattern, std::vector *channels); + void ListSChannelSubscribeNum(const std::vector &channels, + std::vector *channel_subscribe_nums); void BlockOnKey(const std::string &key, redis::Connection *conn); void UnblockOnKey(const std::string &key, redis::Connection *conn); @@ -351,6 +356,8 @@ class Server { std::map> pubsub_channels_; std::map> pubsub_patterns_; std::mutex pubsub_channels_mu_; + std::vector>> pubsub_shard_channels_; + std::mutex pubsub_shard_channels_mu_; std::map> blocking_keys_; std::mutex blocking_keys_mu_; diff --git a/tests/gocase/unit/pubsub/pubsubshard_test.go b/tests/gocase/unit/pubsub/pubsubshard_test.go new file mode 100644 index 00000000000..9e8b04cf79d --- /dev/null +++ b/tests/gocase/unit/pubsub/pubsubshard_test.go @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package pubsub + +import ( + "context" + "fmt" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestPubSubShard(t *testing.T) { + ctx := context.Background() + + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + csrv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer csrv.Close() + crdb := csrv.NewClient() + defer func() { require.NoError(t, crdb.Close()) }() + + nodeID := "YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY" + require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODEID", nodeID).Err()) + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383", nodeID, csrv.Host(), csrv.Port()) + require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + rdbs := []*redis.Client{rdb, crdb} + + t.Run("SSUBSCRIBE PING", func(t *testing.T) { + pubsub := rdb.SSubscribe(ctx, "somechannel") + receiveType(t, pubsub, &redis.Subscription{}) + require.NoError(t, pubsub.Ping(ctx)) + require.NoError(t, pubsub.Ping(ctx)) + require.NoError(t, pubsub.SUnsubscribe(ctx, "somechannel")) + require.Equal(t, "PONG", rdb.Ping(ctx).Val()) + receiveType(t, pubsub, &redis.Pong{}) + receiveType(t, pubsub, &redis.Pong{}) + }) + + t.Run("SSUBSCRIBE/SUNSUBSCRIBE basic", func(t *testing.T) { + for _, c := range rdbs { + pubsub := c.SSubscribe(ctx, "singlechannel") + defer pubsub.Close() + + msg := receiveType(t, pubsub, &redis.Subscription{}) + require.EqualValues(t, 1, msg.Count) + require.EqualValues(t, "singlechannel", msg.Channel) + require.EqualValues(t, "ssubscribe", msg.Kind) + + err := pubsub.SSubscribe(ctx, "multichannel1{tag1}", "multichannel2{tag1}", "multichannel1{tag1}") + require.Nil(t, err) + require.EqualValues(t, 2, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + + err = pubsub.SSubscribe(ctx, "multichannel3{tag1}", "multichannel4{tag2}") + require.Nil(t, err) + if c == rdb { + require.EqualValues(t, 4, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 5, receiveType(t, pubsub, &redis.Subscription{}).Count) + } else { + // note: when cluster enabled, shard channels in single command must belong to the same slot + // reference: https://redis.io/commands/ssubscribe + _, err = pubsub.Receive(ctx) + require.EqualError(t, err, "ERR CROSSSLOT Keys in request don't hash to the same slot") + } + + err = pubsub.SUnsubscribe(ctx, "multichannel3{tag1}", "multichannel4{tag2}", "multichannel5{tag2}") + require.Nil(t, err) + if c == rdb { + require.EqualValues(t, 4, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + } else { + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 3, receiveType(t, pubsub, &redis.Subscription{}).Count) + } + + err = pubsub.SUnsubscribe(ctx) + require.Nil(t, err) + msg = receiveType(t, pubsub, &redis.Subscription{}) + require.EqualValues(t, 2, msg.Count) + require.EqualValues(t, "sunsubscribe", msg.Kind) + require.EqualValues(t, 1, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 0, receiveType(t, pubsub, &redis.Subscription{}).Count) + } + }) + + t.Run("SSUBSCRIBE/SUNSUBSCRIBE with empty channel", func(t *testing.T) { + for _, c := range rdbs { + pubsub := c.SSubscribe(ctx) + defer pubsub.Close() + + err := pubsub.SUnsubscribe(ctx, "foo", "bar") + require.Nil(t, err) + require.EqualValues(t, 0, receiveType(t, pubsub, &redis.Subscription{}).Count) + require.EqualValues(t, 0, receiveType(t, pubsub, &redis.Subscription{}).Count) + } + }) + + t.Run("SHARDNUMSUB returns numbers, not strings", func(t *testing.T) { + require.EqualValues(t, map[string]int64{ + "abc": 0, + "def": 0, + }, rdb.PubSubShardNumSub(ctx, "abc", "def").Val()) + }) + + t.Run("PUBSUB SHARDNUMSUB/SHARDCHANNELS", func(t *testing.T) { + for _, c := range rdbs { + pubsub := c.SSubscribe(ctx, "singlechannel") + defer pubsub.Close() + receiveType(t, pubsub, &redis.Subscription{}) + + err := pubsub.SSubscribe(ctx, "multichannel1{tag1}", "multichannel2{tag1}", "multichannel3{tag1}") + require.Nil(t, err) + receiveType(t, pubsub, &redis.Subscription{}) + receiveType(t, pubsub, &redis.Subscription{}) + receiveType(t, pubsub, &redis.Subscription{}) + + pubsub1 := c.SSubscribe(ctx, "multichannel1{tag1}") + defer pubsub1.Close() + + sc := c.PubSubShardChannels(ctx, "") + require.EqualValues(t, len(sc.Val()), 4) + sc = c.PubSubShardChannels(ctx, "multi*") + require.EqualValues(t, len(sc.Val()), 3) + + sn := c.PubSubShardNumSub(ctx) + require.EqualValues(t, len(sn.Val()), 0) + sn = c.PubSubShardNumSub(ctx, "singlechannel", "multichannel1{tag1}", "multichannel2{tag1}", "multichannel3{tag1}") + for i, k := range sn.Val() { + if i == "multichannel1{tag1}" { + require.EqualValues(t, k, 2) + } else { + require.EqualValues(t, k, 1) + } + } + } + }) +}