diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index 35a6a6ca840..a226d8702c0 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -470,11 +470,11 @@ class CommandZMPop : public Commander { } while (parser.Good()) { - if (parser.EatEqICase("min")) { + if (flag_ == ZSET_NONE && parser.EatEqICase("min")) { flag_ = ZSET_MIN; - } else if (parser.EatEqICase("max")) { + } else if (flag_ == ZSET_NONE && parser.EatEqICase("max")) { flag_ = ZSET_MAX; - } else if (parser.EatEqICase("count")) { + } else if (count_ == 0 && parser.EatEqICase("count")) { count_ = GET_OR_RET(parser.TakeInt(NumericRange{1, std::numeric_limits::max()})); } else { return parser.InvalidSyntax(); @@ -483,6 +483,7 @@ class CommandZMPop : public Commander { if (flag_ == ZSET_NONE) { return parser.InvalidSyntax(); } + if (count_ == 0) count_ = 1; return Commander::Parse(args); } @@ -514,7 +515,7 @@ class CommandZMPop : public Commander { int numkeys_; std::vector keys_; enum { ZSET_MIN, ZSET_MAX, ZSET_NONE } flag_ = ZSET_NONE; - int count_ = 1; + int count_ = 0; }; class CommandBZMPop : public Commander, @@ -535,11 +536,11 @@ class CommandBZMPop : public Commander, } while (parser.Good()) { - if (parser.EatEqICase("min")) { + if (flag_ == ZSET_NONE && parser.EatEqICase("min")) { flag_ = ZSET_MIN; - } else if (parser.EatEqICase("max")) { + } else if (flag_ == ZSET_NONE && parser.EatEqICase("max")) { flag_ = ZSET_MAX; - } else if (parser.EatEqICase("count")) { + } else if (count_ == 0 && parser.EatEqICase("count")) { count_ = GET_OR_RET(parser.TakeInt(NumericRange{1, std::numeric_limits::max()})); } else { return parser.InvalidSyntax(); @@ -549,6 +550,7 @@ class CommandBZMPop : public Commander, if (flag_ == ZSET_NONE) { return parser.InvalidSyntax(); } + if (count_ == 0) count_ = 1; return Commander::Parse(args); } @@ -659,7 +661,7 @@ class CommandBZMPop : public Commander, int num_keys_; std::vector keys_; enum { ZSET_MIN, ZSET_MAX, ZSET_NONE } flag_ = ZSET_NONE; - int count_ = 1; + int count_ = 0; Server *svr_ = nullptr; Connection *conn_ = nullptr; UniqueEvent timer_; diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index a6b150c770b..1b5a2980956 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -392,6 +392,19 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s require.EqualValues(t, 0, rdb.Exists(ctx, "zseta", "zsetb").Val()) }) + t.Run(fmt.Sprintf("ZMPOP error - %s", encoding), func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + + util.ErrorRegexp(t, rdb.Do(ctx, "zmpop", 1, "zseta").Err(), ".*wrong number of arguments.*") + util.ErrorRegexp(t, rdb.Do(ctx, "zmpop", "wrong_numkeys", "zseta", "zsetb").Err(), ".*not started as an integer.*") + util.ErrorRegexp(t, rdb.Do(ctx, "zmpop", 2, "zseta", "min").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "zmpop", 2, "zseta", "zsetb", "min", "min").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "zmpop", 1, "zseta", "min", "max").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "zmpop", 1, "zseta", "min", "count", "wrong_count").Err(), ".*not started as an integer.*") + util.ErrorRegexp(t, rdb.Do(ctx, "zmpop", 1, "zseta", "min", "count", 1, "count", 10).Err(), ".*syntax error.*") + }) + t.Run(fmt.Sprintf("BZMPOP basics - %s", encoding), func(t *testing.T) { rdb.Del(ctx, "zseta") rdb.Del(ctx, "zsetb") @@ -427,6 +440,20 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s require.Equal(t, []redis.Z{{Score: 1, Member: "a"}, {Score: 2, Member: "b"}}, zset) }) + t.Run(fmt.Sprintf("BZMPOP error - %s", encoding), func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", 0.1, 1, "zseta").Err(), ".*wrong number of arguments.*") + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", "wrong_timeout", 1, "zseta", "min").Err(), ".*not started as a number.*") + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", 0.1, "wrong_numkeys", "zseta", "min").Err(), ".*not started as an integer.*") + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", 0.1, 2, "zseta", "min").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", 0.1, 1, "zseta", "min", "max").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", 0.1, 2, "zseta", "zsetb", "min", "min").Err(), ".*syntax error.*") + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", 0.1, 1, "zseta", "min", "count", "wrong_count").Err(), ".*not started as an integer.*") + util.ErrorRegexp(t, rdb.Do(ctx, "bzmpop", 0.1, 1, "zseta", "min", "count", 1, "count", 10).Err(), ".*syntax error.*") + }) + t.Run(fmt.Sprintf("ZRANGESTORE basics - %s", encoding), func(t *testing.T) { rdb.Del(ctx, "zsrc") rdb.Del(ctx, "zdst")