Skip to content

Commit

Permalink
feat: add simple ratelimit control
Browse files Browse the repository at this point in the history
  • Loading branch information
hank9999 committed Nov 3, 2023
1 parent 25d26b4 commit faf2c3a
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 3 deletions.
1 change: 1 addition & 0 deletions khl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .cert import Cert
from .receiver import Receiver, WebhookReceiver, WebsocketReceiver
from .requester import HTTPRequester
from .ratelimiter import RateLimiter
from .gateway import Gateway, Requestable
from .client import Client

Expand Down
4 changes: 2 additions & 2 deletions khl/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Callable, List, Optional, Union, Coroutine, IO

from .. import AsyncRunnable # interfaces
from .. import Cert, HTTPRequester, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import Cert, HTTPRequester, RateLimiter, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types
from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts
from ..command import CommandManager
Expand Down Expand Up @@ -102,7 +102,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
return

# client and gate not in args, build them
_out = out if out else HTTPRequester(cert)
_out = out if out else HTTPRequester(cert, RateLimiter())
if cert.type == Cert.Types.WEBSOCKET:
_in = WebsocketReceiver(cert, compress)
elif cert.type == Cert.Types.WEBHOOK:
Expand Down
100 changes: 100 additions & 0 deletions khl/ratelimiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import asyncio
import logging
from typing import Dict

log = logging.getLogger(__name__)


class RateLimiter:
"""rate limit control"""

def __init__(self):
self._ratelimit_info: Dict[str, RateLimiter.RateLimitData] = {}
self._api_bucket_mapping: Dict[str, str] = {}
self._lock = asyncio.Lock()

async def push_api_bucket_mapping(self, api: str, bucket: str):
"""
when finished request, associate bucket that api returned with api route
to avoid that bucket and api router are not the same
"""

api = api.lower()
bucket = bucket.lower()

async with self._lock:
if api not in self._api_bucket_mapping:
self._api_bucket_mapping[api] = bucket

async def get_bucket(self, api: str):
"""get bucket name by api route"""

api = api.lower()

async with self._lock:
if api not in self._api_bucket_mapping:
return api

return self._api_bucket_mapping[api]

async def update_ratelimit(self, bucket: str, remaining: int, reset: int):
"""update rate limit info"""

bucket = bucket.lower()
async with self._lock:
if bucket not in self._ratelimit_info:
self._ratelimit_info[bucket] = self.RateLimitData(remaining, reset)
else:
self._ratelimit_info[bucket].remaining = remaining
self._ratelimit_info[bucket].reset = reset

async def get_delay(self, bucket: str) -> float:
"""get request delay time, seconds"""

bucket = bucket.lower()
async with self._lock:
if bucket not in self._ratelimit_info:
return 0

if self._ratelimit_info[bucket].reset == 0:
return 0

if self._ratelimit_info[bucket].remaining == 0:
return self._ratelimit_info[bucket].reset

delay = self._ratelimit_info[bucket].reset / self._ratelimit_info[bucket].remaining

return delay

async def wait_for_rate(self, route):
"""get and wait delay"""

bucket = await self.get_bucket(route)
delay = await self.get_delay(bucket)
log.debug(f'ratelimiter: {route} req bucket: {bucket} delay: {delay: .3f}s')
await asyncio.sleep(delay)

@staticmethod
def extract_xrate_header(headers):
"""get bucket, remaining, reset values from headers"""

bucket = headers['X-Rate-Limit-Bucket']
remaining = int(headers['X-Rate-Limit-Remaining'])
reset = int(headers['X-Rate-Limit-Reset'])
return bucket, remaining, reset

async def update(self, route, headers):
"""get values and update ratelimit information"""

if 'X-Rate-Limit-Limit' in headers:
bucket, remaining, reset = self.extract_xrate_header(headers)
await self.push_api_bucket_mapping(route, bucket)
await self.update_ratelimit(bucket, remaining, reset)
log.debug(f'ratelimiter: {route} rsp ratelimit: bucket: {bucket} remaining: {remaining} reset: {reset}s')

class RateLimitData:
"""to save single bucket rate limit"""

def __init__(self, remaining: int = 120, reset: int = 0):
self.remaining = remaining
self.reset = reset
12 changes: 11 additions & 1 deletion khl/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from aiohttp import ClientSession

from .ratelimiter import RateLimiter
from .api import _Req
from .cert import Cert

Expand All @@ -15,9 +16,10 @@
class HTTPRequester:
"""wrap raw requests, handle boilerplate param filling works"""

def __init__(self, cert: Cert):
def __init__(self, cert: Cert, ratelimiter: Union[RateLimiter, None]):
self._cert = cert
self._cs: Union[ClientSession, None] = None
self._ratelimiter = ratelimiter

def __del__(self):
if self._cs is not None:
Expand All @@ -29,6 +31,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list,
params['headers'] = headers

log.debug(f'{method} {route}: req: {params}') # token is excluded

if self._ratelimiter is not None:
await self._ratelimiter.wait_for_rate(route)

headers['Authorization'] = f'Bot {self._cert.token}'
if self._cs is None: # lazy init
self._cs = ClientSession()
Expand All @@ -40,6 +46,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list,
rsp = rsp['data']
else:
rsp = await res.read()

if self._ratelimiter is not None:
await self._ratelimiter.update(route, res.headers)

log.debug(f'{method} {route}: rsp: {rsp}')
return rsp

Expand Down

0 comments on commit faf2c3a

Please sign in to comment.