diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index e97ead7a12e..79c188622b0 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -966,24 +966,55 @@ class CommandZRevRangeByScore : public CommandZRangeGeneric { class CommandZRank : public Commander { public: explicit CommandZRank(bool reversed = false) : reversed_(reversed) {} + + Status Parse(const std::vector &args) override { + if (args.size() > 4) { + return {Status::RedisParseErr, errWrongNumOfArguments}; + } + + // skip the and parse remaining optional arguments + CommandParser parser(args, 3); + while (parser.Good()) { + if (parser.EatEqICase("withscore") && !with_score_) { + with_score_ = true; + } else { + return parser.InvalidSyntax(); + } + } + + return Commander::Parse(args); + } + Status Execute(Server *svr, Connection *conn, std::string *output) override { int rank = 0; + double score = 0.0; redis::ZSet zset_db(svr->storage, conn->GetNamespace()); - auto s = zset_db.Rank(args_[1], args_[2], reversed_, &rank); + auto s = zset_db.Rank(args_[1], args_[2], reversed_, &rank, &score); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } if (rank == -1) { - *output = redis::NilString(); + if (with_score_) { + output->append(redis::MultiLen(-1)); + } else { + *output = redis::NilString(); + } } else { - *output = redis::Integer(rank); + if (with_score_) { + output->append(redis::MultiLen(2)); + output->append(redis::Integer(rank)); + output->append(redis::BulkString(util::Float2String(score))); + } else { + *output = redis::Integer(rank); + } } return Status::OK(); } private: bool reversed_; + bool with_score_ = false; }; class CommandZRevRank : public CommandZRank { @@ -1361,13 +1392,13 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zrangebylex", -4, "read-only", 1, 1, 1), MakeCmdAttr("zrevrangebylex", -4, "read-only", 1, 1, 1), MakeCmdAttr("zrangebyscore", -4, "read-only", 1, 1, 1), - MakeCmdAttr("zrank", 3, "read-only", 1, 1, 1), + MakeCmdAttr("zrank", -3, "read-only", 1, 1, 1), MakeCmdAttr("zrem", -3, "write", 1, 1, 1), MakeCmdAttr("zremrangebyrank", 4, "write", 1, 1, 1), MakeCmdAttr("zremrangebyscore", 4, "write", 1, 1, 1), MakeCmdAttr("zremrangebylex", 4, "write", 1, 1, 1), MakeCmdAttr("zrevrangebyscore", -4, "read-only", 1, 1, 1), - MakeCmdAttr("zrevrank", 3, "read-only", 1, 1, 1), + MakeCmdAttr("zrevrank", -3, "read-only", 1, 1, 1), MakeCmdAttr("zscore", 3, "read-only", 1, 1, 1), MakeCmdAttr("zmscore", -3, "read-only", 1, 1, 1), MakeCmdAttr("zscan", -3, "read-only", 1, 1, 1), diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index bc0515ef7a6..e7a1585d7e2 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -566,8 +566,10 @@ rocksdb::Status ZSet::Remove(const Slice &user_key, const std::vector &me return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } -rocksdb::Status ZSet::Rank(const Slice &user_key, const Slice &member, bool reversed, int *member_rank) { +rocksdb::Status ZSet::Rank(const Slice &user_key, const Slice &member, bool reversed, int *member_rank, + double *member_score) { *member_rank = -1; + *member_score = 0.0; std::string ns_key; AppendNamespacePrefix(user_key, &ns_key); @@ -613,6 +615,7 @@ rocksdb::Status ZSet::Rank(const Slice &user_key, const Slice &member, bool reve } *member_rank = rank; + *member_score = target_score; return rocksdb::Status::OK(); } diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h index 8722b35186f..1ea0b192f06 100644 --- a/src/types/redis_zset.h +++ b/src/types/redis_zset.h @@ -98,7 +98,8 @@ class ZSet : public SubKeyScanner { rocksdb::Status Add(const Slice &user_key, ZAddFlags flags, MemberScores *mscores, uint64_t *added_cnt); rocksdb::Status Card(const Slice &user_key, uint64_t *size); rocksdb::Status IncrBy(const Slice &user_key, const Slice &member, double increment, double *score); - rocksdb::Status Rank(const Slice &user_key, const Slice &member, bool reversed, int *member_rank); + rocksdb::Status Rank(const Slice &user_key, const Slice &member, bool reversed, int *member_rank, + double *member_score); rocksdb::Status Remove(const Slice &user_key, const std::vector &members, uint64_t *removed_cnt); rocksdb::Status Pop(const Slice &user_key, int count, bool min, MemberScores *mscores); rocksdb::Status Score(const Slice &user_key, const Slice &member, double *score); diff --git a/tests/cppunit/types/zset_test.cc b/tests/cppunit/types/zset_test.cc index f7e6e27ac53..1db9aec6f5a 100644 --- a/tests/cppunit/types/zset_test.cc +++ b/tests/cppunit/types/zset_test.cc @@ -382,19 +382,25 @@ TEST_F(RedisZSetTest, Rank) { for (size_t i = 0; i < fields_.size(); i++) { int rank = 0; - zset_->Rank(key_, fields_[i], false, &rank); + double score = 0.0; + zset_->Rank(key_, fields_[i], false, &rank, &score); EXPECT_EQ(i, rank); + EXPECT_EQ(scores_[i], score); } for (size_t i = 0; i < fields_.size(); i++) { int rank = 0; - zset_->Rank(key_, fields_[i], true, &rank); + double score = 0.0; + zset_->Rank(key_, fields_[i], true, &rank, &score); EXPECT_EQ(i, static_cast(fields_.size() - rank - 1)); + EXPECT_EQ(scores_[i], score); } std::vector no_exist_members = {"a", "b"}; for (const auto &member : no_exist_members) { int rank = 0; - zset_->Rank(key_, member, true, &rank); + double score = 0.0; + zset_->Rank(key_, member, true, &rank, &score); EXPECT_EQ(-1, rank); + EXPECT_EQ(0.0, score); } zset_->Del(key_); } diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index 1d7413647ad..d5822221e6a 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -674,20 +674,42 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s rdb.ZAdd(ctx, "zranktmp", redis.Z{Score: 10, Member: "x"}) rdb.ZAdd(ctx, "zranktmp", redis.Z{Score: 20, Member: "y"}) rdb.ZAdd(ctx, "zranktmp", redis.Z{Score: 30, Member: "z"}) + require.Equal(t, int64(0), rdb.ZRank(ctx, "zranktmp", "x").Val()) require.Equal(t, int64(1), rdb.ZRank(ctx, "zranktmp", "y").Val()) require.Equal(t, int64(2), rdb.ZRank(ctx, "zranktmp", "z").Val()) - require.Equal(t, int64(0), rdb.ZRank(ctx, "zranktmp", "foo").Val()) + require.Equal(t, redis.Nil, rdb.ZRank(ctx, "zranktmp", "foo").Err()) require.Equal(t, int64(2), rdb.ZRevRank(ctx, "zranktmp", "x").Val()) require.Equal(t, int64(1), rdb.ZRevRank(ctx, "zranktmp", "y").Val()) require.Equal(t, int64(0), rdb.ZRevRank(ctx, "zranktmp", "z").Val()) - require.Equal(t, int64(0), rdb.ZRevRank(ctx, "zranktmp", "foo").Val()) + require.Equal(t, redis.Nil, rdb.ZRevRank(ctx, "zranktmp", "foo").Err()) + + require.Equal(t, []interface{}{int64(0), "10"}, rdb.Do(ctx, "zrank", "zranktmp", "x", "withscore").Val()) + require.Equal(t, []interface{}{int64(1), "20"}, rdb.Do(ctx, "zrank", "zranktmp", "y", "withscore").Val()) + require.Equal(t, []interface{}{int64(2), "30"}, rdb.Do(ctx, "zrank", "zranktmp", "z", "withscore").Val()) + require.Equal(t, redis.Nil, rdb.Do(ctx, "zrank", "zranktmp", "foo", "withscore").Err()) + require.Equal(t, []interface{}{int64(2), "10"}, rdb.Do(ctx, "zrevrank", "zranktmp", "x", "withscore").Val()) + require.Equal(t, []interface{}{int64(1), "20"}, rdb.Do(ctx, "zrevrank", "zranktmp", "y", "withscore").Val()) + require.Equal(t, []interface{}{int64(0), "30"}, rdb.Do(ctx, "zrevrank", "zranktmp", "z", "withscore").Val()) + require.Equal(t, redis.Nil, rdb.Do(ctx, "zrevrank", "zranktmp", "foo", "withscore").Err()) }) - t.Run(fmt.Sprintf("ZRANK - after deletion -%s", encoding), func(t *testing.T) { + t.Run(fmt.Sprintf("ZRANK/ZREVRANK - after deletion -%s", encoding), func(t *testing.T) { rdb.ZRem(ctx, "zranktmp", "y") + require.Equal(t, int64(0), rdb.ZRank(ctx, "zranktmp", "x").Val()) require.Equal(t, int64(1), rdb.ZRank(ctx, "zranktmp", "z").Val()) + require.Equal(t, redis.Nil, rdb.ZRank(ctx, "zranktmp", "foo").Err()) + require.Equal(t, int64(1), rdb.ZRevRank(ctx, "zranktmp", "x").Val()) + require.Equal(t, int64(0), rdb.ZRevRank(ctx, "zranktmp", "z").Val()) + require.Equal(t, redis.Nil, rdb.ZRevRank(ctx, "zranktmp", "foo").Err()) + + require.Equal(t, []interface{}{int64(0), "10"}, rdb.Do(ctx, "zrank", "zranktmp", "x", "withscore").Val()) + require.Equal(t, []interface{}{int64(1), "30"}, rdb.Do(ctx, "zrank", "zranktmp", "z", "withscore").Val()) + require.Equal(t, redis.Nil, rdb.Do(ctx, "zrank", "zranktmp", "foo", "withscore").Err()) + require.Equal(t, []interface{}{int64(1), "10"}, rdb.Do(ctx, "zrevrank", "zranktmp", "x", "withscore").Val()) + require.Equal(t, []interface{}{int64(0), "30"}, rdb.Do(ctx, "zrevrank", "zranktmp", "z", "withscore").Val()) + require.Equal(t, redis.Nil, rdb.Do(ctx, "zrevrank", "zranktmp", "foo", "withscore").Err()) }) t.Run(fmt.Sprintf("ZINCRBY - can create a new sorted set - %s", encoding), func(t *testing.T) {