Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: pass SinkReplyBuilder and Transaction explicitly. Part6 #3987

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 66 additions & 54 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ AclFamily::AclFamily(UserRegistry* registry, util::ProactorPool* pool)
: registry_(registry), pool_(pool) {
}

void AclFamily::Acl(CmdArgList args, ConnectionContext* cntx) {
cntx->SendError("Wrong number of arguments for acl command");
void AclFamily::Acl(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
builder->SendError("Wrong number of arguments for acl command");
}

void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::List(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(registry.size());

for (const auto& [username, user] : registry) {
Expand All @@ -91,7 +91,7 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password,
acl_keys, maybe_space_com, acl_pub_sub, " ", acl_cat_and_commands);

cntx->SendSimpleString(buffer);
builder->SendSimpleString(buffer);
}
}

Expand All @@ -116,17 +116,17 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,

using facade::ErrorReply;

void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::SetUser(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view username = facade::ToSV(args[0]);
auto reg = registry_->GetRegistryWithWriteLock();
const bool exists = reg.registry.contains(username);
const bool has_all_keys = exists ? reg.registry.find(username)->second.Keys().all_keys : false;

auto req = ParseAclSetUser(args.subspan(1), false, has_all_keys);

auto error_case = [cntx](ErrorReply&& error) { cntx->SendError(error); };
auto error_case = [builder](ErrorReply&& error) { builder->SendError(error); };

auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto update_case = [username, &reg, builder, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
if (!exists) {
User::UpdateRequest default_req;
Expand All @@ -137,7 +137,7 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
const bool reset_channels = req.reset_channels;
user.Update(std::move(req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex());
// Send ok first because the connection might get evicted
cntx->SendOk();
builder->SendOk();
if (exists) {
if (!reset_channels) {
StreamUpdatesToAllProactorConnections(string(username), user.AclCommands(), user.Keys(),
Expand Down Expand Up @@ -184,7 +184,7 @@ void AclFamily::EvictOpenConnectionsOnAllProactorsWithRegistry(
}
}

void AclFamily::DelUser(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::DelUser(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto& registry = *registry_;
absl::flat_hash_set<string_view> users;

Expand All @@ -199,17 +199,18 @@ void AclFamily::DelUser(CmdArgList args, ConnectionContext* cntx) {
}

if (users.empty()) {
cntx->SendLong(0);
builder->SendLong(0);
return;
}
VLOG(1) << "Evicting open acl connections";
EvictOpenConnectionsOnAllProactors(users);
VLOG(1) << "Done evicting open acl connections";
cntx->SendLong(users.size());
builder->SendLong(users.size());
}

void AclFamily::WhoAmI(CmdArgList args, ConnectionContext* cntx) {
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
void AclFamily::WhoAmI(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
ConnectionContext* cntx) {
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->SendBulkString(absl::StrCat("User is ", cntx->authed_username));
}

Expand Down Expand Up @@ -239,18 +240,18 @@ string AclFamily::RegistryToString() const {
return result;
}

void AclFamily::Save(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Save(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto acl_file_path = absl::GetFlag(FLAGS_aclfile);
if (acl_file_path.empty()) {
cntx->SendError("Dragonfly is not configured to use an ACL file.");
builder->SendError("Dragonfly is not configured to use an ACL file.");
return;
}

auto res = io::OpenWrite(acl_file_path);
if (!res) {
std::string error = absl::StrCat("Failed to open the aclfile: ", res.error().message());
LOG(ERROR) << error;
cntx->SendError(error);
builder->SendError(error);
return;
}

Expand All @@ -261,23 +262,23 @@ void AclFamily::Save(CmdArgList args, ConnectionContext* cntx) {
if (ec) {
std::string error = absl::StrCat("Failed to write to the aclfile: ", ec.message());
LOG(ERROR) << error;
cntx->SendError(error);
builder->SendError(error);
return;
}

ec = file->Close();
if (ec) {
std::string error = absl::StrCat("Failed to close the aclfile ", ec.message());
LOG(WARNING) << error;
cntx->SendError(error);
builder->SendError(error);
return;
}

cntx->SendOk();
builder->SendOk();
}

GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,
ConnectionContext* cntx) {
SinkReplyBuilder* builder) {
auto is_file_read = io::ReadFileToString(full_path);
if (!is_file_read) {
auto error = absl::StrCat("Dragonfly could not load ACL file ", full_path, " with error ",
Expand Down Expand Up @@ -316,8 +317,8 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,

auto registry_with_wlock = registry_->GetRegistryWithWriteLock();
auto& registry = registry_with_wlock.registry;
if (cntx) {
cntx->SendOk();
if (builder) {
builder->SendOk();
// Evict open connections for old users
EvictOpenConnectionsOnAllProactorsWithRegistry(registry);
registry.clear();
Expand Down Expand Up @@ -347,23 +348,23 @@ bool AclFamily::Load() {
return !LoadToRegistryFromFile(acl_file, nullptr);
}

void AclFamily::Load(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Load(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto acl_file = absl::GetFlag(FLAGS_aclfile);
if (acl_file.empty()) {
cntx->SendError("Dragonfly is not configured to use an ACL file.");
builder->SendError("Dragonfly is not configured to use an ACL file.");
return;
}

const auto load_error = LoadToRegistryFromFile(acl_file, cntx);
const auto load_error = LoadToRegistryFromFile(acl_file, builder);

if (load_error) {
cntx->SendError(absl::StrCat("Error loading: ", acl_file, " ", load_error.Format()));
builder->SendError(absl::StrCat("Error loading: ", acl_file, " ", load_error.Format()));
}
}

void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Log(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
if (args.size() > 1) {
cntx->SendError(facade::OpStatus::OUT_OF_RANGE);
builder->SendError(facade::OpStatus::OUT_OF_RANGE);
}

size_t max_output = 10;
Expand All @@ -372,12 +373,12 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
if (absl::EqualsIgnoreCase(option, "RESET")) {
pool_->AwaitFiberOnAll(
[](auto index, auto* context) { ServerState::tlocal()->acl_log.Reset(); });
cntx->SendOk();
builder->SendOk();
return;
}

if (!absl::SimpleAtoi(facade::ToSV(args[0]), &max_output)) {
cntx->SendError("Invalid count");
builder->SendError("Invalid count");
return;
}
}
Expand All @@ -392,7 +393,7 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
total_entries += log.size();
}

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
if (total_entries == 0) {
rb->SendEmptyArray();
return;
Expand Down Expand Up @@ -453,19 +454,19 @@ void AclFamily::Log(CmdArgList args, ConnectionContext* cntx) {
}
}

void AclFamily::Users(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Users(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(registry.size());
for (const auto& [username, _] : registry) {
rb->SendSimpleString(username);
}
}

void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::Cat(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
if (args.size() > 1) {
cntx->SendError(facade::OpStatus::SYNTAX_ERR);
builder->SendError(facade::OpStatus::SYNTAX_ERR);
return;
}

Expand All @@ -474,7 +475,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {

if (!cat_table_.contains(category)) {
auto error = absl::StrCat("Unkown category: ", category);
cntx->SendError(error);
builder->SendError(error);
return;
}

Expand All @@ -487,7 +488,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
}
};

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
cmd_registry_->Traverse(cb);
rb->StartArray(results.size());
for (const auto& command : results) {
Expand All @@ -504,7 +505,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
}
}

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(total_categories);
for (auto& elem : reverse_cat_table_) {
if (elem != "_RESERVED") {
Expand All @@ -513,12 +514,12 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {
}
}

void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::GetUser(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto username = facade::ToSV(args[0]);
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
if (!registry.contains(username)) {
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->SendNull();
return;
}
Expand All @@ -529,7 +530,7 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
pass.pop_back();
}

auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->StartArray(10);

rb->SendSimpleString("flags");
Expand Down Expand Up @@ -566,17 +567,17 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
rb->SendSimpleString(pub_sub);
}

void AclFamily::GenPass(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::GenPass(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
if (args.length() > 1) {
cntx->SendError(facade::UnknownSubCmd("GENPASS", "ACL"));
builder->SendError(facade::UnknownSubCmd("GENPASS", "ACL"));
return;
}
uint32_t random_bits = 256;
if (args.length() == 1) {
auto requested_bits = facade::ArgS(args, 0);

if (!absl::SimpleAtoi(requested_bits, &random_bits) || random_bits == 0 || random_bits > 4096) {
return cntx->SendError(
return builder->SendError(
"ACL GENPASS argument must be the number of bits for the output password, a positive "
"number up to 4096");
}
Expand All @@ -591,44 +592,55 @@ void AclFamily::GenPass(CmdArgList args, ConnectionContext* cntx) {

response.resize(result_length);

cntx->SendSimpleString(response);
builder->SendSimpleString(response);
}

void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) {
void AclFamily::DryRun(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
auto username = facade::ArgS(args, 0);
const auto registry_with_lock = registry_->GetRegistryWithLock();
const auto& registry = registry_with_lock.registry;
if (!registry.contains(username)) {
auto error = absl::StrCat("User '", username, "' not found");
cntx->SendError(error);
builder->SendError(error);
return;
}

string command = absl::AsciiStrToUpper(ArgS(args, 1));
auto* cid = cmd_registry_->Find(command);
if (!cid) {
auto error = absl::StrCat("Command '", command, "' not found");
cntx->SendError(error);
builder->SendError(error);
return;
}

const auto& user = registry.find(username)->second;
const bool is_allowed =
IsUserAllowedToInvokeCommandGeneric(user.AclCommandsRef(), {{}, true}, {}, *cid).first;
if (is_allowed) {
cntx->SendOk();
builder->SendOk();
return;
}

auto msg = absl::StrCat("This user has no permissions to run the '", command, "' command");
auto* rb = static_cast<facade::RedisReplyBuilder*>(cntx->reply_builder());
auto* rb = static_cast<facade::RedisReplyBuilder*>(builder);
rb->SendBulkString(msg);
}

using MemberFunc = void (AclFamily::*)(CmdArgList args, ConnectionContext* cntx);
using MemberFunc2 = void (AclFamily::*)(CmdArgList args, Transaction* tx,
facade::SinkReplyBuilder* builder);

CommandId::Handler HandlerFunc(AclFamily* acl, MemberFunc f) {
return [=](CmdArgList args, ConnectionContext* cntx) { return (acl->*f)(args, cntx); };
using MemberFunc3 = void (AclFamily::*)(CmdArgList args, Transaction* tx,
facade::SinkReplyBuilder* builder, ConnectionContext* cntx);

CommandId::Handler2 HandlerFunc(AclFamily* acl, MemberFunc2 f) {
return [=](CmdArgList args, Transaction* tx, facade::SinkReplyBuilder* builder) {
return (acl->*f)(args, tx, builder);
};
}

CommandId::Handler3 HandlerFunc(AclFamily* acl, MemberFunc3 f) {
return [=](CmdArgList args, Transaction* tx, facade::SinkReplyBuilder* builder,
ConnectionContext* cntx) { return (acl->*f)(args, tx, builder, cntx); };
}

#define HFUNC(x) SetHandler(HandlerFunc(this, &AclFamily::x))
Expand Down
Loading
Loading