Skip to content

Commit

Permalink
Merge pull request #6 from theImmortalCoders/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
marcinbator authored Nov 6, 2024
2 parents 710fe35 + 86499cb commit 50ec33a
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 7 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TOKEN_VERIFY_URL=http://localhost:5172/api/user/auth/verify
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ RUN pip install -r requirements.txt

COPY "/" .

CMD ["python", "main.py"]
CMD ["python","-u","main.py"]
18 changes: 15 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.bot import PongBot
from src.enjoy import enjoy
from src.handlers import AiHandler
from src.routes import RoutesHandler
from src.socket import run_socket
from src.wrapper import StateStack

Expand All @@ -21,14 +22,25 @@ def main():
'WebsocketPong_200000_steps.zip'
)
model = DQN.load(path=path, env=env)

launch_socket(env)

env.connection_event.wait()
enjoy(model=model, env=env)


def launch_socket(env):
routes = [
(r"/ws/pong/", AiHandler, dict(env=env)),
(r"/ws/pong-bot/", PongBot)
(r"/ws/pong-bot/", PongBot),
]

print("Routes: ", routes)

routes.append((r"/ws/routes/", RoutesHandler, dict(routes=routes)))

socket_thread = threading.Thread(target=run_socket, args=(8001, routes))
socket_thread.start()
env.connection_event.wait()
enjoy(model=model, env=env)


if __name__ == '__main__':
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ gym==0.21.0
stable-baselines3[extra]==1.5.0
protobuf==3.20.*
numpy==1.26.4
tornado
tornado~=6.4.1
requests~=2.32.3
python-dotenv~=1.0.1
27 changes: 26 additions & 1 deletion src/handlers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
import json
import os
import socket
from urllib.parse import parse_qs, urlparse

import requests
from dotenv import load_dotenv
from tornado.websocket import WebSocketHandler


def verify_jwt(token):
load_dotenv()
TOKEN_VERIFY_URL = os.getenv('TOKEN_VERIFY_URL')
headers = {"Authorization": "Bearer " + token}

try:
response = requests.get(TOKEN_VERIFY_URL, headers=headers, timeout=2)
print(response)
return response.status_code == 200
except requests.RequestException as e:
print(f"JWT verification failed: {e}")
return False


class BaseHandler(WebSocketHandler):

def check_origin(self, origin):
return True

def open(self):
print("WebSocket connection opened")
query_params = parse_qs(urlparse(self.request.uri).query)
token = query_params.get("jwt", [None])[0]

if not token or not verify_jwt(token):
self.close(code=401, reason="Unauthorized")
return
print("WebSocket connection opened and authenticated")

def on_close(self):
print("WebSocket connection closed")
Expand Down
12 changes: 12 additions & 0 deletions src/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import json

from tornado.web import RequestHandler


class RoutesHandler(RequestHandler):
def initialize(self, routes):
self.routes = routes

def get(self):
routes_info = [{"path": route[0], "name": route[1].__name__} for route in self.routes]
self.write(json.dumps(routes_info))
2 changes: 1 addition & 1 deletion src/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def make_app(routes: List[Tuple[str, Type, dict]]) -> Application:
return Application(routes)
return Application(routes, websocket_ping_interval=10, websocket_ping_timeout=30, )


def run_socket(port: int, routes: List[Tuple[str, Type, dict]]) -> None:
Expand Down

0 comments on commit 50ec33a

Please sign in to comment.