From 6747b8da7a7280a39b4f85e498249dc51f6322d4 Mon Sep 17 00:00:00 2001 From: Twice Date: Thu, 14 Nov 2024 16:43:15 +0800 Subject: [PATCH] fix(cmd): args should be parsed before retrieving keys in COMMAND GETKEYS (#2661) --- src/cluster/cluster.cc | 8 +++++--- src/commands/commander.cc | 5 +++++ src/commands/commander.h | 2 ++ src/storage/scripting.cc | 14 +++++++------- tests/gocase/unit/command/command_test.go | 16 ++++++++-------- 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index 3850e12fc29..b04fb5890ec 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -835,9 +835,11 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx) { std::vector key_indexes; - auto s = redis::CommandTable::GetKeysFromCommand(attributes, cmd_tokens); - if (!s) return Status::OK(); - key_indexes = *s; + attributes->ForEachKeyRange( + [&](const std::vector &, redis::CommandKeyRange key_range) { + key_range.ForEachKeyIndex([&](int i) { key_indexes.push_back(i); }, cmd_tokens.size()); + }, + cmd_tokens); if (key_indexes.empty()) return Status::OK(); diff --git a/src/commands/commander.cc b/src/commands/commander.cc index 063b1c469d0..da32800be94 100644 --- a/src/commands/commander.cc +++ b/src/commands/commander.cc @@ -88,6 +88,11 @@ StatusOr> CommandTable::GetKeysFromCommand(const CommandAttribu return {Status::NotOK, "Invalid number of arguments specified for command"}; } + auto cmd = attributes->factory(); + if (auto s = cmd->Parse(cmd_tokens); !s) { + return {Status::NotOK, "Invalid syntax found in this command arguments: " + s.Msg()}; + } + Status status; std::vector key_indexes; diff --git a/src/commands/commander.h b/src/commands/commander.h index b1d0a04975f..ac3d3aa9939 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -254,6 +254,8 @@ struct CommandAttributes { return {Status::NotOK, "key range is unavailable without command arguments"}; } + // the command arguments must be parsed and in valid syntax + // before this method is called, otherwise the behavior is UNDEFINED template void ForEachKeyRange(F &&f, const std::vector &args, G &&g) const { if (key_range_.first_key > 0) { diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 5768aee8169..9676b439952 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -778,6 +778,13 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { auto *srv = conn->GetServer(); Config *config = srv->GetConfig(); + cmd->SetArgs(args); + auto s = cmd->Parse(); + if (!s) { + PushError(lua, s.Msg().data()); + return raise_error ? RaiseError(lua) : 1; + } + if (config->cluster_enabled) { if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) { PushError(lua, "Can not run script on cluster, 'no-cluster' flag is set"); @@ -807,13 +814,6 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { return raise_error ? RaiseError(lua) : 1; } - cmd->SetArgs(args); - auto s = cmd->Parse(); - if (!s) { - PushError(lua, s.Msg().data()); - return raise_error ? RaiseError(lua) : 1; - } - std::string output; // TODO: make it possible for multiple redis commands in lua script to use the same txn context. { diff --git a/tests/gocase/unit/command/command_test.go b/tests/gocase/unit/command/command_test.go index 18fb2d3df1e..d6f233fa520 100644 --- a/tests/gocase/unit/command/command_test.go +++ b/tests/gocase/unit/command/command_test.go @@ -180,7 +180,7 @@ func TestCommand(t *testing.T) { }) t.Run("COMMAND GETKEYS ZMPOP", func(t *testing.T) { - r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZMPOP", "2", "key1", "key2") + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZMPOP", "2", "key1", "key2", "min") vs, err := r.Slice() require.NoError(t, err) require.Len(t, vs, 2) @@ -189,7 +189,7 @@ func TestCommand(t *testing.T) { }) t.Run("COMMAND GETKEYS BZMPOP", func(t *testing.T) { - r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BZMPOP", "0", "2", "key1", "key2") + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BZMPOP", "0", "2", "key1", "key2", "min") vs, err := r.Slice() require.NoError(t, err) require.Len(t, vs, 2) @@ -198,7 +198,7 @@ func TestCommand(t *testing.T) { }) t.Run("COMMAND GETKEYS LMPOP", func(t *testing.T) { - r := rdb.Do(ctx, "COMMAND", "GETKEYS", "LMPOP", "2", "key1", "key2") + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "LMPOP", "2", "key1", "key2", "left") vs, err := r.Slice() require.NoError(t, err) require.Len(t, vs, 2) @@ -207,7 +207,7 @@ func TestCommand(t *testing.T) { }) t.Run("COMMAND GETKEYS BLMPOP", func(t *testing.T) { - r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BLMPOP", "0", "2", "key1", "key2") + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BLMPOP", "0", "2", "key1", "key2", "left") vs, err := r.Slice() require.NoError(t, err) require.Len(t, vs, 2) @@ -250,14 +250,14 @@ func TestCommand(t *testing.T) { t.Run("COMMAND GETKEYS GEORADIUSBYMEMBER", func(t *testing.T) { // non-store - r := rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m") + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "100", "m") vs, err := r.Slice() require.NoError(t, err) require.Len(t, vs, 1) require.Equal(t, "src", vs[0]) // store - r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m", "store", "dst") + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "100", "m", "store", "dst") vs, err = r.Slice() require.NoError(t, err) require.Len(t, vs, 2) @@ -265,7 +265,7 @@ func TestCommand(t *testing.T) { require.Equal(t, "dst", vs[1]) // storedist - r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m", "storedist", "dst") + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "100", "m", "storedist", "dst") vs, err = r.Slice() require.NoError(t, err) require.Len(t, vs, 2) @@ -273,7 +273,7 @@ func TestCommand(t *testing.T) { require.Equal(t, "dst", vs[1]) // store + storedist - r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m", "store", "dst1", "storedist", "dst2") + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "100", "m", "store", "dst1", "storedist", "dst2") vs, err = r.Slice() require.NoError(t, err) require.Len(t, vs, 2)