Skip to content

Commit

Permalink
feat(api-gateway): add common service errors (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Brewer authored Jul 6, 2021
1 parent 2473480 commit 33f80fd
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 18 deletions.
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Event handler decorators for common Lambda events
"""

from .api_gateway import ApiGatewayResolver
from .appsync import AppSyncResolver

__all__ = ["AppSyncResolver"]
__all__ = ["AppSyncResolver", "ApiGatewayResolver"]
33 changes: 25 additions & 8 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import re
import zlib
from enum import Enum
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import ServiceError
from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
Expand Down Expand Up @@ -466,19 +469,28 @@ def _not_found(self, method: str) -> ResponseBuilder:

return ResponseBuilder(
Response(
status_code=404,
content_type="application/json",
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
body=json.dumps({"message": "Not found"}),
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
)
)

def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
return ResponseBuilder(self._to_response(route.func(**args)), route)
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
except ServiceError as e:
return ResponseBuilder(
Response(
status_code=e.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
),
route,
)

@staticmethod
def _to_response(result: Union[Dict, Response]) -> Response:
def _to_response(self, result: Union[Dict, Response]) -> Response:
"""Convert the route's result to a Response
2 main result types are supported:
Expand All @@ -493,6 +505,11 @@ def _to_response(result: Union[Dict, Response]) -> Response:
logger.debug("Simple response detected, serializing return before constructing final response")
return Response(
status_code=200,
content_type="application/json",
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
content_type=content_types.APPLICATION_JSON,
body=self._json_dump(result),
)

@staticmethod
def _json_dump(obj: Any) -> str:
"""Does a concise json serialization"""
return json.dumps(obj, separators=(",", ":"), cls=Encoder)
2 changes: 2 additions & 0 deletions aws_lambda_powertools/event_handler/content_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
APPLICATION_JSON = "application/json"
PLAIN_TEXT = "plain/text"
45 changes: 45 additions & 0 deletions aws_lambda_powertools/event_handler/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from http import HTTPStatus


class ServiceError(Exception):
"""Service Error"""

def __init__(self, status_code: int, msg: str):
"""
Parameters
----------
status_code: int
Http status code
msg: str
Error message
"""
self.status_code = status_code
self.msg = msg


class BadRequestError(ServiceError):
"""Bad Request Error"""

def __init__(self, msg: str):
super().__init__(HTTPStatus.BAD_REQUEST, msg)


class UnauthorizedError(ServiceError):
"""Unauthorized Error"""

def __init__(self, msg: str):
super().__init__(HTTPStatus.UNAUTHORIZED, msg)


class NotFoundError(ServiceError):
"""Not Found Error"""

def __init__(self, msg: str = "Not found"):
super().__init__(HTTPStatus.NOT_FOUND, msg)


class InternalServerError(ServiceError):
"""Internal Server Error"""

def __init__(self, message: str):
super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, message)
110 changes: 101 additions & 9 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
from pathlib import Path
from typing import Dict

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
CORSConfig,
ProxyEventType,
Response,
ResponseBuilder,
)
from aws_lambda_powertools.event_handler.exceptions import (
BadRequestError,
InternalServerError,
NotFoundError,
ServiceError,
UnauthorizedError,
)
from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
from tests.functional.utils import load_event
Expand All @@ -24,7 +32,6 @@ def read_media(file_name: str) -> bytes:

LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
TEXT_HTML = "text/html"
APPLICATION_JSON = "application/json"


def test_alb_event():
Expand Down Expand Up @@ -55,15 +62,15 @@ def test_api_gateway_v1():
def get_lambda() -> Response:
assert isinstance(app.current_event, APIGatewayProxyEvent)
assert app.lambda_context == {}
return Response(200, APPLICATION_JSON, json.dumps({"foo": "value"}))
return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"}))

# WHEN calling the event handler
result = app(LOAD_GW_EVENT, {})

# THEN process event correctly
# AND set the current_event type as APIGatewayProxyEvent
assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == APPLICATION_JSON
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON


def test_api_gateway():
Expand Down Expand Up @@ -93,15 +100,15 @@ def test_api_gateway_v2():
def my_path() -> Response:
assert isinstance(app.current_event, APIGatewayProxyEventV2)
post_data = app.current_event.json_body
return Response(200, "plain/text", post_data["username"])
return Response(200, content_types.PLAIN_TEXT, post_data["username"])

# WHEN calling the event handler
result = app(load_event("apiGatewayProxyV2Event.json"), {})

# THEN process event correctly
# AND set the current_event type as APIGatewayProxyEventV2
assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == "plain/text"
assert result["headers"]["Content-Type"] == content_types.PLAIN_TEXT
assert result["body"] == "tom"


Expand Down Expand Up @@ -215,7 +222,7 @@ def test_compress():

@app.get("/my/request", compress=True)
def with_compression() -> Response:
return Response(200, APPLICATION_JSON, expected_value)
return Response(200, content_types.APPLICATION_JSON, expected_value)

def handler(event, context):
return app.resolve(event, context)
Expand Down Expand Up @@ -261,7 +268,7 @@ def test_compress_no_accept_encoding():

@app.get("/my/path", compress=True)
def return_text() -> Response:
return Response(200, "text/plain", expected_value)
return Response(200, content_types.PLAIN_TEXT, expected_value)

# WHEN calling the event handler
result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None)
Expand Down Expand Up @@ -327,7 +334,7 @@ def rest_func() -> Dict:

# THEN automatically process this as a json rest api response
assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == APPLICATION_JSON
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder)
assert result["body"] == expected_str

Expand Down Expand Up @@ -382,7 +389,7 @@ def another_one():
# THEN routes by default return the custom cors headers
assert "headers" in result
headers = result["headers"]
assert headers["Content-Type"] == APPLICATION_JSON
assert headers["Content-Type"] == content_types.APPLICATION_JSON
assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin
expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))
assert headers["Access-Control-Allow-Headers"] == expected_allows_headers
Expand Down Expand Up @@ -429,6 +436,7 @@ def test_no_matches_with_cors():
# AND cors headers are returned
assert result["statusCode"] == 404
assert "Access-Control-Allow-Origin" in result["headers"]
assert "Not found" in result["body"]


def test_cors_preflight():
Expand Down Expand Up @@ -490,3 +498,87 @@ def custom_method():
assert headers["Content-Type"] == TEXT_HTML
assert "Access-Control-Allow-Origin" in result["headers"]
assert headers["Access-Control-Allow-Methods"] == "CUSTOM"


def test_service_error_responses():
# SCENARIO handling different kind of service errors being raised
app = ApiGatewayResolver(cors=CORSConfig())

def json_dump(obj):
return json.dumps(obj, separators=(",", ":"))

# GIVEN an BadRequestError
@app.get(rule="/bad-request-error", cors=False)
def bad_request_error():
raise BadRequestError("Missing required parameter")

# WHEN calling the handler
# AND path is /bad-request-error
result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None)
# THEN return the bad request error response
# AND status code equals 400
assert result["statusCode"] == 400
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 400, "message": "Missing required parameter"}
assert result["body"] == json_dump(expected)

# GIVEN an UnauthorizedError
@app.get(rule="/unauthorized-error", cors=False)
def unauthorized_error():
raise UnauthorizedError("Unauthorized")

# WHEN calling the handler
# AND path is /unauthorized-error
result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None)
# THEN return the unauthorized error response
# AND status code equals 401
assert result["statusCode"] == 401
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 401, "message": "Unauthorized"}
assert result["body"] == json_dump(expected)

# GIVEN an NotFoundError
@app.get(rule="/not-found-error", cors=False)
def not_found_error():
raise NotFoundError

# WHEN calling the handler
# AND path is /not-found-error
result = app({"path": "/not-found-error", "httpMethod": "GET"}, None)
# THEN return the not found error response
# AND status code equals 404
assert result["statusCode"] == 404
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 404, "message": "Not found"}
assert result["body"] == json_dump(expected)

# GIVEN an InternalServerError
@app.get(rule="/internal-server-error", cors=False)
def internal_server_error():
raise InternalServerError("Internal server error")

# WHEN calling the handler
# AND path is /internal-server-error
result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None)
# THEN return the internal server error response
# AND status code equals 500
assert result["statusCode"] == 500
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
expected = {"statusCode": 500, "message": "Internal server error"}
assert result["body"] == json_dump(expected)

# GIVEN an ServiceError with a custom status code
@app.get(rule="/service-error", cors=True)
def service_error():
raise ServiceError(502, "Something went wrong!")

# WHEN calling the handler
# AND path is /service-error
result = app({"path": "/service-error", "httpMethod": "GET"}, None)
# THEN return the service error response
# AND status code equals 502
assert result["statusCode"] == 502
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
assert "Access-Control-Allow-Origin" in result["headers"]
expected = {"statusCode": 502, "message": "Something went wrong!"}
assert result["body"] == json_dump(expected)

0 comments on commit 33f80fd

Please sign in to comment.