Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor StringFamily::Set to use CmdArgParser #2800

Merged
merged 3 commits into from
Mar 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions src/server/string_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "server/string_family.h"

#include <absl/container/inlined_vector.h>
#include <absl/strings/match.h>

#include <algorithm>
#include <array>
Expand All @@ -14,6 +15,8 @@

#include "base/flags.h"
#include "base/logging.h"
#include "base/stl_util.h"
#include "facade/cmd_arg_parser.h"
#include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
Expand Down Expand Up @@ -693,49 +696,48 @@ void SetCmd::RecordJournal(const SetParams& params, string_view key, string_view
}

void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
string_view value = ArgS(args, 1);
facade::CmdArgParser parser{args};

auto [key, value] = parser.Next<string_view, string_view>();

SetCmd::SetParams sparams;
sparams.memcache_flags = cntx->conn_state.memcache_flag;

int64_t int_arg;
SinkReplyBuilder* builder = cntx->reply_builder();
facade::SinkReplyBuilder* builder = cntx->reply_builder();

for (size_t i = 2; i < args.size(); ++i) {
ToUpper(&args[i]);
while (parser.HasNext()) {
parser.ToUpper();
if (base::_in(parser.Peek(), {"EX", "PX", "EXAT", "PXAT"})) {
auto [opt, int_arg] = parser.Next<string_view, int64_t>();

string_view cur_arg = ArgS(args, i);
if (auto err = parser.Error(); err) {
return builder->SendError(err->MakeReply());
}

if ((cur_arg == "EX" || cur_arg == "PX" || cur_arg == "EXAT" || cur_arg == "PXAT") &&
!(sparams.flags & SetCmd::SET_KEEP_EXPIRE) &&
!(sparams.flags & SetCmd::SET_EXPIRE_AFTER_MS)) {
sparams.flags |= SetCmd::SET_EXPIRE_AFTER_MS;
bool is_ms = (cur_arg == "PX" || cur_arg == "PXAT");
++i;
if (i == args.size()) {
// We can set expiry only once.
if (sparams.flags & SetCmd::SET_EXPIRE_AFTER_MS)
return builder->SendError(kSyntaxErr);
}

string_view ex = ArgS(args, i);
if (!absl::SimpleAtoi(ex, &int_arg)) {
return builder->SendError(kInvalidIntErr);
}
sparams.flags |= SetCmd::SET_EXPIRE_AFTER_MS;

// Since PXAT/EXAT can change this, we need to check this ahead
if (int_arg <= 0) {
return builder->SendError(InvalidExpireTime("set"));
}

bool is_ms = (opt[0] == 'P');

// for []AT we need to take expiration time as absolute from the value given
// check here and if the time is in the past, return OK but don't set it
// Note that the time pass here for PXAT is in milliseconds, we must not change it!
if (cur_arg == "EXAT" || cur_arg == "PXAT") {
if (absl::EndsWith(opt, "AT")) {
int_arg = AbsExpiryToTtl(int_arg, is_ms);
if (int_arg < 0) {
// this happened in the past, just return, for some reason Redis reports OK in this case
return builder->SendStored();
}
}

if (is_ms) {
if (int_arg > kMaxExpireDeadlineMs) {
int_arg = kMaxExpireDeadlineMs;
Expand All @@ -747,22 +749,26 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
int_arg *= 1000;
}
sparams.expire_after_ms = int_arg;
} else if (cur_arg == "NX" && !(sparams.flags & SetCmd::SET_IF_EXISTS)) {
sparams.flags |= SetCmd::SET_IF_NOTEXIST;
} else if (cur_arg == "XX" && !(sparams.flags & SetCmd::SET_IF_NOTEXIST)) {
sparams.flags |= SetCmd::SET_IF_EXISTS;
} else if (cur_arg == "KEEPTTL" && !(sparams.flags & SetCmd::SET_EXPIRE_AFTER_MS)) {
sparams.flags |= SetCmd::SET_KEEP_EXPIRE;
} else if (cur_arg == "GET") {
sparams.flags |= SetCmd::SET_GET;
} else if (cur_arg == "STICK") {
sparams.flags |= SetCmd::SET_STICK;
} else {
return builder->SendError(kSyntaxErr);
uint16_t flag = parser.Switch( //
"GET", SetCmd::SET_GET, "STICK", SetCmd::SET_STICK, "KEEPTTL", SetCmd::SET_KEEP_EXPIRE,
"XX", SetCmd::SET_IF_EXISTS, "NX", SetCmd::SET_IF_NOTEXIST);
sparams.flags |= flag;
}
}

const auto result{SetGeneric(cntx, sparams, key, value, true)};
if (auto err = parser.Error(); err) {
return builder->SendError(err->MakeReply());
}

auto has_mask = [&](uint16_t m) { return (sparams.flags & m) == m; };

if (has_mask(SetCmd::SET_IF_EXISTS | SetCmd::SET_IF_NOTEXIST) ||
has_mask(SetCmd::SET_KEEP_EXPIRE | SetCmd::SET_EXPIRE_AFTER_MS)) {
return builder->SendError(kSyntaxErr);
}

OpResult result{SetGeneric(cntx, sparams, key, value, true)};

if (sparams.flags & SetCmd::SET_GET) {
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
Expand All @@ -783,7 +789,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
return builder->SendError(kOutOfMemory);
}

CHECK_EQ(result, OpStatus::SKIPPED); // in case of NX option
DCHECK_EQ(result, OpStatus::SKIPPED); // in case of NX option

builder->SendSetSkipped();
}
Expand Down
Loading