Skip to content

Commit

Permalink
chore: pass SinkReplyBuilder and Transaction explicitly. Part6 (#3987)
Browse files Browse the repository at this point in the history
  • Loading branch information
romange authored Oct 24, 2024
1 parent 16f59d3 commit 7035606
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 396 deletions.
120 changes: 66 additions & 54 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ AclFamily::AclFamily(UserRegistry* registry, util::ProactorPool* pool)
: registry_(registry), pool_(pool) {
}

void AclFamily::Acl(CmdArgList args, ConnectionContext* cntx) {
cntx->SendError("Wrong number of arguments for acl command");
void AclFamily::Acl(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
builder->SendError("Wrong number of arguments for acl command");
}

void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::List(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(registry.size());

for (const auto& [username, user] : registry) {
Expand All @@ -91,7 +91,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password,
acl_keys, maybe_space_com, acl_pub_sub, " ", acl_cat_and_commands);

cntx->SendSimpleString(buffer);
builder->SendSimpleString(buffer);
}
}

Expand All @@ -116,17 +116,17 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,

using facade::ErrorReply;

void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::SetUser(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view username = facade::ToSV(args[0]);
auto reg = registry_->GetRegistryWithWriteLock();
const bool exists = reg.registry.contains(username);
const bool has_all_keys = exists ? reg.registry.find(username)->second.Keys().all_keys : false;

auto req = ParseAclSetUser(args.subspan(1), false, has_all_keys);

auto error_case = [cntx](ErrorReply&& error) { cntx->SendError(error); };
auto error_case = [builder](ErrorReply&& error) { builder->SendError(error); };

auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto update_case = [username, &reg, builder, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
if (!exists) {
User::UpdateRequest default_req;
Expand All @@ -137,7 +137,7 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
const bool reset_channels = req.reset_channels;
user.Update(std::move(req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex());
// Send ok first because the connection might get evicted
cntx->SendOk();
builder->SendOk();
if (exists) {
if (!reset_channels) {
StreamUpdatesToAllProactorConnections(string(username), user.AclCommands(), user.Keys(),
Expand Down Expand Up @@ -184,7 +184,7 @@ void AclFamily::EvictOpenConnectionsOnAllProactorsWithRegistry(
}
}

void AclFamily::DelUser(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::DelUser(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto& registry = *registry_;
absl::flat_hash_set<string_view> users;

Expand All @@ -199,17 +199,18 @@ void AclFamily::DelUser(CmdArgList args, ConnectionContext* cntx) {
}

if (users.empty()) {
cntx->SendLong(0);
builder->SendLong(0);
return;
}
VLOG(1) << "Evicting open acl connections";
EvictOpenConnectionsOnAllProactors(users);
VLOG(1) << "Done evicting open acl connections";
cntx->SendLong(users.size());
builder->SendLong(users.size());
}

void AclFamily::WhoAmI(CmdArgList args, ConnectionContext* cntx) {
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
void AclFamily::WhoAmI(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
ConnectionContext* cntx) {
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->SendBulkString(absl::StrCat("User is ", cntx->authed_username));
}

Expand Down Expand Up @@ -239,18 +240,18 @@ string AclFamily::RegistryToString() const {
return result;
}

void AclFamily::Save(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Save(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto acl_file_path = absl::GetFlag(FLAGS_aclfile);
if (acl_file_path.empty()) {
cntx->SendError("Dragonfly is not configured to use an ACL file.");
builder->SendError("Dragonfly is not configured to use an ACL file.");
return;
}

auto res = io::OpenWrite(acl_file_path);
if (!res) {
std::string error = absl::StrCat("Failed to open the aclfile: ", res.error().message());
LOG(ERROR) << error;
cntx->SendError(error);
builder->SendError(error);
return;
}

Expand All @@ -261,23 +262,23 @@ void AclFamily::Save(CmdArgList args, ConnectionContext* cntx) {
if (ec) {
std::string error = absl::StrCat("Failed to write to the aclfile: ", ec.message());
LOG(ERROR) << error;
cntx->SendError(error);
builder->SendError(error);
return;
}

ec = file->Close();
if (ec) {
std::string error = absl::StrCat("Failed to close the aclfile ", ec.message());
LOG(WARNING) << error;
cntx->SendError(error);
builder->SendError(error);
return;
}

cntx->SendOk();
builder->SendOk();
}

GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,
ConnectionContext* cntx) {
SinkReplyBuilder* builder) {
auto is_file_read = io::ReadFileToString(full_path);
if (!is_file_read) {
auto error = absl::StrCat("Dragonfly could not load ACL file ", full_path, " with error ",
Expand Down Expand Up @@ -316,8 +317,8 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,

auto registry_with_wlock = registry_->GetRegistryWithWriteLock();
auto& registry = registry_with_wlock.registry;
if (cntx) {
cntx->SendOk();
if (builder) {
builder->SendOk();
// Evict open connections for old users
EvictOpenConnectionsOnAllProactorsWithRegistry(registry);
registry.clear();
Expand Down Expand Up @@ -347,23 +348,23 @@ bool AclFamily::Load() {
return !LoadToRegistryFromFile(acl_file, nullptr);
}

void AclFamily::Load(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Load(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto acl_file = absl::GetFlag(FLAGS_aclfile);
if (acl_file.empty()) {
cntx->SendError("Dragonfly is not configured to use an ACL file.");
builder->SendError("Dragonfly is not configured to use an ACL file.");
return;
}

const auto load_error = LoadToRegistryFromFile(acl_file, cntx);
const auto load_error = LoadToRegistryFromFile(acl_file, builder);

if (load_error) {
cntx->SendError(absl::StrCat("Error loading: ", acl_file, " ", load_error.Format()));
builder->SendError(absl::StrCat("Error loading: ", acl_file, " ", load_error.Format()));
}
}

void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Log(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
if (args.size() > 1) {
cntx->SendError(facade::OpStatus::OUT_OF_RANGE);
builder->SendError(facade::OpStatus::OUT_OF_RANGE);
}

size_t max_output = 10;
Expand All @@ -372,12 +373,12 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
if (absl::EqualsIgnoreCase(option, "RESET")) {
pool_->AwaitFiberOnAll(
[](auto index, auto* context) { ServerState::tlocal()->acl_log.Reset(); });
cntx->SendOk();
builder->SendOk();
return;
}

if (!absl::SimpleAtoi(facade::ToSV(args[0]), &max_output)) {
cntx->SendError("Invalid count");
builder->SendError("Invalid count");
return;
}
}
Expand All @@ -392,7 +393,7 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
total_entries += log.size();
}

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
if (total_entries == 0) {
rb->SendEmptyArray();
return;
Expand Down Expand Up @@ -453,19 +454,19 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
}
}

void AclFamily::Users(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Users(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(registry.size());
for (const auto& [username, _] : registry) {
rb->SendSimpleString(username);
}
}

void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Cat(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
if (args.size() > 1) {
cntx->SendError(facade::OpStatus::SYNTAX_ERR);
builder->SendError(facade::OpStatus::SYNTAX_ERR);
return;
}

Expand All @@ -474,7 +475,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {

if (!cat_table_.contains(category)) {
auto error = absl::StrCat("Unkown category: ", category);
cntx->SendError(error);
builder->SendError(error);
return;
}

Expand All @@ -487,7 +488,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
}
};

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
cmd_registry_->Traverse(cb);
rb->StartArray(results.size());
for (const auto& command : results) {
Expand All @@ -504,7 +505,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
}
}

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(total_categories);
for (auto& elem : reverse_cat_table_) {
if (elem != "_RESERVED") {
Expand All @@ -513,12 +514,12 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
}
}

void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::GetUser(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto username = facade::ToSV(args[0]);
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
if (!registry.contains(username)) {
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->SendNull();
return;
}
Expand All @@ -529,7 +530,7 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
pass.pop_back();
}

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(10);

rb->SendSimpleString("flags");
Expand Down Expand Up @@ -566,17 +567,17 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
rb->SendSimpleString(pub_sub);
}

void AclFamily::GenPass(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::GenPass(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
if (args.length() > 1) {
cntx->SendError(facade::UnknownSubCmd("GENPASS", "ACL"));
builder->SendError(facade::UnknownSubCmd("GENPASS", "ACL"));
return;
}
uint32_t random_bits = 256;
if (args.length() == 1) {
auto requested_bits = facade::ArgS(args, 0);

if (!absl::SimpleAtoi(requested_bits, &random_bits) || random_bits == 0 || random_bits > 4096) {
return cntx->SendError(
return builder->SendError(
"ACL GENPASS argument must be the number of bits for the output password, a positive "
"number up to 4096");
}
Expand All @@ -591,44 +592,55 @@ void AclFamily::GenPass(CmdArgList args, ConnectionContext* cntx) {

response.resize(result_length);

cntx->SendSimpleString(response);
builder->SendSimpleString(response);
}

void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::DryRun(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto username = facade::ArgS(args, 0);
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
if (!registry.contains(username)) {
auto error = absl::StrCat("User '", username, "' not found");
cntx->SendError(error);
builder->SendError(error);
return;
}

string command = absl::AsciiStrToUpper(ArgS(args, 1));
auto* cid = cmd_registry_->Find(command);
if (!cid) {
auto error = absl::StrCat("Command '", command, "' not found");
cntx->SendError(error);
builder->SendError(error);
return;
}

const auto& user = registry.find(username)->second;
const bool is_allowed =
IsUserAllowedToInvokeCommandGeneric(user.AclCommandsRef(), {{}, true}, {}, *cid).first;
if (is_allowed) {
cntx->SendOk();
builder->SendOk();
return;
}

auto msg = absl::StrCat("This user has no permissions to run the '", command, "' command");
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->SendBulkString(msg);
}

using MemberFunc = void (AclFamily::*)(CmdArgList args, ConnectionContext* cntx);
using MemberFunc2 = void (AclFamily::*)(CmdArgList args, Transaction* tx,
facade::SinkReplyBuilder* builder);

CommandId::Handler HandlerFunc(AclFamily* acl, MemberFunc f) {
return [=](CmdArgList args, ConnectionContext* cntx) { return (acl->*f)(args, cntx); };
using MemberFunc3 = void (AclFamily::*)(CmdArgList args, Transaction* tx,
facade::SinkReplyBuilder* builder, ConnectionContext* cntx);

CommandId::Handler2 HandlerFunc(AclFamily* acl, MemberFunc2 f) {
return [=](CmdArgList args, Transaction* tx, facade::SinkReplyBuilder* builder) {
return (acl->*f)(args, tx, builder);
};
}

CommandId::Handler3 HandlerFunc(AclFamily* acl, MemberFunc3 f) {
return [=](CmdArgList args, Transaction* tx, facade::SinkReplyBuilder* builder,
ConnectionContext* cntx) { return (acl->*f)(args, tx, builder, cntx); };
}

#define HFUNC(x) SetHandler(HandlerFunc(this, &AclFamily::x))
Expand Down
Loading

0 comments on commit 7035606

Please sign in to comment.