Skip to content

Commit

Permalink
SQL cleanup thread or per requests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lxstr committed Feb 4, 2024
1 parent 7d7d58c commit 0152993
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 10 deletions.
8 changes: 8 additions & 0 deletions src/flask_session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def _get_interface(self, app):
SESSION_SQLALCHEMY_BIND_KEY = config.get(
"SESSION_SQLALCHEMY_BIND_KEY", Defaults.SESSION_SQLALCHEMY_BIND_KEY
)
SESSION_CLEANUP_N_REQUESTS = config.get(
"SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS
)
SESSION_CLEANUP_N_SECONDS = config.get(
"SESSION_CLEANUP_N_SECONDS", Defaults.SESSION_CLEANUP_N_SECONDS
)

common_params = {
"app": app,
Expand Down Expand Up @@ -147,6 +153,8 @@ def _get_interface(self, app):
sequence=SESSION_SQLALCHEMY_SEQUENCE,
schema=SESSION_SQLALCHEMY_SCHEMA,
bind_key=SESSION_SQLALCHEMY_BIND_KEY,
cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS,
cleanup_n_seconds=SESSION_CLEANUP_N_SECONDS,
)
else:
raise RuntimeError(f"Unrecognized value for SESSION_TYPE: {SESSION_TYPE}")
Expand Down
2 changes: 2 additions & 0 deletions src/flask_session/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ class Defaults:
SESSION_SQLALCHEMY_SEQUENCE = None
SESSION_SQLALCHEMY_SCHEMA = None
SESSION_SQLALCHEMY_BIND_KEY = None
SESSION_CLEANUP_N_REQUESTS = None
SESSION_CLEANUP_N_SECONDS = None
74 changes: 64 additions & 10 deletions src/flask_session/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
except ImportError:
import pickle

import random
from datetime import datetime
from datetime import timedelta as TimeDelta
from threading import Thread
from typing import Any, Optional

from flask import Flask, Request, Response
Expand All @@ -32,9 +34,9 @@ def __bool__(self) -> bool:

def __init__(
self,
initial: dict[str, Any] | None = None,
sid: str | None = None,
permanent: bool | None = None,
initial: Optional[dict[str, Any]] = None,
sid: Optional[str] = None,
permanent: Optional[bool] = None,
):
def on_update(self) -> None:
self.modified = True
Expand Down Expand Up @@ -177,7 +179,7 @@ def open_session(self, app: Flask, request: Request) -> ServerSideSession:
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)

def _retrieve_session_data(self, store_id: str) -> dict | None:
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
raise NotImplementedError()

def _delete_session(self, store_id: str) -> None:
Expand Down Expand Up @@ -224,7 +226,7 @@ def __init__(
self.redis = redis
super().__init__(app, key_prefix, use_signer, permanent, sid_length)

def _retrieve_session_data(self, store_id: str) -> dict | None:
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
# Get the saved session (value) from the database
serialized_session_data = self.redis.get(store_id)
if serialized_session_data:
Expand Down Expand Up @@ -315,7 +317,7 @@ def _get_memcache_timeout(self, timeout: int) -> int:
timeout += int(time.time())
return timeout

def _retrieve_session_data(self, store_id: str) -> dict | None:
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
# Get the saved session (item) from the database
serialized_session_data = self.client.get(store_id)
if serialized_session_data:
Expand Down Expand Up @@ -382,7 +384,7 @@ def __init__(
self.cache = FileSystemCache(cache_dir, threshold=threshold, mode=mode)
super().__init__(app, key_prefix, use_signer, permanent, sid_length)

def _retrieve_session_data(self, store_id: str) -> dict | None:
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
# Get the saved session (item) from the database
return self.cache.get(store_id)

Expand Down Expand Up @@ -451,7 +453,7 @@ def __init__(

super().__init__(app, key_prefix, use_signer, permanent, sid_length)

def _retrieve_session_data(self, store_id: str) -> dict | None:
def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
# Get the saved session (document) from the database
document = self.store.find_one({"id": store_id})
if document:
Expand Down Expand Up @@ -515,6 +517,11 @@ class SqlAlchemySessionInterface(ServerSideSessionInterface):
:param sequence: The sequence to use for the primary key if needed.
:param schema: The db schema to use
:param bind_key: The db bind key to use
:param cleanup_n_requests: Delete expired sessions approximately every N requests.
:param cleanup_n_seconds: Delete expired sessions approximately every N seconds.
.. versionadded:: 0.7
The `cleanup_n_requests` and `cleanup_n_seconds` parameters were added.
.. versionadded:: 0.6
The `sid_length`, `sequence`, `schema` and `bind_key` parameters were added.
Expand All @@ -538,16 +545,19 @@ def __init__(
sequence: Optional[str] = Defaults.SESSION_SQLALCHEMY_SEQUENCE,
schema: Optional[str] = Defaults.SESSION_SQLALCHEMY_SCHEMA,
bind_key: Optional[str] = Defaults.SESSION_SQLALCHEMY_BIND_KEY,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
cleanup_n_seconds: Optional[int] = Defaults.SESSION_CLEANUP_N_SECONDS,
):
if db is None:
from flask_sqlalchemy import SQLAlchemy

db = SQLAlchemy(app)

self.db = db
self.sequence = sequence
self.schema = schema
self.bind_key = bind_key
self.cleanup_n_requests = cleanup_n_requests
self.cleanup_n_seconds = cleanup_n_seconds
super().__init__(app, key_prefix, use_signer, permanent, sid_length)

# Create the Session database model
Expand Down Expand Up @@ -586,7 +596,51 @@ def __repr__(self):

self.sql_session_model = Session

def _retrieve_session_data(self, store_id: str) -> dict | None:
# Start the cleanup thread
if self.cleanup_n_seconds:
self._start_cleanup_thread(cleanup_n_seconds)

def _clean_up(self) -> None:
# Delete expired sessions approximately every N requests
if self.cleanup_n_seconds or (
self.cleanup_n_requests and random.randint(0, self.cleanup_n_requests) == 0
):
self.app.logger.info("Deleting expired sessions")
try:
self.db.session.query(self.sql_session_model).filter(
self.sql_session_model.expiry <= datetime.utcnow()
).delete(synchronize_session=False)
self.db.session.commit()
except Exception as e:
self.app.logger.exception(
e, "Failed to delete expired sessions. Skipping..."
)

def _start_cleanup_thread(self, cleanup_n_seconds: int) -> None:
def cleanup():
with self.app.app_context():
while True:
try:
self.app.logger.info("Deleting expired sessions")
self.db.session.query(self.sql_session_model).filter(
self.sql_session_model.expiry <= datetime.utcnow()
).delete(synchronize_session=False)
self.db.session.commit()
except Exception as e:
self.app.logger.exception(
e, "Failed to delete expired sessions. Skipping..."
)
# Wait for a specified interval (e.g., 3600 seconds = 1 hour) before the next cleanup
time.sleep(cleanup_n_seconds)

# Create and start the cleanup thread
thread = Thread(target=cleanup, daemon=True)
thread.start()

def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
if self.cleanup_n_requests:
self._clean_up()

# Get the saved session (record) from the database
record = self.sql_session_model.query.filter_by(session_id=store_id).first()

Expand Down

0 comments on commit 0152993

Please sign in to comment.