diff --git a/connexion/apps/flask_app.py b/connexion/apps/flask_app.py index be465aeee..5e9daf38f 100644 --- a/connexion/apps/flask_app.py +++ b/connexion/apps/flask_app.py @@ -2,16 +2,16 @@ This module defines a FlaskApp, a Connexion application to wrap a Flask application. """ -import datetime import logging import pathlib -from decimal import Decimal from types import FunctionType # NOQA import a2wsgi import flask import werkzeug.exceptions -from flask import json, signals +from flask import signals + +from connexion import jsonifier from ..apis.flask_api import FlaskApi from ..exceptions import ProblemException @@ -36,7 +36,7 @@ def __init__(self, import_name, server="flask", extra_files=None, **kwargs): def create_app(self): app = flask.Flask(self.import_name, **self.server_args) - app.json_encoder = FlaskJSONEncoder + app.json = FlaskJSONProvider(app) app.url_map.converters["float"] = NumberConverter app.url_map.converters["int"] = IntegerConverter return app @@ -183,24 +183,12 @@ def __call__(self, scope, receive, send): # pragma: no cover return self.middleware(scope, receive, send) -class FlaskJSONEncoder(json.JSONEncoder): +class FlaskJSONProvider(flask.json.provider.DefaultJSONProvider): + """Custom JSONProvider which adds connexion defaults on top of Flask's""" + + @jsonifier.wrap_default def default(self, o): - if isinstance(o, datetime.datetime): - if o.tzinfo: - # eg: '2015-09-25T23:14:42.588601+00:00' - return o.isoformat("T") - else: - # No timezone present - assume UTC. - # eg: '2015-09-25T23:14:42.588601Z' - return o.isoformat("T") + "Z" - - if isinstance(o, datetime.date): - return o.isoformat() - - if isinstance(o, Decimal): - return float(o) - - return json.JSONEncoder.default(self, o) + return super().default(o) class NumberConverter(werkzeug.routing.BaseConverter): diff --git a/connexion/jsonifier.py b/connexion/jsonifier.py index 1407d4892..266a060df 100644 --- a/connexion/jsonifier.py +++ b/connexion/jsonifier.py @@ -3,21 +3,26 @@ """ import datetime +import functools import json +import typing as t import uuid +from decimal import Decimal -class JSONEncoder(json.JSONEncoder): - """The default Connexion JSON encoder. Handles extra types compared to the +def wrap_default(default_fn: t.Callable) -> t.Callable: + """The Connexion defaults for JSON encoding. Handles extra types compared to the built-in :class:`json.JSONEncoder`. - :class:`datetime.datetime` and :class:`datetime.date` are serialized to :rfc:`822` strings. This is the same as the HTTP date format. + - :class:`decimal.Decimal` is serialized to a float. - :class:`uuid.UUID` is serialized to a string. """ - def default(self, o): + @functools.wraps(default_fn) + def wrapped_default(self, o): if isinstance(o, datetime.datetime): if o.tzinfo: # eg: '2015-09-25T23:14:42.588601+00:00' @@ -30,10 +35,25 @@ def default(self, o): if isinstance(o, datetime.date): return o.isoformat() + if isinstance(o, Decimal): + return float(o) + if isinstance(o, uuid.UUID): return str(o) - return json.JSONEncoder.default(self, o) + return default_fn(o) + + return wrapped_default + + +class JSONEncoder(json.JSONEncoder): + """The default Connexion JSON encoder. Handles extra types compared to the + built-in :class:`json.JSONEncoder`. + """ + + @wrap_default + def default(self, o): + return super().default(o) class Jsonifier: @@ -48,6 +68,7 @@ def __init__(self, json_=json, **kwargs): """ self.json = json_ self.dumps_args = kwargs + self.dumps_args.setdefault("cls", JSONEncoder) def dumps(self, data, **kwargs): """Central point where JSON serialization happens inside diff --git a/connexion/middleware/swagger_ui.py b/connexion/middleware/swagger_ui.py index 73c991416..71fb5f75e 100644 --- a/connexion/middleware/swagger_ui.py +++ b/connexion/middleware/swagger_ui.py @@ -12,7 +12,6 @@ from starlette.types import ASGIApp, Receive, Scope, Send from connexion.apis import AbstractSwaggerUIAPI -from connexion.jsonifier import JSONEncoder, Jsonifier from connexion.middleware import AppMiddleware from connexion.utils import yamldumper @@ -207,7 +206,3 @@ async def _get_swagger_ui_config(self, request): media_type="application/json", content=self.jsonifier.dumps(self.options.openapi_console_ui_config), ) - - @classmethod - def _set_jsonifier(cls): - cls.jsonifier = Jsonifier(cls=JSONEncoder) diff --git a/setup.py b/setup.py index 0453feac0..9319835d2 100755 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def read_version(package): 'PyYAML>=5.1,<7', 'requests>=2.27,<3', 'inflection>=0.3.1,<0.6', - 'werkzeug>=2,<3', + 'werkzeug>=2.2.1,<3', 'starlette>=0.15,<1', 'httpx>=0.15,<1', ] @@ -33,7 +33,7 @@ def read_version(package): swagger_ui_require = 'swagger-ui-bundle>=0.0.2,<0.1' flask_require = [ - 'flask>=2,<3', + 'flask>=2.2,<3', 'a2wsgi>=1.4,<2', ] diff --git a/tests/api/test_responses.py b/tests/api/test_responses.py index 485f7d8e1..c4d4fd4c3 100644 --- a/tests/api/test_responses.py +++ b/tests/api/test_responses.py @@ -2,7 +2,7 @@ from struct import unpack import yaml -from connexion.apps.flask_app import FlaskJSONEncoder +from connexion.apps.flask_app import FlaskJSONProvider from werkzeug.test import Client, EnvironBuilder @@ -279,15 +279,15 @@ def test_nested_additional_properties(simple_openapi_app): assert response == {"nested": {"object": True}} -def test_custom_encoder(simple_app): - class CustomEncoder(FlaskJSONEncoder): +def test_custom_provider(simple_app): + class CustomProvider(FlaskJSONProvider): def default(self, o): if o.__class__.__name__ == "DummyClass": return "cool result" - return FlaskJSONEncoder.default(self, o) + return super().default(o) flask_app = simple_app.app - flask_app.json_encoder = CustomEncoder + flask_app.json = CustomProvider(flask_app) app_client = flask_app.test_client() resp = app_client.get("/v1.0/custom-json-response") diff --git a/tests/test_flask_encoder.py b/tests/test_flask_encoder.py index 08e3cc332..bf457ee6a 100644 --- a/tests/test_flask_encoder.py +++ b/tests/test_flask_encoder.py @@ -4,31 +4,33 @@ from decimal import Decimal import pytest -from connexion.apps.flask_app import FlaskJSONEncoder +from connexion.apps.flask_app import FlaskJSONProvider from conftest import build_app_from_fixture SPECS = ["swagger.yaml", "openapi.yaml"] -def test_json_encoder(): - s = json.dumps({1: 2}, cls=FlaskJSONEncoder) +def test_json_encoder(simple_app): + flask_app = simple_app.app + + s = FlaskJSONProvider(flask_app).dumps({1: 2}) assert '{"1": 2}' == s - s = json.dumps(datetime.date.today(), cls=FlaskJSONEncoder) + s = FlaskJSONProvider(flask_app).dumps(datetime.date.today()) assert len(s) == 12 - s = json.dumps(datetime.datetime.utcnow(), cls=FlaskJSONEncoder) + s = FlaskJSONProvider(flask_app).dumps(datetime.datetime.utcnow()) assert s.endswith('Z"') - s = json.dumps(Decimal(1.01), cls=FlaskJSONEncoder) + s = FlaskJSONProvider(flask_app).dumps(Decimal(1.01)) assert s == "1.01" - s = json.dumps(math.expm1(1e-10), cls=FlaskJSONEncoder) + s = FlaskJSONProvider(flask_app).dumps(math.expm1(1e-10)) assert s == "1.00000000005e-10" -def test_json_encoder_datetime_with_timezone(): +def test_json_encoder_datetime_with_timezone(simple_app): class DummyTimezone(datetime.tzinfo): def utcoffset(self, dt): return datetime.timedelta(0) @@ -36,7 +38,8 @@ def utcoffset(self, dt): def dst(self, dt): return datetime.timedelta(0) - s = json.dumps(datetime.datetime.now(DummyTimezone()), cls=FlaskJSONEncoder) + flask_app = simple_app.app + s = FlaskJSONProvider(flask_app).dumps(datetime.datetime.now(DummyTimezone())) assert s.endswith('+00:00"')