-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
120 lines (101 loc) · 4.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import io
import uuid
import base64
import asyncio
from src.server import Server
from fastapi import FastAPI, Cookie, HTTPException, WebSocket, WebSocketException
from fastapi.requests import Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
app = FastAPI(docs_url=None, redoc_url=None)
limiter = Limiter(key_func=get_remote_address, default_limits=["3/second"])
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
server = Server(time_per_prompt=15*60)
app.mount("/static", StaticFiles(directory="./static/"), name="static")
app.mount("/data", StaticFiles(directory="data"), name="data")
app.mount("/media", StaticFiles(directory="media"), name="media")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
await server.startup()
asyncio.create_task(server.global_timer())
@app.get("/")
@limiter.limit("3/second")
async def read_root(request: Request):
return FileResponse("./static/index.html")
@app.get("/init")
@limiter.limit("2/second")
async def initialize_session(request: Request, response: Response):
session_id = str(uuid.uuid4())
response.set_cookie(key="session_id", value=session_id)
await server.init_client(session_id)
return {"message": "Session initialized", "session_id": session_id}
@app.websocket("/clock")
# @limiter.limit("2/second")
async def connect_clock(websocket: WebSocket, session_id: str=Cookie(None)):
await websocket.accept()
print(f'[INFO] Client {session_id} Connected.')
try:
while True:
await server.add_client(session_id)
await asyncio.sleep(1)
time = await server.fetch_clock()
conns = await server.player_count()
reset = bool(await server.redis_conn.exists('reset'))
await websocket.send_json({"time": time, "reset": reset, "conns": conns})
except WebSocketException:
print('[INFO] Client Disconnected.')
except ConnectionClosedError:
print('[INFO] Client Disconnected.')
except ConnectionClosedOK:
print('[INFO] Client Disconnected.')
finally:
await server.remove_connection(session_id)
@app.get("/client/status")
@limiter.limit("2/second")
async def check_status(request: Request, session_id: str=Cookie(None)):
# Check if session_id exists and is valid
if not session_id or not await server.redis_conn.exists(session_id):
# If not, signal the client that initialization is needed
return JSONResponse(content={'needInitialization': True})
# Fetch client scores if session is valid
scores = await server.fetch_client_scores(session_id)
f = {'won': int(scores['won']), 'needInitialization': False}
return JSONResponse(content=f)
@app.get("/fetch/contents")
@limiter.limit("2/second")
async def fetch_contents(request: Request, session_id: str=Cookie(None)):
if not await server.redis_conn.exists(session_id):
await server.init_client(session_id)
image = await server.fetch_masked_image(session_id)
img_io = io.BytesIO()
image.save(img_io, 'JPEG')
img_io.seek(0)
prompt = await server.fetch_prompt_json(session_id)
story = await server.fetch_story()
content = {
"image": base64.b64encode(img_io.getvalue()).decode(),
"prompt": prompt,
"story": story
}
return JSONResponse(content=content)
@app.post("/compute_score")
@limiter.limit("2/second")
async def compute_score(request: Request, session_id: str = Cookie(None)):
if not await server.redis_conn.exists(session_id):
await server.init_client(session_id)
data = await request.json()
scores = await server.compute_client_scores(session_id, data['inputs'])
return JSONResponse(scores)