Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api-gateway): add common HTTP service errors #506

Merged
merged 8 commits into from
Jul 6, 2021
73 changes: 66 additions & 7 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,51 @@
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)
APPLICATION_JSON = "application/json"
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved


class ServiceError(Exception):
"""Service Error"""

def __init__(self, status_code: int, msg: str):
"""
Parameters
----------
status_code: int
Http status code
msg: str
Error message
"""
self.status_code = status_code
self.msg = msg


class BadRequestError(ServiceError):
"""Bad Request Error"""

def __init__(self, msg: str):
super().__init__(400, msg)


class UnauthorizedError(ServiceError):
"""Unauthorized Error"""

def __init__(self, msg: str):
super().__init__(401, msg)


class NotFoundError(ServiceError):
"""Not Found Error"""

def __init__(self, msg: str = "Not found"):
super().__init__(404, msg)
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved


class InternalServerError(ServiceError):
"""Internal Server Error"""

def __init__(self, message: str):
super().__init__(500, message)


class ProxyEventType(Enum):
Expand Down Expand Up @@ -467,18 +512,27 @@ def _not_found(self, method: str) -> ResponseBuilder:
return ResponseBuilder(
Response(
status_code=404,
content_type="application/json",
content_type=APPLICATION_JSON,
headers=headers,
body=json.dumps({"message": "Not found"}),
body=self._json_dump({"code": 404, "message": "Not found"}),
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved
)
)

def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
return ResponseBuilder(self._to_response(route.func(**args)), route)
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
except ServiceError as e:
return ResponseBuilder(
Response(
status_code=e.status_code,
content_type=APPLICATION_JSON,
body=self._json_dump({"code": e.status_code, "message": e.msg}),
michaelbrewer marked this conversation as resolved.
Show resolved Hide resolved
),
route,
)

@staticmethod
def _to_response(result: Union[Dict, Response]) -> Response:
def _to_response(self, result: Union[Dict, Response]) -> Response:
"""Convert the route's result to a Response

2 main result types are supported:
Expand All @@ -493,6 +547,11 @@ def _to_response(result: Union[Dict, Response]) -> Response:
logger.debug("Simple response detected, serializing return before constructing final response")
return Response(
status_code=200,
content_type="application/json",
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
content_type=APPLICATION_JSON,
body=self._json_dump(result),
)

@staticmethod
def _json_dump(obj: Any) -> str:
"""Does a concise json serialization"""
return json.dumps(obj, separators=(",", ":"), cls=Encoder)
92 changes: 91 additions & 1 deletion tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
from typing import Dict

from aws_lambda_powertools.event_handler.api_gateway import (
APPLICATION_JSON,
ApiGatewayResolver,
BadRequestError,
CORSConfig,
InternalServerError,
NotFoundError,
ProxyEventType,
Response,
ResponseBuilder,
ServiceError,
UnauthorizedError,
)
from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
Expand All @@ -24,7 +30,6 @@ def read_media(file_name: str) -> bytes:

LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
TEXT_HTML = "text/html"
APPLICATION_JSON = "application/json"


def test_alb_event():
Expand Down Expand Up @@ -429,6 +434,7 @@ def test_no_matches_with_cors():
# AND cors headers are returned
assert result["statusCode"] == 404
assert "Access-Control-Allow-Origin" in result["headers"]
assert "Not found" in result["body"]


def test_cors_preflight():
Expand Down Expand Up @@ -490,3 +496,87 @@ def custom_method():
assert headers["Content-Type"] == TEXT_HTML
assert "Access-Control-Allow-Origin" in result["headers"]
assert headers["Access-Control-Allow-Methods"] == "CUSTOM"


def test_service_error_responses():
# SCENARIO handling different kind of service errors being raised
app = ApiGatewayResolver(cors=CORSConfig())

def json_dump(obj):
return json.dumps(obj, separators=(",", ":"))

# GIVEN an BadRequestError
@app.get(rule="/bad-request-error", cors=False)
def bad_request_error():
raise BadRequestError("Missing required parameter")

# WHEN calling the handler
# AND path is /bad-request-error
result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None)
# THEN return the bad request error response
# AND status code equals 400
assert result["statusCode"] == 400
assert result["headers"]["Content-Type"] == APPLICATION_JSON
expected = {"code": 400, "message": "Missing required parameter"}
assert result["body"] == json_dump(expected)

# GIVEN an UnauthorizedError
@app.get(rule="/unauthorized-error", cors=False)
def unauthorized_error():
raise UnauthorizedError("Unauthorized")

# WHEN calling the handler
# AND path is /unauthorized-error
result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None)
# THEN return the unauthorized error response
# AND status code equals 401
assert result["statusCode"] == 401
assert result["headers"]["Content-Type"] == APPLICATION_JSON
expected = {"code": 401, "message": "Unauthorized"}
assert result["body"] == json_dump(expected)

# GIVEN an NotFoundError
@app.get(rule="/not-found-error", cors=False)
def not_found_error():
raise NotFoundError

# WHEN calling the handler
# AND path is /not-found-error
result = app({"path": "/not-found-error", "httpMethod": "GET"}, None)
# THEN return the not found error response
# AND status code equals 404
assert result["statusCode"] == 404
assert result["headers"]["Content-Type"] == APPLICATION_JSON
expected = {"code": 404, "message": "Not found"}
assert result["body"] == json_dump(expected)

# GIVEN an InternalServerError
@app.get(rule="/internal-server-error", cors=False)
def internal_server_error():
raise InternalServerError("Internal server error")

# WHEN calling the handler
# AND path is /internal-server-error
result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None)
# THEN return the internal server error response
# AND status code equals 500
assert result["statusCode"] == 500
assert result["headers"]["Content-Type"] == APPLICATION_JSON
expected = {"code": 500, "message": "Internal server error"}
assert result["body"] == json_dump(expected)

# GIVEN an ServiceError with a custom status code
@app.get(rule="/service-error", cors=True)
def service_error():
raise ServiceError(502, "Something went wrong!")

# WHEN calling the handler
# AND path is /service-error
result = app({"path": "/service-error", "httpMethod": "GET"}, None)
# THEN return the service error response
# AND status code equals 502
assert result["statusCode"] == 502
assert result["headers"]["Content-Type"] == APPLICATION_JSON
assert "Access-Control-Allow-Origin" in result["headers"]
expected = {"code": 502, "message": "Something went wrong!"}
assert result["body"] == json_dump(expected)