diff --git a/khl/__init__.py b/khl/__init__.py index c0ad453..a28dd4c 100644 --- a/khl/__init__.py +++ b/khl/__init__.py @@ -19,6 +19,7 @@ from .cert import Cert from .receiver import Receiver, WebhookReceiver, WebsocketReceiver from .requester import HTTPRequester +from .ratelimiter import RateLimiter from .gateway import Gateway, Requestable from .client import Client diff --git a/khl/bot/bot.py b/khl/bot/bot.py index b5e382f..cc76b1d 100644 --- a/khl/bot/bot.py +++ b/khl/bot/bot.py @@ -6,7 +6,7 @@ from typing import Dict, Callable, List, Optional, Union, Coroutine, IO from .. import AsyncRunnable # interfaces -from .. import Cert, HTTPRequester, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related +from .. import Cert, HTTPRequester, RateLimiter, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts from ..command import CommandManager @@ -49,7 +49,8 @@ def __init__(self, out: HTTPRequester = None, compress: bool = True, port=5000, - route='/khl-wh'): + route='/khl-wh', + ratelimiter: Optional[RateLimiter] = RateLimiter(start=80)): """ The most common usage: ``Bot(token='xxxxxx')`` @@ -66,7 +67,7 @@ def __init__(self, if not token and not cert: raise ValueError('require token or cert') - self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route) + self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route, ratelimiter) self._register_client_handler() self.command = CommandManager() @@ -78,7 +79,8 @@ def __init__(self, self._startup_index = [] self._shutdown_index = [] - def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route): + def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route, + ratelimiter): """ construct self.client from args. @@ -102,7 +104,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque return # client and gate not in args, build them - _out = out if out else HTTPRequester(cert) + _out = out if out else HTTPRequester(cert, ratelimiter) if cert.type == Cert.Types.WEBSOCKET: _in = WebsocketReceiver(cert, compress) elif cert.type == Cert.Types.WEBHOOK: diff --git a/khl/ratelimiter.py b/khl/ratelimiter.py new file mode 100644 index 0000000..cac99b9 --- /dev/null +++ b/khl/ratelimiter.py @@ -0,0 +1,106 @@ +import asyncio +import logging +from typing import Dict + +log = logging.getLogger(__name__) + + +class RateLimiter: + """rate limit control + @param start: when the remain reach this number, start ratelimit + """ + + def __init__(self, start: int = 120): + self._ratelimit_info: Dict[str, RateLimiter.RateLimitData] = {} + self._api_bucket_mapping: Dict[str, str] = {} + self._lock = asyncio.Lock() + self._start = start + + async def wait_for_rate(self, route): + """get and wait delay""" + + bucket = await self.get_bucket(route) + delay = await self.get_delay(bucket) + log.debug(f'ratelimiter: {route} req bucket: {bucket} delay: {delay: .3f}s') + await asyncio.sleep(delay) + + async def update(self, route, headers): + """get values and update ratelimit information""" + + if 'X-Rate-Limit-Limit' in headers: + bucket, remaining, reset = self.extract_xrate_header(headers) + await self.push_api_bucket_mapping(route, bucket) + await self.update_ratelimit(bucket, remaining, reset) + log.debug(f'ratelimiter: {route} rsp ratelimit: bucket: {bucket} remaining: {remaining} reset: {reset}s') + + async def push_api_bucket_mapping(self, api: str, bucket: str): + """ + when finished request, associate bucket that api returned with api route + to avoid that bucket and api router are not the same + """ + + api = api.lower() + bucket = bucket.lower() + + async with self._lock: + if api not in self._api_bucket_mapping: + self._api_bucket_mapping[api] = bucket + + async def get_bucket(self, api: str): + """get bucket name by api route""" + + api = api.lower() + + async with self._lock: + if api not in self._api_bucket_mapping: + return api + + return self._api_bucket_mapping[api] + + async def update_ratelimit(self, bucket: str, remaining: int, reset: int): + """update rate limit info""" + + bucket = bucket.lower() + async with self._lock: + if bucket not in self._ratelimit_info: + self._ratelimit_info[bucket] = self.RateLimitData(remaining, reset) + else: + self._ratelimit_info[bucket].remaining = remaining + self._ratelimit_info[bucket].reset = reset + + async def get_delay(self, bucket: str) -> float: + """get request delay time, seconds""" + + bucket = bucket.lower() + async with self._lock: + if bucket not in self._ratelimit_info: + return 0 + + if self._ratelimit_info[bucket].reset == 0: + return 0 + + if self._ratelimit_info[bucket].remaining == 0: + return self._ratelimit_info[bucket].reset + + if self._ratelimit_info[bucket].remaining > self._start: + return 0 + + delay = self._ratelimit_info[bucket].reset / self._ratelimit_info[bucket].remaining + + return delay + + @staticmethod + def extract_xrate_header(headers): + """get bucket, remaining, reset values from headers""" + + bucket = headers['X-Rate-Limit-Bucket'] + remaining = int(headers['X-Rate-Limit-Remaining']) + reset = int(headers['X-Rate-Limit-Reset']) + return bucket, remaining, reset + + class RateLimitData: + """to save single bucket rate limit""" + + def __init__(self, remaining: int = 120, reset: int = 0): + self.remaining = remaining + self.reset = reset diff --git a/khl/requester.py b/khl/requester.py index 268770a..f45dd42 100644 --- a/khl/requester.py +++ b/khl/requester.py @@ -1,9 +1,10 @@ import asyncio import logging -from typing import Union, List +from typing import Union, List, Optional from aiohttp import ClientSession +from .ratelimiter import RateLimiter from .api import _Req from .cert import Cert @@ -15,9 +16,10 @@ class HTTPRequester: """wrap raw requests, handle boilerplate param filling works""" - def __init__(self, cert: Cert): + def __init__(self, cert: Cert, ratelimiter: Optional[RateLimiter]): self._cert = cert self._cs: Union[ClientSession, None] = None + self._ratelimiter = ratelimiter def __del__(self): if self._cs is not None: @@ -29,6 +31,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list, params['headers'] = headers log.debug(f'{method} {route}: req: {params}') # token is excluded + + if self._ratelimiter is not None: + await self._ratelimiter.wait_for_rate(route) + headers['Authorization'] = f'Bot {self._cert.token}' if self._cs is None: # lazy init self._cs = ClientSession() @@ -40,6 +46,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list, rsp = rsp['data'] else: rsp = await res.read() + + if self._ratelimiter is not None: + await self._ratelimiter.update(route, res.headers) + log.debug(f'{method} {route}: rsp: {rsp}') return rsp