Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add simple ratelimit control #217

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions khl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .cert import Cert
from .receiver import Receiver, WebhookReceiver, WebsocketReceiver
from .requester import HTTPRequester
from .ratelimiter import RateLimiter
from .gateway import Gateway, Requestable
from .client import Client

Expand Down
12 changes: 7 additions & 5 deletions khl/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Callable, List, Optional, Union, Coroutine, IO

from .. import AsyncRunnable # interfaces
from .. import Cert, HTTPRequester, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import Cert, HTTPRequester, RateLimiter, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types
from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts
from ..command import CommandManager
Expand Down Expand Up @@ -49,7 +49,8 @@ def __init__(self,
out: HTTPRequester = None,
compress: bool = True,
port=5000,
route='/khl-wh'):
route='/khl-wh',
ratelimiter: Optional[RateLimiter] = RateLimiter(start=80)):
"""
The most common usage: ``Bot(token='xxxxxx')``
Expand All @@ -66,7 +67,7 @@ def __init__(self,
if not token and not cert:
raise ValueError('require token or cert')

self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route)
self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route, ratelimiter)
self._register_client_handler()

self.command = CommandManager()
Expand All @@ -78,7 +79,8 @@ def __init__(self,
self._startup_index = []
self._shutdown_index = []

def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route):
def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route,
ratelimiter):
"""
construct self.client from args.
Expand All @@ -102,7 +104,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
return

# client and gate not in args, build them
_out = out if out else HTTPRequester(cert)
_out = out if out else HTTPRequester(cert, ratelimiter)
if cert.type == Cert.Types.WEBSOCKET:
_in = WebsocketReceiver(cert, compress)
elif cert.type == Cert.Types.WEBHOOK:
Expand Down
106 changes: 106 additions & 0 deletions khl/ratelimiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import asyncio
import logging
from typing import Dict

log = logging.getLogger(__name__)


class RateLimiter:
"""rate limit control
@param start: when the remain reach this number, start ratelimit
"""

def __init__(self, start: int = 120):
self._ratelimit_info: Dict[str, RateLimiter.RateLimitData] = {}
self._api_bucket_mapping: Dict[str, str] = {}
self._lock = asyncio.Lock()
self._start = start

async def wait_for_rate(self, route):
"""get and wait delay"""

bucket = await self.get_bucket(route)
delay = await self.get_delay(bucket)
log.debug(f'ratelimiter: {route} req bucket: {bucket} delay: {delay: .3f}s')
await asyncio.sleep(delay)

async def update(self, route, headers):
"""get values and update ratelimit information"""

if 'X-Rate-Limit-Limit' in headers:
bucket, remaining, reset = self.extract_xrate_header(headers)
await self.push_api_bucket_mapping(route, bucket)
await self.update_ratelimit(bucket, remaining, reset)
log.debug(f'ratelimiter: {route} rsp ratelimit: bucket: {bucket} remaining: {remaining} reset: {reset}s')

async def push_api_bucket_mapping(self, api: str, bucket: str):
"""
when finished request, associate bucket that api returned with api route
to avoid that bucket and api router are not the same
"""

api = api.lower()
bucket = bucket.lower()

async with self._lock:
if api not in self._api_bucket_mapping:
self._api_bucket_mapping[api] = bucket

async def get_bucket(self, api: str):
"""get bucket name by api route"""

api = api.lower()

async with self._lock:
if api not in self._api_bucket_mapping:
return api

return self._api_bucket_mapping[api]

async def update_ratelimit(self, bucket: str, remaining: int, reset: int):
"""update rate limit info"""

bucket = bucket.lower()
async with self._lock:
if bucket not in self._ratelimit_info:
self._ratelimit_info[bucket] = self.RateLimitData(remaining, reset)
else:
self._ratelimit_info[bucket].remaining = remaining
self._ratelimit_info[bucket].reset = reset

async def get_delay(self, bucket: str) -> float:
"""get request delay time, seconds"""

bucket = bucket.lower()
async with self._lock:
if bucket not in self._ratelimit_info:
return 0

if self._ratelimit_info[bucket].reset == 0:
return 0

if self._ratelimit_info[bucket].remaining == 0:
return self._ratelimit_info[bucket].reset

if self._ratelimit_info[bucket].remaining > self._start:
return 0

delay = self._ratelimit_info[bucket].reset / self._ratelimit_info[bucket].remaining

return delay

@staticmethod
def extract_xrate_header(headers):
"""get bucket, remaining, reset values from headers"""

bucket = headers['X-Rate-Limit-Bucket']
remaining = int(headers['X-Rate-Limit-Remaining'])
reset = int(headers['X-Rate-Limit-Reset'])
return bucket, remaining, reset

class RateLimitData:
"""to save single bucket rate limit"""

def __init__(self, remaining: int = 120, reset: int = 0):
self.remaining = remaining
self.reset = reset
14 changes: 12 additions & 2 deletions khl/requester.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import logging
from typing import Union, List
from typing import Union, List, Optional

from aiohttp import ClientSession

from .ratelimiter import RateLimiter
from .api import _Req
from .cert import Cert

Expand All @@ -15,9 +16,10 @@
class HTTPRequester:
"""wrap raw requests, handle boilerplate param filling works"""

def __init__(self, cert: Cert):
def __init__(self, cert: Cert, ratelimiter: Optional[RateLimiter]):
self._cert = cert
self._cs: Union[ClientSession, None] = None
self._ratelimiter = ratelimiter

def __del__(self):
if self._cs is not None:
Expand All @@ -29,6 +31,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list,
params['headers'] = headers

log.debug(f'{method} {route}: req: {params}') # token is excluded

if self._ratelimiter is not None:
await self._ratelimiter.wait_for_rate(route)

headers['Authorization'] = f'Bot {self._cert.token}'
if self._cs is None: # lazy init
self._cs = ClientSession()
Expand All @@ -40,6 +46,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list,
rsp = rsp['data']
else:
rsp = await res.read()

if self._ratelimiter is not None:
await self._ratelimiter.update(route, res.headers)

log.debug(f'{method} {route}: rsp: {rsp}')
return rsp

Expand Down
Loading