diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 44d3f2b07de..7bf364695da 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -6,6 +6,7 @@ import traceback import zlib from enum import Enum +from functools import partial from http import HTTPStatus from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -263,6 +264,7 @@ def __init__( proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, + serializer: Optional[Callable[[Dict], str]] = None, ): """ Parameters @@ -284,6 +286,13 @@ def __init__( env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug ) + # Allow for a custom serializer or a concise json serialization + self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) + + if self._debug: + # Always does a pretty print when in debug mode + self._serializer = partial(json.dumps, indent=4, cls=Encoder) + def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): """Get route decorator with GET `method` @@ -592,8 +601,4 @@ def _to_response(self, result: Union[Dict, Response]) -> Response: ) def _json_dump(self, obj: Any) -> str: - """Does a concise json serialization or pretty print when in debug mode""" - if self._debug: - return json.dumps(obj, indent=4, cls=Encoder) - else: - return json.dumps(obj, separators=(",", ":"), cls=Encoder) + return self._serializer(obj) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index f16086ba634..1272125da8b 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -3,6 +3,8 @@ import zlib from copy import deepcopy from decimal import Decimal +from enum import Enum +from json import JSONEncoder from pathlib import Path from typing import Dict @@ -728,3 +730,42 @@ def get_account(account_id: str): ret = app.resolve(event, None) assert ret["statusCode"] == 200 + + +def test_custom_serializer(): + # GIVEN a custom serializer to handle enums and sets + class CustomEncoder(JSONEncoder): + def default(self, data): + if isinstance(data, Enum): + return data.value + try: + iterable = iter(data) + except TypeError: + pass + else: + return sorted(iterable) + return JSONEncoder.default(self, data) + + def custom_serializer(data) -> str: + return json.dumps(data, cls=CustomEncoder) + + app = ApiGatewayResolver(serializer=custom_serializer) + + class Color(Enum): + RED = 1 + BLUE = 2 + + @app.get("/colors") + def get_color() -> Dict: + return { + "color": Color.RED, + "variations": {"light", "dark"}, + } + + # WHEN calling handler + response = app({"httpMethod": "GET", "path": "/colors"}, None) + + # THEN then use the custom serializer + body = response["body"] + expected = '{"color": 1, "variations": ["dark", "light"]}' + assert expected == body