FastAPI-Limiter is a rate limiting tool for fastapi routes with lua script.
Just install from pypi
> pip install fastapi-limiter
FastAPI-Limiter is simple to use, which just provide a dependency RateLimiter
, the following example allow 2
times
request per 5
seconds in route /
.
import aioredis
import uvicorn
from fastapi import Depends, FastAPI
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
app = FastAPI()
@app.on_event("startup")
async def startup():
redis = await aioredis.create_redis_pool("redis://localhost")
await FastAPILimiter.init(redis)
@app.get("/", dependencies=[Depends(RateLimiter(times=2, seconds=5))])
async def index():
return {"msg": "Hello World"}
if __name__ == "__main__":
uvicorn.run("main:app", debug=True, reload=True)
There are some config in FastAPILimiter.init
.
The redis
instance of aioredis
.
Prefix of redis key.
Identifier of route limit, default is ip
, you can override it such as userid
and so on.
async def default_identifier(request: Request):
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0]
return request.client.host + ":" + request.scope["path"]
Callback when access is forbidden, default is raise HTTPException
with 429
status code.
async def default_callback(request: Request, response: Response, pexpire: int):
"""
default callback when too many requests
:param request:
:param pexpire: The remaining milliseconds
:param response:
:return:
"""
expire = ceil(pexpire / 1000)
raise HTTPException(
HTTP_429_TOO_MANY_REQUESTS, "Too Many Requests", headers={"Retry-After": str(expire)}
)
You can use multiple limiters in one route.
@app.get(
"/multiple",
dependencies=[
Depends(RateLimiter(times=1, seconds=5)),
Depends(RateLimiter(times=2, seconds=15)),
],
)
async def multiple():
return {"msg": "Hello World"}
Not that you should note the dependencies orders, keep lower of result of seconds/times
at the first.
The lua script used.
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]
local current = tonumber(redis.call('get', key) or "0")
if current > 0 then
if current + 1 > limit then
return redis.call("PTTL", key)
else
redis.call("INCR", key)
return 0
end
else
redis.call("SET", key, 1, "px", expire_time)
return 0
end
This project is licensed under the Apache-2.0 License.