Skip to content

Commit

Permalink
wip SessionStatus
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher committed Jun 14, 2024
1 parent 69e4bbc commit f9dc16b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
31 changes: 30 additions & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from reflex.state import (
BaseState,
RouterData,
SessionStatus,
State,
StateManager,
StateUpdate,
Expand Down Expand Up @@ -1117,6 +1118,29 @@ async def process(
"""
from reflex.utils import telemetry

# Add request data to the state.
router_data = event.router_data
router_data.update(
{
constants.RouteVar.QUERY: format.format_query_params(event.router_data),
constants.RouteVar.CLIENT_TOKEN: event.token,
constants.RouteVar.SESSION_ID: sid,
constants.RouteVar.HEADERS: headers,
constants.RouteVar.CLIENT_IP: client_ip,
}
)
# Get the state for the session exclusively.
async with app.state_manager.modify_state(event.token) as state:
# re-assign only when the value is different
if state.router_data != router_data:
# assignment will recurse into substates and force recalculation of
# dependent ComputedVar (dynamic route variables)
state.router_data = router_data
if state.router:
state.router.update(router_data)
else:
state.router = RouterData(router_data)

try:
# Add request data to the state.
router_data = event.router_data
Expand Down Expand Up @@ -1324,7 +1348,7 @@ def on_connect(self, sid, environ):
"""
pass

def on_disconnect(self, sid):
async def on_disconnect(self, sid):
"""Event for when the websocket disconnects.
Args:
Expand All @@ -1333,6 +1357,11 @@ def on_disconnect(self, sid):
disconnect_token = self.sid_to_token.pop(sid, None)
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
else:
return

async with self.app.state_manager.modify_state(disconnect_token) as state:
state.router.session.status = SessionStatus.DISCONNECTED

async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""Emit an update to the client.
Expand Down
48 changes: 40 additions & 8 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import asyncio
import contextlib
import copy
import datetime
import enum
import functools
import inspect
import os
Expand Down Expand Up @@ -120,24 +122,46 @@ def __init__(self, router_data: Optional[dict] = None):
self.params = router_data.get(constants.RouteVar.QUERY, {})


class SessionStatus(enum.Enum):
"""The status of the session."""

INITIAL = "initial"
CONNECTED = "connected"
DISCONNECTED = "disconnected"
RECONNECTED = "reconnected"


class SessionData(Base):
"""An object containing session data."""

client_token: str = ""
client_ip: str = ""
session_id: str = ""
status: SessionStatus = SessionStatus.INITIAL
# also represents disconnected_at if status is DISCONNECTED
last_event: datetime.datetime = datetime.datetime.now()

def __init__(self, router_data: Optional[dict] = None):
"""Initalize the SessionData object based on router_data.
def update(self, router_data: Optional[dict] = None):
"""Update the session data based on the router_data.
Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
self.last_event = datetime.datetime.now()
if not router_data:
return
self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
new_session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
print(
f"current_session_id: {self.session_id}, new_session_id: {new_session_id}"
)
if self.session_id and new_session_id and self.session_id != new_session_id:
self.status = SessionStatus.RECONNECTED
print("Reconnected")
else:
self.status = SessionStatus.CONNECTED
self.session_id = new_session_id


class RouterData(Base):
Expand All @@ -154,7 +178,15 @@ def __init__(self, router_data: Optional[dict] = None):
router_data: the router_data dict.
"""
super().__init__()
self.session = SessionData(router_data)
self.update(router_data)

def update(self, router_data: Optional[dict] = None):
"""Update the router data based on the router_data.
Args:
router_data: the router_data dict.
"""
self.session.update(router_data)
self.headers = HeaderData(router_data)
self.page = PageData(router_data)

Expand Down

0 comments on commit f9dc16b

Please sign in to comment.