diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index 2cc56919647..c7f742656d0 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -1364,7 +1364,7 @@ REDIS_REGISTER_COMMANDS(Server, MakeCmdAttr("auth", 2, "read-only o MakeCmdAttr("slaveof", 3, "read-only exclusive no-script", NO_KEY), MakeCmdAttr("stats", 1, "read-only", NO_KEY), MakeCmdAttr("rdb", -3, "write exclusive", NO_KEY), - MakeCmdAttr("reset", 1, "ok-loading multi no-script", NO_KEY), + MakeCmdAttr("reset", 1, "ok-loading bypass-multi no-script", NO_KEY), MakeCmdAttr("applybatch", -2, "write no-multi", NO_KEY), MakeCmdAttr("dump", 2, "read-only", 1, 1, 1), MakeCmdAttr("pollupdates", -2, "read-only", NO_KEY), ) diff --git a/src/commands/cmd_txn.cc b/src/commands/cmd_txn.cc index 3d88d9a2ec4..5f922ddf5ef 100644 --- a/src/commands/cmd_txn.cc +++ b/src/commands/cmd_txn.cc @@ -98,10 +98,6 @@ class CommandExec : public Commander { class CommandWatch : public Commander { public: Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { - if (conn->IsFlagEnabled(Connection::kMultiExec)) { - return {Status::RedisExecErr, "WATCH inside MULTI is not allowed"}; - } - // If a conn is already marked as watched_keys_modified, we can skip the watch. if (srv->IsWatchedKeysModified(conn)) { *output = redis::RESP_OK; @@ -123,10 +119,10 @@ class CommandUnwatch : public Commander { } }; -REDIS_REGISTER_COMMANDS(Txn, MakeCmdAttr("multi", 1, "multi", NO_KEY), - MakeCmdAttr("discard", 1, "multi", NO_KEY), - MakeCmdAttr("exec", 1, "exclusive multi slow", NO_KEY), - MakeCmdAttr("watch", -2, "multi", 1, -1, 1), - MakeCmdAttr("unwatch", 1, "multi", NO_KEY), ) +REDIS_REGISTER_COMMANDS(Txn, MakeCmdAttr("multi", 1, "bypass-multi", NO_KEY), + MakeCmdAttr("discard", 1, "bypass-multi", NO_KEY), + MakeCmdAttr("exec", 1, "exclusive bypass-multi slow", NO_KEY), + MakeCmdAttr("watch", -2, "no-multi", 1, -1, 1), + MakeCmdAttr("unwatch", 1, "no-multi", NO_KEY), ) } // namespace redis diff --git a/src/commands/commander.h b/src/commands/commander.h index b2a6ae3d330..f0589c92c42 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -64,8 +64,9 @@ enum CommandFlags : uint64_t { // "ok-loading" flag, for any command that can be executed while // the db is in loading phase kCmdLoading = 1ULL << 5, - // "multi" flag, for commands that can end a MULTI scope - kCmdEndMulti = 1ULL << 6, + // "bypass-multi" flag, for commands that can be executed in a MULTI scope, + // but these commands will NOT be queued and will be executed immediately + kCmdBypassMulti = 1ULL << 6, // "exclusive" flag, for commands that should be executed execlusive globally kCmdExclusive = 1ULL << 7, // "no-multi" flag, for commands that cannot be executed in MULTI scope @@ -320,8 +321,8 @@ inline uint64_t ParseCommandFlags(const std::string &description, const std::str flags |= kCmdLoading; else if (flag == "exclusive") flags |= kCmdExclusive; - else if (flag == "multi") - flags |= kCmdEndMulti; + else if (flag == "bypass-multi") + flags |= kCmdBypassMulti; else if (flag == "no-multi") flags |= kCmdNoMulti; else if (flag == "no-script") diff --git a/src/common/string_util.cc b/src/common/string_util.cc index cce6440227a..c1e6f7e3ffe 100644 --- a/src/common/string_util.cc +++ b/src/common/string_util.cc @@ -42,6 +42,11 @@ std::string ToLower(std::string in) { return in; } +std::string ToUpper(std::string in) { + std::transform(in.begin(), in.end(), in.begin(), [](char c) -> char { return static_cast(std::toupper(c)); }); + return in; +} + bool EqualICase(std::string_view lhs, std::string_view rhs) { return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin(), [](char l, char r) { return std::tolower(l) == std::tolower(r); }); diff --git a/src/common/string_util.h b/src/common/string_util.h index f86590ad046..619d95d5d09 100644 --- a/src/common/string_util.h +++ b/src/common/string_util.h @@ -32,6 +32,7 @@ namespace util { std::string Float2String(double d); std::string ToLower(std::string in); +std::string ToUpper(std::string in); bool EqualICase(std::string_view lhs, std::string_view rhs); std::string BytesToHuman(uint64_t n); std::string Trim(std::string in, std::string_view chars); diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index ca00de1684b..2e4ed68c552 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -413,7 +413,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { // that can guarantee other threads can't come into critical zone, such as DEBUG, // CLUSTER subcommand, CONFIG SET, MULTI, LUA (in the immediate future). // Otherwise, we just use 'ConcurrencyGuard' to allow all workers to execute commands at the same time. - if (is_multi_exec && cmd_name != "exec") { + if (is_multi_exec && !(cmd_flags & kCmdBypassMulti)) { // No lock guard, because 'exec' command has acquired 'WorkExclusivityGuard' } else if (cmd_flags & kCmdExclusive) { exclusivity = srv_->WorkExclusivityGuard(); @@ -443,7 +443,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { } if (is_multi_exec && (cmd_flags & kCmdNoMulti)) { - Reply(redis::Error({Status::NotOK, "Can't execute " + cmd_name + " in MULTI"})); + Reply(redis::Error({Status::NotOK, fmt::format("{} inside MULTI is not allowed", util::ToUpper(cmd_name))})); multi_error_ = true; continue; } @@ -463,7 +463,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { } // We don't execute commands, but queue them, and then execute in EXEC command - if (is_multi_exec && !in_exec_ && !(cmd_flags & kCmdEndMulti)) { + if (is_multi_exec && !in_exec_ && !(cmd_flags & kCmdBypassMulti)) { multi_cmds_.emplace_back(std::move(cmd_tokens)); Reply(redis::SimpleString("QUEUED")); continue; diff --git a/tests/gocase/unit/multi/multi_test.go b/tests/gocase/unit/multi/multi_test.go index c2a96917b41..084eb805ba7 100644 --- a/tests/gocase/unit/multi/multi_test.go +++ b/tests/gocase/unit/multi/multi_test.go @@ -174,7 +174,7 @@ func TestMulti(t *testing.T) { t.Run("WATCH inside MULTI is not allowed", func(t *testing.T) { require.NoError(t, rdb.Do(ctx, "MULTI").Err()) require.EqualError(t, rdb.Do(ctx, "WATCH", "x").Err(), "ERR WATCH inside MULTI is not allowed") - require.NoError(t, rdb.Do(ctx, "EXEC").Err()) + require.NoError(t, rdb.Do(ctx, "DISCARD").Err()) }) t.Run("EXEC without MULTI is not allowed", func(t *testing.T) {