Skip to content

Commit

Permalink
fix(script): avoid SetCurrentConnection on read-only scriting
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice committed Nov 2, 2024
1 parent 7a3cc8c commit 9b0705e
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 38 deletions.
12 changes: 0 additions & 12 deletions src/server/redis_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,22 +417,10 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
// No lock guard, because 'exec' command has acquired 'WorkExclusivityGuard'
} else if (cmd_flags & kCmdExclusive) {
exclusivity = srv_->WorkExclusivityGuard();

// When executing lua script commands that have "exclusive" attribute, we need to know current connection,
// but we should set current connection after acquiring the WorkExclusivityGuard to make it thread-safe
srv_->SetCurrentConnection(this);
} else {
concurrency = srv_->WorkConcurrencyGuard();
}

auto category = attributes->category;
if ((category == CommandCategory::Function || category == CommandCategory::Script) && (cmd_flags & kCmdReadOnly)) {
// FIXME: since read-only script commands are not exclusive,
// SetCurrentConnection here is weird and can cause many issues,
// we should pass the Connection directly to the lua context instead
srv_->SetCurrentConnection(this);
}

if (srv_->IsLoading() && !(cmd_flags & kCmdLoading)) {
Reply(redis::Error({Status::RedisLoading, errRestoringBackup}));
if (is_multi_exec) multi_error_ = true;
Expand Down
4 changes: 2 additions & 2 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Server::Server(engine::Storage *storage, Config *config)
AdjustOpenFilesLimit();
slow_log_.SetMaxEntries(config->slowlog_max_len);
perf_log_.SetMaxEntries(config->profiling_sample_record_max_len);
lua_ = lua::CreateState(this);
lua_ = lua::CreateState();
}

Server::~Server() {
Expand Down Expand Up @@ -1764,7 +1764,7 @@ Status Server::FunctionSetLib(const std::string &func, const std::string &lib) c
}

void Server::ScriptReset() {
auto lua = lua_.exchange(lua::CreateState(this));
auto lua = lua_.exchange(lua::CreateState());
lua::DestroyState(lua);
}

Expand Down
3 changes: 0 additions & 3 deletions src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,6 @@ class Server {
Status ExecPropagatedCommand(const std::vector<std::string> &tokens);
Status ExecPropagateScriptCommand(const std::vector<std::string> &tokens);

void SetCurrentConnection(redis::Connection *conn) { curr_connection_ = conn; }
redis::Connection *GetCurrentConnection() { return curr_connection_; }

LogCollector<PerfEntry> *GetPerfLog() { return &perf_log_; }
LogCollector<SlowEntry> *GetSlowLog() { return &slow_log_; }
void SlowlogPushEntryIfNeeded(const std::vector<std::string> *args, uint64_t duration, const redis::Connection *conn);
Expand Down
2 changes: 1 addition & 1 deletion src/server/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Worker::Worker(Server *srv, Config *config) : srv(srv), base_(event_base_new())
}
}
}
lua_ = lua::CreateState(srv);
lua_ = lua::CreateState();
}

Worker::~Worker() {
Expand Down
30 changes: 13 additions & 17 deletions src/storage/scripting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,12 @@ enum {

namespace lua {

lua_State *CreateState(Server *srv) {
lua_State *CreateState() {
lua_State *lua = lua_open();
LoadLibraries(lua);
RemoveUnsupportedFunctions(lua);
LoadFuncs(lua);

lua_pushlightuserdata(lua, srv);
lua_setglobal(lua, REDIS_LUA_SERVER_PTR);

EnableGlobalsProtection(lua);
return lua;
}
Expand Down Expand Up @@ -273,7 +270,10 @@ int RedisRegisterFunction(lua_State *lua) {
}

// store the map from function name to library name
auto s = GetServer(lua)->FunctionSetLib(name, libname);
auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
CHECK_NOTNULL(script_run_ctx);

auto s = script_run_ctx->conn->GetServer()->FunctionSetLib(name, libname);
if (!s) {
lua_pushstring(lua, "redis.register_function() failed to store informantion.");
return lua_error(lua);
Expand Down Expand Up @@ -396,6 +396,7 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std:
}

ScriptRunCtx script_run_ctx;
script_run_ctx.conn = conn;
script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
if (!lua_isnil(lua, -1)) {
Expand Down Expand Up @@ -642,6 +643,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh
}

ScriptRunCtx current_script_run_ctx;
current_script_run_ctx.conn = conn;
current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str());
if (!lua_isnil(lua, -1)) {
Expand Down Expand Up @@ -709,14 +711,6 @@ int RedisCallCommand(lua_State *lua) { return RedisGenericCommand(lua, 1); }

int RedisPCallCommand(lua_State *lua) { return RedisGenericCommand(lua, 0); }

Server *GetServer(lua_State *lua) {
lua_getglobal(lua, REDIS_LUA_SERVER_PTR);
auto srv = reinterpret_cast<Server *>(lua_touserdata(lua, -1));
lua_pop(lua, 1);

return srv;
}

// TODO: we do not want to repeat same logic as Connection::ExecuteCommands,
// so the function need to be refactored
int RedisGenericCommand(lua_State *lua, int raise_error) {
Expand Down Expand Up @@ -772,10 +766,10 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {

std::string cmd_name = attributes->name;

auto srv = GetServer(lua);
auto *conn = script_run_ctx->conn;
auto *srv = conn->GetServer();
Config *config = srv->GetConfig();

redis::Connection *conn = srv->GetCurrentConnection();
if (config->cluster_enabled) {
if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) {
PushError(lua, "Can not run script on cluster, 'no-cluster' flag is set");
Expand Down Expand Up @@ -901,8 +895,10 @@ int RedisReturnSingleFieldTable(lua_State *lua, const char *field) {
}

int RedisSetResp(lua_State *lua) {
auto srv = GetServer(lua);
auto conn = srv->GetCurrentConnection();
auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
CHECK_NOTNULL(script_run_ctx);
auto *conn = script_run_ctx->conn;
auto *srv = conn->GetServer();

if (lua_gettop(lua) != 1) {
PushError(lua, "redis.setresp() requires one argument.");
Expand Down
6 changes: 3 additions & 3 deletions src/storage/scripting.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,15 @@ inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] = "f_";
inline constexpr const char REDIS_LUA_FUNC_SHA_FLAGS[] = "f_{}_flags_";
inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] = "__redis_registered_";
inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_registered_flags_";
inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr";
inline constexpr const char REDIS_FUNCTION_LIBNAME[] = "REDIS_FUNCTION_LIBNAME";
inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] = "REDIS_FUNCTION_NEEDSTORE";
inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LIBRARIES";
inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX";

namespace lua {

lua_State *CreateState(Server *srv);
lua_State *CreateState();
void DestroyState(lua_State *lua);
Server *GetServer(lua_State *lua);

void LoadFuncs(lua_State *lua);
void LoadLibraries(lua_State *lua);
Expand Down Expand Up @@ -150,6 +148,8 @@ struct ScriptRunCtx {
// and is used to detect whether there is cross-slot access
// between multiple commands in a script or function.
int current_slot = -1;
// the current connection
redis::Connection *conn = nullptr;
};

/// SaveOnRegistry saves user-defined data to lua REGISTRY
Expand Down

0 comments on commit 9b0705e

Please sign in to comment.