Skip to content

Commit

Permalink
fix: refactor Flask init code
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Jul 6, 2023
1 parent 7f30c17 commit b211131
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 50 deletions.
16 changes: 0 additions & 16 deletions aw_server/rest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import json
import traceback
from datetime import datetime, timedelta
from functools import wraps
from threading import Lock
from typing import Dict

import flask.json.provider
import iso8601
from aw_core import schema
from aw_core.models import Event
Expand Down Expand Up @@ -53,20 +51,6 @@ def decorator(*args, **kwargs):
api = Api(blueprint, doc="/", decorators=[host_header_check])


# TODO: Clean up JSONEncoder code?
# Move to server.py
class CustomJSONProvider(flask.json.provider.DefaultJSONProvider):
def default(self, obj, *args, **kwargs):
try:
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, timedelta):
return obj.total_seconds()
except TypeError:
pass
return super().default(obj)


# Loads event and bucket schema from JSONSchema in aw_core
event = api.schema_model("Event", schema.get_json_schema("event"))
bucket = api.schema_model("Bucket", schema.get_json_schema("bucket"))
Expand Down
76 changes: 44 additions & 32 deletions aw_server/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os
from datetime import datetime, timedelta
from typing import Dict, List

import aw_datastore
import flask.json.provider
from aw_datastore import Datastore
from flask import (
Blueprint,
Expand All @@ -26,40 +28,50 @@


class AWFlask(Flask):
def __init__(self, name, testing: bool, *args, **kwargs):
self.json_provider_class = rest.CustomJSONProvider

# Only pretty-print JSON if in testing mode (because of performance)
def __init__(
self,
host: str,
testing: bool,
storage_method=None,
cors_origins=[],
custom_static=dict(),
*args,
**kwargs
):
name = "aw-server"
self.json_provider_class = CustomJSONProvider
# only prettyprint JSON if testing (due to perf)
self.json_provider_class.compact = not testing

# Initialize Flask
Flask.__init__(self, name, *args, **kwargs)

# Is set on later initialization
self.api: ServerAPI = None # type: ignore


def create_app(
host: str, testing=True, storage_method=None, cors_origins=[], custom_static=dict()
) -> AWFlask:
app = AWFlask("aw-server", testing, static_folder=static_folder, static_url_path="")

with app.app_context():
_config_cors(cors_origins, testing)

app.register_blueprint(root)
app.register_blueprint(rest.blueprint)
app.register_blueprint(get_custom_static_blueprint(custom_static))

if storage_method is None:
storage_method = aw_datastore.get_storage_methods()["memory"]
db = Datastore(storage_method, testing=testing)
app.api = ServerAPI(db=db, testing=testing)

# needed for host-header check
app.config["HOST"] = host

return app
self.config["HOST"] = host # needed for host-header check
with self.app_context():
_config_cors(cors_origins, testing)

# Initialize datastore and API
if storage_method is None:
storage_method = aw_datastore.get_storage_methods()["memory"]
db = Datastore(storage_method, testing=testing)
self.api = ServerAPI(db=db, testing=testing)

self.register_blueprint(root)
self.register_blueprint(rest.blueprint)
self.register_blueprint(get_custom_static_blueprint(custom_static))


class CustomJSONProvider(flask.json.provider.DefaultJSONProvider):
# encoding/decoding of datetime as iso8601 strings
# encoding of timedelta as second floats
def default(self, obj, *args, **kwargs):
try:
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, timedelta):
return obj.total_seconds()
except TypeError:
pass
return super().default(obj)


@root.route("/")
Expand Down Expand Up @@ -105,10 +117,10 @@ def _start(
cors_origins: List[str] = [],
custom_static: Dict[str, str] = dict(),
):
app = create_app(
app = AWFlask(
host,
storage_method=storage_method,
testing=testing,
storage_method=storage_method,
cors_origins=cors_origins,
custom_static=custom_static,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging

import pytest
from aw_server.server import create_app
from aw_server.server import AWFlask

logging.basicConfig(level=logging.WARN)


@pytest.fixture(scope="session")
def app():
return create_app("127.0.0.1", testing=True)
return AWFlask("127.0.0.1", testing=True)


@pytest.fixture(scope="session")
Expand Down

0 comments on commit b211131

Please sign in to comment.