Skip to content

Commit

Permalink
feat(string family): implement cl.throttle
Browse files Browse the repository at this point in the history
  • Loading branch information
zetanumbers committed Jan 20, 2023
1 parent b2edf9c commit 2fdf9bc
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 1 deletion.
149 changes: 148 additions & 1 deletion src/server/string_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ extern "C" {
#include <absl/container/inlined_vector.h>
#include <double-conversion/string-to-double.h>

#include <algorithm>
#include <array>
#include <chrono>
#include <tuple>

#include "base/logging.h"
#include "redis/util.h"
Expand Down Expand Up @@ -356,6 +359,92 @@ OpResult<void> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar
return cntx->transaction->ScheduleSingleHop(std::move(cb));
}

OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, string_view key, int64_t max_burst,
int64_t count, int64_t period_s, int64_t quantity) {
using namespace chrono_literals;

auto& db_slice = op_args.shard->db_slice();

const int64_t limit = max_burst + 1;
const int64_t emission_interval_ms = period_s * 1000 / count;
const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit;

if (emission_interval_ms == 0) {
return OpStatus::OUT_OF_RANGE;
}

int64_t remaining = 0;
int64_t reset_after_ms = -1000;
int64_t retry_after_ms = -1000;

const int64_t increment_ms = emission_interval_ms * quantity;

auto [it, expire_it] = db_slice.FindExt(op_args.db_cntx, key);
const int64_t now_ms = op_args.db_cntx.time_now_ms;

int64_t tat_ms = now_ms;
if (IsValid(it)) {
if (it->second.ObjType() != OBJ_STRING) {
return OpStatus::WRONG_TYPE;
}

auto opt_prev = it->second.TryGetInt();
if (!opt_prev) {
return OpStatus::INVALID_VALUE;
}
tat_ms = *opt_prev;
}
const int64_t new_tat_ms = max(tat_ms, now_ms) + increment_ms;

const int64_t allow_at_ms = new_tat_ms - delay_variation_tolerance_ms;
const int64_t diff_ms = now_ms - allow_at_ms;

const bool limited = diff_ms < 0;
int64_t ttl_ms;
if (limited) {
if (increment_ms <= delay_variation_tolerance_ms) {
retry_after_ms = -diff_ms;
}
ttl_ms = tat_ms - now_ms;
} else {
ttl_ms = new_tat_ms - now_ms;

if (IsValid(it)) {
db_slice.PreUpdate(op_args.db_cntx.db_index, it);
it->second.SetInt(new_tat_ms);
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
} else {
CompactObj cobj;
cobj.SetInt(new_tat_ms);

// AddNew calls PostUpdate inside.
try {
it = db_slice.AddNew(op_args.db_cntx, key, std::move(cobj), 0);
} catch (bad_alloc&) {
return OpStatus::OUT_OF_MEMORY;
}
}
}

const int64_t next_ms = delay_variation_tolerance_ms - ttl_ms;
if (next_ms > -emission_interval_ms) {
remaining = next_ms / emission_interval_ms;
}
reset_after_ms = ttl_ms;

int64_t retry_after_s = retry_after_ms / 1000;
if (retry_after_ms > 0) {
retry_after_s += 1;
}

int64_t reset_after_s = reset_after_ms / 1000;
if (reset_after_ms > 0) {
reset_after_s += 1;
}

return array<int64_t, 5>{limited ? 1 : 0, limit, remaining, retry_after_s, reset_after_s};
}

} // namespace

OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) {
Expand Down Expand Up @@ -1128,6 +1217,63 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction
return response;
}

void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 1);

int64_t max_burst;
string_view max_burst_str = ArgS(args, 2);
if (!absl::SimpleAtoi(max_burst_str, &max_burst)) {
return (*cntx)->SendError(kInvalidIntErr);
}

int64_t count;
string_view count_str = ArgS(args, 3);
if (!absl::SimpleAtoi(count_str, &count)) {
return (*cntx)->SendError(kInvalidIntErr);
}

int64_t period;
string_view period_str = ArgS(args, 4);
if (!absl::SimpleAtoi(period_str, &period)) {
return (*cntx)->SendError(kInvalidIntErr);
}

int64_t quantity = 1;
if (args.size() > 5) {
string_view quantity_str = ArgS(args, 5);

if (!absl::SimpleAtoi(quantity_str, &quantity)) {
return (*cntx)->SendError(kInvalidIntErr);
}
}

auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<array<int64_t, 5>> {
return OpThrottle(t->GetOpArgs(shard), key, max_burst, count, period, quantity);
};

Transaction* trans = cntx->transaction;
OpResult<array<int64_t, 5>> result = trans->ScheduleSingleHopT(std::move(cb));

switch (result.status()) {
case OpStatus::WRONG_TYPE:
(*cntx)->SendError(result.status());
break;
case OpStatus::INVALID_VALUE:
(*cntx)->SendError(kInvalidIntErr);
break;
case OpStatus::OUT_OF_RANGE:
(*cntx)->SendError(kIncrOverflow);
break;
default:
(*cntx)->StartArray(result->size());
const auto& array = result.value();
for (const auto& v : array) {
(*cntx)->SendLong(v);
}
break;
}
}

void StringFamily::Init(util::ProactorPool* pp) {
set_qps.Init(pp);
get_qps.Init(pp);
Expand Down Expand Up @@ -1163,7 +1309,8 @@ void StringFamily::Register(CommandRegistry* registry) {
<< CI{"GETRANGE", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange)
<< CI{"SUBSTR", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(
GetRange) // Alias for GetRange
<< CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange);
<< CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange)
<< CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, 1}.HFUNC(ClThrottle);
}

} // namespace dfly
2 changes: 2 additions & 0 deletions src/server/string_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class StringFamily {
static void Prepend(CmdArgList args, ConnectionContext* cntx);
static void PSetEx(CmdArgList args, ConnectionContext* cntx);

static void ClThrottle(CmdArgList args, ConnectionContext* cntx);

// These functions are used internally, they do not implement any specific command
static void IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx);
static void ExtendGeneric(CmdArgList args, bool prepend, ConnectionContext* cntx);
Expand Down
63 changes: 63 additions & 0 deletions src/server/string_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,67 @@ TEST_F(StringFamilyTest, GetEx) {
EXPECT_THAT(Run({"getex", "foo"}), ArgType(RespExpr::NIL));
}

TEST_F(StringFamilyTest, ClThrottle) {
const int64_t limit = 5;
const char* const key = "foo";
const char* const max_burst = "4"; // limit - 1
const char* const count = "1";
const char* const period = "10";

// You can never make a request larger than the maximum.
auto resp = Run({"cl.throttle", key, max_burst, count, period, "6"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0)));

// Rate limit normal requests appropriately.
resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51)));

AdvanceTime(30000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

AdvanceTime(1000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40)));

AdvanceTime(9000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41)));

AdvanceTime(40000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

AdvanceTime(15000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

// Zero-volume request just peeks at the state.
resp = Run({"cl.throttle", key, max_burst, count, period, "0"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

// High-volume request uses up more of the limit.
resp = Run({"cl.throttle", key, max_burst, count, period, "2"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

// Large requests cannot exceed limits
resp = Run({"cl.throttle", key, max_burst, count, period, "5"});
ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31)));
}

} // namespace dfly

0 comments on commit 2fdf9bc

Please sign in to comment.