Skip to content

Commit

Permalink
feat(api-gateway): add support for custom serializer (#568)
Browse files Browse the repository at this point in the history
Co-authored-by: Heitor Lessa <heitor.lessa@hotmail.com>
  • Loading branch information
Michael Brewer and heitorlessa authored Jul 30, 2021
1 parent dfe42b1 commit 8fcb23b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
15 changes: 10 additions & 5 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
import zlib
from enum import Enum
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Union

Expand Down Expand Up @@ -263,6 +264,7 @@ def __init__(
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
cors: Optional[CORSConfig] = None,
debug: Optional[bool] = None,
serializer: Optional[Callable[[Dict], str]] = None,
):
"""
Parameters
Expand All @@ -284,6 +286,13 @@ def __init__(
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
)

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)

if self._debug:
# Always does a pretty print when in debug mode
self._serializer = partial(json.dumps, indent=4, cls=Encoder)

def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
"""Get route decorator with GET `method`
Expand Down Expand Up @@ -592,8 +601,4 @@ def _to_response(self, result: Union[Dict, Response]) -> Response:
)

def _json_dump(self, obj: Any) -> str:
"""Does a concise json serialization or pretty print when in debug mode"""
if self._debug:
return json.dumps(obj, indent=4, cls=Encoder)
else:
return json.dumps(obj, separators=(",", ":"), cls=Encoder)
return self._serializer(obj)
41 changes: 41 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import zlib
from copy import deepcopy
from decimal import Decimal
from enum import Enum
from json import JSONEncoder
from pathlib import Path
from typing import Dict

Expand Down Expand Up @@ -728,3 +730,42 @@ def get_account(account_id: str):

ret = app.resolve(event, None)
assert ret["statusCode"] == 200


def test_custom_serializer():
# GIVEN a custom serializer to handle enums and sets
class CustomEncoder(JSONEncoder):
def default(self, data):
if isinstance(data, Enum):
return data.value
try:
iterable = iter(data)
except TypeError:
pass
else:
return sorted(iterable)
return JSONEncoder.default(self, data)

def custom_serializer(data) -> str:
return json.dumps(data, cls=CustomEncoder)

app = ApiGatewayResolver(serializer=custom_serializer)

class Color(Enum):
RED = 1
BLUE = 2

@app.get("/colors")
def get_color() -> Dict:
return {
"color": Color.RED,
"variations": {"light", "dark"},
}

# WHEN calling handler
response = app({"httpMethod": "GET", "path": "/colors"}, None)

# THEN then use the custom serializer
body = response["body"]
expected = '{"color": 1, "variations": ["dark", "light"]}'
assert expected == body

0 comments on commit 8fcb23b

Please sign in to comment.