Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(string family): implement cl.throttle #714

Merged
merged 10 commits into from
Jan 24, 2023
223 changes: 222 additions & 1 deletion src/server/string_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ extern "C" {
#include <absl/container/inlined_vector.h>
#include <double-conversion/string-to-double.h>

#include <algorithm>
#include <array>
#include <chrono>
#include <cstdint>
#include <tuple>
romange marked this conversation as resolved.
Show resolved Hide resolved

#include "base/logging.h"
#include "redis/util.h"
Expand Down Expand Up @@ -356,6 +360,116 @@ OpResult<void> SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar
return cntx->transaction->ScheduleSingleHop(std::move(cb));
}

// emission_interval_ms assumed to be positive
// limit is assumed to be positive
OpResult<array<int64_t, 5>> OpThrottle(const OpArgs& op_args, const string_view key,
const int64_t limit, const int64_t emission_interval_ms,
const uint64_t quantity) {
auto& db_slice = op_args.shard->db_slice();

if (emission_interval_ms > INT64_MAX / limit) {
return OpStatus::INVALID_INT;
}
const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit; // should be positive

int64_t remaining = 0;
int64_t reset_after_ms = -1000;
int64_t retry_after_ms = -1000;

if (quantity != 0 && static_cast<uint64_t>(emission_interval_ms) > INT64_MAX / quantity) {
return OpStatus::INVALID_INT;
}
const int64_t increment_ms = emission_interval_ms * quantity; // should be nonnegative

auto [it, e_it] = db_slice.FindExt(op_args.db_cntx, key);
const int64_t now_ms = op_args.db_cntx.time_now_ms;

int64_t tat_ms = now_ms;
if (IsValid(it)) {
if (it->second.ObjType() != OBJ_STRING) {
return OpStatus::WRONG_TYPE;
}

auto opt_prev = it->second.TryGetInt();
if (!opt_prev) {
return OpStatus::INVALID_VALUE;
}
tat_ms = *opt_prev;
}

int64_t new_tat_ms = max(tat_ms, now_ms);
if (new_tat_ms > INT64_MAX - increment_ms) {
return OpStatus::INVALID_INT;
}
new_tat_ms += increment_ms;

if (new_tat_ms < INT64_MIN + delay_variation_tolerance_ms) {
return OpStatus::INVALID_INT;
}
const int64_t allow_at_ms = new_tat_ms - delay_variation_tolerance_ms;

if (allow_at_ms >= 0 ? now_ms < INT64_MIN + allow_at_ms : now_ms > INT64_MAX + allow_at_ms) {
return OpStatus::INVALID_INT;
}
const int64_t diff_ms = now_ms - allow_at_ms;

const bool limited = diff_ms < 0;
int64_t ttl_ms;
if (limited) {
if (increment_ms <= delay_variation_tolerance_ms) {
if (diff_ms == INT64_MIN) {
return OpStatus::INVALID_INT;
}
retry_after_ms = -diff_ms;
}

if (now_ms >= 0 ? tat_ms < INT64_MIN + now_ms : tat_ms > INT64_MAX + now_ms) {
return OpStatus::INVALID_INT;
}
ttl_ms = tat_ms - now_ms;
} else {
if (now_ms >= 0 ? new_tat_ms < INT64_MIN + now_ms : new_tat_ms > INT64_MAX + now_ms) {
return OpStatus::INVALID_INT;
}
ttl_ms = new_tat_ms - now_ms;
}

if (ttl_ms < delay_variation_tolerance_ms - INT64_MAX) {
return OpStatus::INVALID_INT;
}
const int64_t next_ms = delay_variation_tolerance_ms - ttl_ms;
if (next_ms > -emission_interval_ms) {
remaining = next_ms / emission_interval_ms;
}
reset_after_ms = ttl_ms;

if (!limited) {
if (IsValid(it)) {
if (IsValid(e_it)) {
e_it->second = db_slice.FromAbsoluteTime(new_tat_ms);
} else {
db_slice.AddExpire(op_args.db_cntx.db_index, it, new_tat_ms);
}

db_slice.PreUpdate(op_args.db_cntx.db_index, it);
it->second.SetInt(new_tat_ms);
db_slice.PostUpdate(op_args.db_cntx.db_index, it, key);
} else {
CompactObj cobj;
cobj.SetInt(new_tat_ms);

// AddNew calls PostUpdate inside.
try {
it = db_slice.AddNew(op_args.db_cntx, key, std::move(cobj), new_tat_ms);
} catch (bad_alloc&) {
return OpStatus::OUT_OF_MEMORY;
}
}
}

return array<int64_t, 5>{limited ? 1 : 0, limit, remaining, retry_after_ms, reset_after_ms};
}

} // namespace

OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) {
Expand Down Expand Up @@ -1170,6 +1284,112 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction
return response;
}

/* CL.THROTTLE <key> <max_burst> <count per period> <period> [<quantity>] */
/* Response is array of 5 integers. The meaning of each array item is:
* 1. Whether the action was limited:
* - 0 indicates the action is allowed.
* - 1 indicates that the action was limited/blocked.
* 2. The total limit of the key (max_burst + 1). This is equivalent to the common
* X-RateLimit-Limit HTTP header.
* 3. The remaining limit of the key. Equivalent to X-RateLimit-Remaining.
* 4. The number of seconds until the user should retry, and always -1 if the action was allowed.
* Equivalent to Retry-After.
* 5. The number of seconds until the limit will reset to its maximum capacity. Equivalent to
* X-RateLimit-Reset.
*/
void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
const string_view key = ArgS(args, 1);

// Allow max burst in number of tokens
uint64_t max_burst;
const string_view max_burst_str = ArgS(args, 2);
if (!absl::SimpleAtoi(max_burst_str, &max_burst)) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
return (*cntx)->SendError(kInvalidIntErr);
}

// Emit count of tokens per period
uint64_t count;
const string_view count_str = ArgS(args, 3);
if (!absl::SimpleAtoi(count_str, &count)) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
return (*cntx)->SendError(kInvalidIntErr);
}

// Period of emitting count of tokens
uint64_t period;
const string_view period_str = ArgS(args, 4);
if (!absl::SimpleAtoi(period_str, &period)) {
zetanumbers marked this conversation as resolved.
Show resolved Hide resolved
return (*cntx)->SendError(kInvalidIntErr);
}

// Apply quantity of tokens now
uint64_t quantity = 1;
if (args.size() > 5) {
const string_view quantity_str = ArgS(args, 5);

if (!absl::SimpleAtoi(quantity_str, &quantity)) {
return (*cntx)->SendError(kInvalidIntErr);
}
}

if (max_burst > INT64_MAX - 1) {
return (*cntx)->SendError(kInvalidIntErr);
}
const int64_t limit = max_burst + 1;

if (period > UINT64_MAX / 1000 || count == 0 || period * 1000 / count > INT64_MAX) {
return (*cntx)->SendError(kInvalidIntErr);
}
const int64_t emission_interval_ms = period * 1000 / count;
romange marked this conversation as resolved.
Show resolved Hide resolved

if (emission_interval_ms == 0) {
return (*cntx)->SendError("zero rates are not supported");
}

auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<array<int64_t, 5>> {
return OpThrottle(t->GetOpArgs(shard), key, limit, emission_interval_ms, quantity);
};

Transaction* trans = cntx->transaction;
OpResult<array<int64_t, 5>> result = trans->ScheduleSingleHopT(std::move(cb));

if (result) {
(*cntx)->StartArray(result->size());
auto& array = result.value();

int64_t retry_after_s = array[3] / 1000;
if (array[3] > 0) {
retry_after_s += 1;
}
array[3] = retry_after_s;

int64_t reset_after_s = array[4] / 1000;
if (array[4] > 0) {
reset_after_s += 1;
}
array[4] = reset_after_s;

for (const auto& v : array) {
(*cntx)->SendLong(v);
}
} else {
switch (result.status()) {
case OpStatus::WRONG_TYPE:
(*cntx)->SendError(kWrongTypeErr);
break;
case OpStatus::INVALID_INT:
case OpStatus::INVALID_VALUE:
(*cntx)->SendError(kInvalidIntErr);
break;
case OpStatus::OUT_OF_MEMORY:
(*cntx)->SendError(kOutOfMemory);
break;
default:
(*cntx)->SendError(result.status());
break;
}
}
}

void StringFamily::Init(util::ProactorPool* pp) {
set_qps.Init(pp);
get_qps.Init(pp);
Expand Down Expand Up @@ -1206,7 +1426,8 @@ void StringFamily::Register(CommandRegistry* registry) {
<< CI{"STRLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(StrLen)
<< CI{"GETRANGE", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange)
<< CI{"SUBSTR", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange) // Alias for GetRange
<< CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange);
<< CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange)
<< CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, 1}.HFUNC(ClThrottle);
}

} // namespace dfly
2 changes: 2 additions & 0 deletions src/server/string_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class StringFamily {
static void Prepend(CmdArgList args, ConnectionContext* cntx);
static void PSetEx(CmdArgList args, ConnectionContext* cntx);

static void ClThrottle(CmdArgList args, ConnectionContext* cntx);

// These functions are used internally, they do not implement any specific command
static void IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx);
static void ExtendGeneric(CmdArgList args, bool prepend, ConnectionContext* cntx);
Expand Down
103 changes: 103 additions & 0 deletions src/server/string_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,107 @@ TEST_F(StringFamilyTest, GetEx) {
EXPECT_THAT(Run({"getex", "foo"}), ArgType(RespExpr::NIL));
}

TEST_F(StringFamilyTest, ClThrottle) {
const int64_t limit = 5;
const char* const key = "foo";
const char* const max_burst = "4"; // limit - 1
const char* const count = "1";
const char* const period = "10";

// You can never make a request larger than the maximum.
auto resp = Run({"cl.throttle", key, max_burst, count, period, "6"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0)));

// Rate limit normal requests appropriately.
resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51)));

resp = Run({"cl.throttle", key, max_burst, count, period});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51)));

AdvanceTime(30000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

AdvanceTime(1000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40)));

AdvanceTime(9000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41)));

AdvanceTime(40000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

AdvanceTime(15000);
resp = Run({"cl.throttle", key, max_burst, count, period, "1"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

// Zero-volume request just peeks at the state.
resp = Run({"cl.throttle", key, max_burst, count, period, "0"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11)));

// High-volume request uses up more of the limit.
resp = Run({"cl.throttle", key, max_burst, count, period, "2"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31)));

// Large requests cannot exceed limits
resp = Run({"cl.throttle", key, max_burst, count, period, "5"});
ASSERT_EQ(RespExpr::ARRAY, resp.type);
ASSERT_THAT(resp.GetVec(),
ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31)));

// Zero rates aren't supported
resp = Run({"cl.throttle", "bar", "10", "1", "0"});
ASSERT_EQ(RespExpr::ERROR, resp.type);
EXPECT_THAT(resp, ErrArg("zero rates are not supported"));

// count == 0
resp = Run({"cl.throttle", "bar", "10", "0", "1"});
ASSERT_EQ(RespExpr::ERROR, resp.type);
EXPECT_THAT(resp, ErrArg(kInvalidIntErr));
}

} // namespace dfly