-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor admin for better modularity, separation.
- Loading branch information
Showing
14 changed files
with
277 additions
and
224 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# SPDX-FileCopyrightText: 2024 Mark Liffiton <liffiton@gmail.com> | ||
# | ||
# SPDX-License-Identifier: AGPL-3.0-only | ||
|
||
from .base import ( | ||
bp, | ||
register_admin_link, | ||
) | ||
from .main import ( | ||
ChartData, | ||
register_admin_chart, | ||
) | ||
|
||
__all__ = [ | ||
'ChartData', | ||
'bp', | ||
'register_admin_chart', | ||
'register_admin_link', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# SPDX-FileCopyrightText: 2023 Mark Liffiton <liffiton@gmail.com> | ||
# | ||
# SPDX-License-Identifier: AGPL-3.0-only | ||
|
||
import platform | ||
from collections.abc import Callable | ||
from dataclasses import dataclass, field | ||
from datetime import date | ||
from pathlib import Path | ||
from tempfile import NamedTemporaryFile | ||
from typing import ParamSpec, TypeVar | ||
|
||
from flask import ( | ||
Blueprint, | ||
current_app, | ||
send_file, | ||
) | ||
from werkzeug.wrappers.response import Response | ||
|
||
from gened.auth import admin_required | ||
from gened.db import backup_db | ||
|
||
from .consumers import bp as bp_consumers | ||
|
||
bp = Blueprint('admin', __name__, template_folder='templates') | ||
|
||
bp.register_blueprint(bp_consumers, url_prefix='/consumer') | ||
|
||
@bp.before_request | ||
@admin_required | ||
def before_request() -> None: | ||
""" Apply decorator to protect all admin blueprint endpoints. """ | ||
|
||
|
||
@dataclass(frozen=True) | ||
class AdminLink: | ||
"""Represents a link in the admin interface. | ||
Attributes: | ||
endpoint: The Flask endpoint name | ||
display: The text to show in the navigation UI | ||
""" | ||
endpoint: str | ||
display: str | ||
|
||
# For decorator type hints | ||
P = ParamSpec('P') | ||
R = TypeVar('R') | ||
|
||
@dataclass | ||
class AdminLinks: | ||
"""Container for registering admin navigation links.""" | ||
regular: list[AdminLink] = field(default_factory=list) | ||
right: list[AdminLink] = field(default_factory=list) | ||
|
||
def register(self, display_name: str, *, right: bool = False) -> Callable[[Callable[P, R]], Callable[P, R]]: | ||
"""Decorator to register an admin page link. | ||
Args: | ||
display_name: Text to show in the admin interface navigation | ||
right: If True, display this link on the right side of the nav bar | ||
""" | ||
def decorator(route_func: Callable[P, R]) -> Callable[P, R]: | ||
handler_name = f"admin.{route_func.__name__}" | ||
link = AdminLink(handler_name, display_name) | ||
if right: | ||
self.right.append(link) | ||
else: | ||
self.regular.append(link) | ||
return route_func | ||
return decorator | ||
|
||
def get_template_context(self) -> dict[str, list[AdminLink]]: | ||
return { | ||
'admin_links': self.regular, | ||
'admin_links_right': self.right, | ||
} | ||
|
||
# Module-level instance | ||
_admin_links = AdminLinks() | ||
register_admin_link = _admin_links.register # name for the decorator to be imported/used in other modules | ||
|
||
@bp.context_processor | ||
def inject_admin_links() -> dict[str, list[AdminLink]]: | ||
return _admin_links.get_template_context() | ||
|
||
|
||
@dataclass(frozen=True) | ||
class DBDownloadStatus: | ||
"""Status of database download encryption.""" | ||
encrypted: bool | ||
reason: str | None = None # reason provided if not encrypted | ||
|
||
@bp.context_processor | ||
def inject_db_download_status() -> dict[str, DBDownloadStatus]: | ||
if platform.system() == "Windows": | ||
status = DBDownloadStatus(False, "Encryption unavailable on Windows servers.") | ||
elif not current_app.config.get('AGE_PUBLIC_KEY'): | ||
status = DBDownloadStatus(False, "No encryption key configured, AGE_PUBLIC_KEY not set.") | ||
else: | ||
status = DBDownloadStatus(True) | ||
return {'db_download_status': status} | ||
|
||
|
||
@register_admin_link("Download DB", right=True) | ||
@bp.route("/get_db") | ||
def get_db_file() -> Response: | ||
db_name = current_app.config['DATABASE_NAME'] | ||
db_basename = Path(db_name).stem | ||
dl_name = f"{db_basename}_{date.today().strftime('%Y%m%d')}.db" | ||
if current_app.config.get('AGE_PUBLIC_KEY'): | ||
dl_name += '.age' | ||
|
||
if platform.system() == "Windows": | ||
# Slightly unsafe way to do it, because the file may be written while | ||
# send_file is sending it. Temp file issues make it hard to do | ||
# otherwise on Windows, though, and no one should run a production | ||
# server for this on Windows, anyway. | ||
if current_app.config.get('AGE_PUBLIC_KEY'): | ||
current_app.logger.warning("Database download on Windows does not support encryption") | ||
return send_file(current_app.config['DATABASE'], | ||
mimetype='application/vnd.sqlite3', | ||
as_attachment=True, download_name=dl_name) | ||
else: | ||
db_backup_file = NamedTemporaryFile() | ||
backup_db(Path(db_backup_file.name)) | ||
return send_file(db_backup_file, | ||
mimetype='application/vnd.sqlite3', | ||
as_attachment=True, download_name=dl_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# SPDX-FileCopyrightText: 2024 Mark Liffiton <liffiton@gmail.com> | ||
# | ||
# SPDX-License-Identifier: AGPL-3.0-only | ||
|
||
from flask import ( | ||
Blueprint, | ||
flash, | ||
redirect, | ||
render_template, | ||
request, | ||
url_for, | ||
) | ||
from werkzeug.wrappers.response import Response | ||
|
||
from gened.db import get_db | ||
from gened.llm import get_models | ||
from gened.lti import reload_consumers | ||
|
||
bp = Blueprint('admin_consumers', __name__, template_folder='templates/admin') | ||
|
||
|
||
@bp.route("/<int:consumer_id>") | ||
def consumer_form(consumer_id: int | None = None) -> str: | ||
db = get_db() | ||
consumer_row = db.execute("SELECT * FROM consumers WHERE id=?", [consumer_id]).fetchone() | ||
return render_template("admin_consumer_form.html", consumer=consumer_row, models=get_models()) | ||
|
||
|
||
@bp.route("/new") | ||
def consumer_new() -> str: | ||
return render_template("admin_consumer_form.html", models=get_models()) | ||
|
||
|
||
@bp.route("/delete/<int:consumer_id>", methods=['POST']) | ||
def consumer_delete(consumer_id: int) -> Response: | ||
db = get_db() | ||
|
||
# Check for dependencies | ||
classes_count = db.execute("SELECT COUNT(*) FROM classes_lti WHERE lti_consumer_id=?", [consumer_id]).fetchone()[0] | ||
|
||
if classes_count > 0: | ||
flash("Cannot delete consumer: there are related classes.", "warning") | ||
return redirect(url_for(".consumer_form", consumer_id=consumer_id)) | ||
|
||
# No dependencies, proceed with deletion | ||
|
||
# Fetch the consumer's name | ||
consumer_name_row = db.execute("SELECT lti_consumer FROM consumers WHERE id=?", [consumer_id]).fetchone() | ||
if not consumer_name_row: | ||
flash("Invalid id.", "danger") | ||
return redirect(url_for(".consumer_form", consumer_id=consumer_id)) | ||
|
||
consumer_name = consumer_name_row['lti_consumer'] | ||
|
||
# Delete the row | ||
db.execute("DELETE FROM consumers WHERE id=?", [consumer_id]) | ||
db.commit() | ||
reload_consumers() | ||
|
||
flash(f"Consumer '{consumer_name}' deleted.") | ||
|
||
return redirect(url_for("admin.main")) | ||
|
||
|
||
@bp.route("/update", methods=['POST']) | ||
def consumer_update() -> Response: | ||
db = get_db() | ||
|
||
consumer_id = request.form.get("consumer_id", type=int) | ||
|
||
if consumer_id is None: | ||
# Adding a new consumer | ||
cur = db.execute("INSERT INTO consumers (lti_consumer, lti_secret, llm_api_key, model_id) VALUES (?, ?, ?, ?)", | ||
[request.form['lti_consumer'], request.form['lti_secret'], request.form['llm_api_key'], request.form['model_id']]) | ||
consumer_id = cur.lastrowid | ||
db.commit() | ||
flash(f"Consumer {request.form['lti_consumer']} created.") | ||
|
||
elif 'clear_lti_secret' in request.form: | ||
db.execute("UPDATE consumers SET lti_secret='' WHERE id=?", [consumer_id]) | ||
db.commit() | ||
flash("Consumer secret cleared.") | ||
|
||
elif 'clear_llm_api_key' in request.form: | ||
db.execute("UPDATE consumers SET llm_api_key='' WHERE id=?", [consumer_id]) | ||
db.commit() | ||
flash("Consumer API key cleared.") | ||
|
||
else: | ||
# Updating | ||
if request.form.get('lti_secret', ''): | ||
db.execute("UPDATE consumers SET lti_secret=? WHERE id=?", [request.form['lti_secret'], consumer_id]) | ||
if request.form.get('llm_api_key', ''): | ||
db.execute("UPDATE consumers SET llm_api_key=? WHERE id=?", [request.form['llm_api_key'], consumer_id]) | ||
if request.form.get('model_id', ''): | ||
db.execute("UPDATE consumers SET model_id=? WHERE id=?", [request.form['model_id'], consumer_id]) | ||
db.commit() | ||
flash("Consumer updated.") | ||
|
||
# anything might have changed: reload all consumers | ||
reload_consumers() | ||
|
||
return redirect(url_for(".consumer_form", consumer_id=consumer_id)) |
Oops, something went wrong.