diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 0f5e334c863b..4852043a5a10 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -11,7 +11,10 @@ extern "C" { #include #include +#include +#include #include +#include #include "base/logging.h" #include "redis/util.h" @@ -356,6 +359,92 @@ OpResult SetGeneric(ConnectionContext* cntx, const SetCmd::SetParams& spar return cntx->transaction->ScheduleSingleHop(std::move(cb)); } +OpResult> OpThrottle(const OpArgs& op_args, string_view key, int64_t max_burst, + int64_t count, int64_t period_s, int64_t quantity) { + using namespace chrono_literals; + + auto& db_slice = op_args.shard->db_slice(); + + const int64_t limit = max_burst + 1; + const int64_t emission_interval_ms = period_s * 1000 / count; + const int64_t delay_variation_tolerance_ms = emission_interval_ms * limit; + + if (emission_interval_ms == 0) { + return OpStatus::OUT_OF_RANGE; + } + + int64_t remaining = 0; + int64_t reset_after_ms = -1000; + int64_t retry_after_ms = -1000; + + const int64_t increment_ms = emission_interval_ms * quantity; + + auto [it, expire_it] = db_slice.FindExt(op_args.db_cntx, key); + const int64_t now_ms = op_args.db_cntx.time_now_ms; + + int64_t tat_ms = now_ms; + if (IsValid(it)) { + if (it->second.ObjType() != OBJ_STRING) { + return OpStatus::WRONG_TYPE; + } + + auto opt_prev = it->second.TryGetInt(); + if (!opt_prev) { + return OpStatus::INVALID_VALUE; + } + tat_ms = *opt_prev; + } + const int64_t new_tat_ms = max(tat_ms, now_ms) + increment_ms; + + const int64_t allow_at_ms = new_tat_ms - delay_variation_tolerance_ms; + const int64_t diff_ms = now_ms - allow_at_ms; + + const bool limited = diff_ms < 0; + int64_t ttl_ms; + if (limited) { + if (increment_ms <= delay_variation_tolerance_ms) { + retry_after_ms = -diff_ms; + } + ttl_ms = tat_ms - now_ms; + } else { + ttl_ms = new_tat_ms - now_ms; + + if (IsValid(it)) { + db_slice.PreUpdate(op_args.db_cntx.db_index, it); + it->second.SetInt(new_tat_ms); + db_slice.PostUpdate(op_args.db_cntx.db_index, it, key); + } else { + CompactObj cobj; + cobj.SetInt(new_tat_ms); + + // AddNew calls PostUpdate inside. + try { + it = db_slice.AddNew(op_args.db_cntx, key, std::move(cobj), 0); + } catch (bad_alloc&) { + return OpStatus::OUT_OF_MEMORY; + } + } + } + + const int64_t next_ms = delay_variation_tolerance_ms - ttl_ms; + if (next_ms > -emission_interval_ms) { + remaining = next_ms / emission_interval_ms; + } + reset_after_ms = ttl_ms; + + int64_t retry_after_s = retry_after_ms / 1000; + if (retry_after_ms > 0) { + retry_after_s += 1; + } + + int64_t reset_after_s = reset_after_ms / 1000; + if (reset_after_ms > 0) { + reset_after_s += 1; + } + + return array{limited ? 1 : 0, limit, remaining, retry_after_s, reset_after_s}; +} + } // namespace OpStatus SetCmd::Set(const SetParams& params, string_view key, string_view value) { @@ -1128,6 +1217,63 @@ auto StringFamily::OpMGet(bool fetch_mcflag, bool fetch_mcver, const Transaction return response; } +void StringFamily::ClThrottle(CmdArgList args, ConnectionContext* cntx) { + string_view key = ArgS(args, 1); + + int64_t max_burst; + string_view max_burst_str = ArgS(args, 2); + if (!absl::SimpleAtoi(max_burst_str, &max_burst)) { + return (*cntx)->SendError(kInvalidIntErr); + } + + int64_t count; + string_view count_str = ArgS(args, 3); + if (!absl::SimpleAtoi(count_str, &count)) { + return (*cntx)->SendError(kInvalidIntErr); + } + + int64_t period; + string_view period_str = ArgS(args, 4); + if (!absl::SimpleAtoi(period_str, &period)) { + return (*cntx)->SendError(kInvalidIntErr); + } + + int64_t quantity = 1; + if (args.size() > 5) { + string_view quantity_str = ArgS(args, 5); + + if (!absl::SimpleAtoi(quantity_str, &quantity)) { + return (*cntx)->SendError(kInvalidIntErr); + } + } + + auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult> { + return OpThrottle(t->GetOpArgs(shard), key, max_burst, count, period, quantity); + }; + + Transaction* trans = cntx->transaction; + OpResult> result = trans->ScheduleSingleHopT(std::move(cb)); + + switch (result.status()) { + case OpStatus::WRONG_TYPE: + (*cntx)->SendError(result.status()); + break; + case OpStatus::INVALID_VALUE: + (*cntx)->SendError(kInvalidIntErr); + break; + case OpStatus::OUT_OF_RANGE: + (*cntx)->SendError(kIncrOverflow); + break; + default: + (*cntx)->StartArray(result->size()); + const auto& array = result.value(); + for (const auto& v : array) { + (*cntx)->SendLong(v); + } + break; + } +} + void StringFamily::Init(util::ProactorPool* pp) { set_qps.Init(pp); get_qps.Init(pp); @@ -1163,7 +1309,8 @@ void StringFamily::Register(CommandRegistry* registry) { << CI{"GETRANGE", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC(GetRange) << CI{"SUBSTR", CO::READONLY | CO::FAST, 4, 1, 1, 1}.HFUNC( GetRange) // Alias for GetRange - << CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange); + << CI{"SETRANGE", CO::WRITE | CO::FAST | CO::DENYOOM, 4, 1, 1, 1}.HFUNC(SetRange) + << CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, 1}.HFUNC(ClThrottle); } } // namespace dfly diff --git a/src/server/string_family.h b/src/server/string_family.h index 4286a2c29ecc..282f4a951d31 100644 --- a/src/server/string_family.h +++ b/src/server/string_family.h @@ -81,6 +81,8 @@ class StringFamily { static void Prepend(CmdArgList args, ConnectionContext* cntx); static void PSetEx(CmdArgList args, ConnectionContext* cntx); + static void ClThrottle(CmdArgList args, ConnectionContext* cntx); + // These functions are used internally, they do not implement any specific command static void IncrByGeneric(std::string_view key, int64_t val, ConnectionContext* cntx); static void ExtendGeneric(CmdArgList args, bool prepend, ConnectionContext* cntx); diff --git a/src/server/string_family_test.cc b/src/server/string_family_test.cc index f914b8ca3dfa..fa392800be93 100644 --- a/src/server/string_family_test.cc +++ b/src/server/string_family_test.cc @@ -557,4 +557,67 @@ TEST_F(StringFamilyTest, GetEx) { EXPECT_THAT(Run({"getex", "foo"}), ArgType(RespExpr::NIL)); } +TEST_F(StringFamilyTest, ClThrottle) { + const int64_t limit = 5; + const char* const key = "foo"; + const char* const max_burst = "4"; // limit - 1 + const char* const count = "1"; + const char* const period = "10"; + + // You can never make a request larger than the maximum. + auto resp = Run({"cl.throttle", key, max_burst, count, period, "6"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(5), IntArg(-1), IntArg(0))); + + // Rate limit normal requests appropriately. + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(3), IntArg(-1), IntArg(21))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(0), IntArg(-1), IntArg(51))); + + resp = Run({"cl.throttle", key, max_burst, count, period}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(0), IntArg(11), IntArg(51))); + + AdvanceTime(30000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + + AdvanceTime(1000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(40))); + + AdvanceTime(9000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(1), IntArg(-1), IntArg(41))); + + AdvanceTime(40000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + AdvanceTime(15000); + resp = Run({"cl.throttle", key, max_burst, count, period, "1"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + // Zero-volume request just peeks at the state. + resp = Run({"cl.throttle", key, max_burst, count, period, "0"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(4), IntArg(-1), IntArg(11))); + + // High-volume request uses up more of the limit. + resp = Run({"cl.throttle", key, max_burst, count, period, "2"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(0), IntArg(limit), IntArg(2), IntArg(-1), IntArg(31))); + + // Large requests cannot exceed limits + resp = Run({"cl.throttle", key, max_burst, count, period, "5"}); + ASSERT_THAT(resp.GetVec(), ElementsAre(IntArg(1), IntArg(limit), IntArg(2), IntArg(31), IntArg(31))); +} + } // namespace dfly