From f9dc16b2bff2591e975207d951ed673e966e816a Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Fri, 9 Feb 2024 21:00:22 +0100 Subject: [PATCH] wip SessionStatus --- reflex/app.py | 31 ++++++++++++++++++++++++++++++- reflex/state.py | 48 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 5441ce9870d..d9dbc8c8087 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -72,6 +72,7 @@ from reflex.state import ( BaseState, RouterData, + SessionStatus, State, StateManager, StateUpdate, @@ -1117,6 +1118,29 @@ async def process( """ from reflex.utils import telemetry + # Add request data to the state. + router_data = event.router_data + router_data.update( + { + constants.RouteVar.QUERY: format.format_query_params(event.router_data), + constants.RouteVar.CLIENT_TOKEN: event.token, + constants.RouteVar.SESSION_ID: sid, + constants.RouteVar.HEADERS: headers, + constants.RouteVar.CLIENT_IP: client_ip, + } + ) + # Get the state for the session exclusively. + async with app.state_manager.modify_state(event.token) as state: + # re-assign only when the value is different + if state.router_data != router_data: + # assignment will recurse into substates and force recalculation of + # dependent ComputedVar (dynamic route variables) + state.router_data = router_data + if state.router: + state.router.update(router_data) + else: + state.router = RouterData(router_data) + try: # Add request data to the state. router_data = event.router_data @@ -1324,7 +1348,7 @@ def on_connect(self, sid, environ): """ pass - def on_disconnect(self, sid): + async def on_disconnect(self, sid): """Event for when the websocket disconnects. Args: @@ -1333,6 +1357,11 @@ def on_disconnect(self, sid): disconnect_token = self.sid_to_token.pop(sid, None) if disconnect_token: self.token_to_sid.pop(disconnect_token, None) + else: + return + + async with self.app.state_manager.modify_state(disconnect_token) as state: + state.router.session.status = SessionStatus.DISCONNECTED async def emit_update(self, update: StateUpdate, sid: str) -> None: """Emit an update to the client. diff --git a/reflex/state.py b/reflex/state.py index 56b28f9e802..0cdfb2acd31 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -5,6 +5,8 @@ import asyncio import contextlib import copy +import datetime +import enum import functools import inspect import os @@ -120,24 +122,46 @@ def __init__(self, router_data: Optional[dict] = None): self.params = router_data.get(constants.RouteVar.QUERY, {}) +class SessionStatus(enum.Enum): + """The status of the session.""" + + INITIAL = "initial" + CONNECTED = "connected" + DISCONNECTED = "disconnected" + RECONNECTED = "reconnected" + + class SessionData(Base): """An object containing session data.""" client_token: str = "" client_ip: str = "" session_id: str = "" + status: SessionStatus = SessionStatus.INITIAL + # also represents disconnected_at if status is DISCONNECTED + last_event: datetime.datetime = datetime.datetime.now() - def __init__(self, router_data: Optional[dict] = None): - """Initalize the SessionData object based on router_data. + def update(self, router_data: Optional[dict] = None): + """Update the session data based on the router_data. Args: router_data: the router_data dict. """ - super().__init__() - if router_data: - self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") - self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") - self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + self.last_event = datetime.datetime.now() + if not router_data: + return + self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") + self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") + new_session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + print( + f"current_session_id: {self.session_id}, new_session_id: {new_session_id}" + ) + if self.session_id and new_session_id and self.session_id != new_session_id: + self.status = SessionStatus.RECONNECTED + print("Reconnected") + else: + self.status = SessionStatus.CONNECTED + self.session_id = new_session_id class RouterData(Base): @@ -154,7 +178,15 @@ def __init__(self, router_data: Optional[dict] = None): router_data: the router_data dict. """ super().__init__() - self.session = SessionData(router_data) + self.update(router_data) + + def update(self, router_data: Optional[dict] = None): + """Update the router data based on the router_data. + + Args: + router_data: the router_data dict. + """ + self.session.update(router_data) self.headers = HeaderData(router_data) self.page = PageData(router_data)