From 33f80fd3db0fc316959768a0729d0a6ff1ee74bb Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 6 Jul 2021 00:24:18 -0700 Subject: [PATCH] feat(api-gateway): add common service errors (#506) --- .../event_handler/__init__.py | 3 +- .../event_handler/api_gateway.py | 33 ++++-- .../event_handler/content_types.py | 2 + .../event_handler/exceptions.py | 45 +++++++ .../event_handler/test_api_gateway.py | 110 ++++++++++++++++-- 5 files changed, 175 insertions(+), 18 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/content_types.py create mode 100644 aws_lambda_powertools/event_handler/exceptions.py diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 0475982e377..def92f706f9 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -2,6 +2,7 @@ Event handler decorators for common Lambda events """ +from .api_gateway import ApiGatewayResolver from .appsync import AppSyncResolver -__all__ = ["AppSyncResolver"] +__all__ = ["AppSyncResolver", "ApiGatewayResolver"] diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 2b1e1fc0900..391b1e4a2c2 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -4,8 +4,11 @@ import re import zlib from enum import Enum +from http import HTTPStatus from typing import Any, Callable, Dict, List, Optional, Set, Union +from aws_lambda_powertools.event_handler import content_types +from aws_lambda_powertools.event_handler.exceptions import ServiceError from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -466,19 +469,28 @@ def _not_found(self, method: str) -> ResponseBuilder: return ResponseBuilder( Response( - status_code=404, - content_type="application/json", + status_code=HTTPStatus.NOT_FOUND.value, + content_type=content_types.APPLICATION_JSON, headers=headers, - body=json.dumps({"message": "Not found"}), + body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}), ) ) def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: """Actually call the matching route with any provided keyword arguments.""" - return ResponseBuilder(self._to_response(route.func(**args)), route) + try: + return ResponseBuilder(self._to_response(route.func(**args)), route) + except ServiceError as e: + return ResponseBuilder( + Response( + status_code=e.status_code, + content_type=content_types.APPLICATION_JSON, + body=self._json_dump({"statusCode": e.status_code, "message": e.msg}), + ), + route, + ) - @staticmethod - def _to_response(result: Union[Dict, Response]) -> Response: + def _to_response(self, result: Union[Dict, Response]) -> Response: """Convert the route's result to a Response 2 main result types are supported: @@ -493,6 +505,11 @@ def _to_response(result: Union[Dict, Response]) -> Response: logger.debug("Simple response detected, serializing return before constructing final response") return Response( status_code=200, - content_type="application/json", - body=json.dumps(result, separators=(",", ":"), cls=Encoder), + content_type=content_types.APPLICATION_JSON, + body=self._json_dump(result), ) + + @staticmethod + def _json_dump(obj: Any) -> str: + """Does a concise json serialization""" + return json.dumps(obj, separators=(",", ":"), cls=Encoder) diff --git a/aws_lambda_powertools/event_handler/content_types.py b/aws_lambda_powertools/event_handler/content_types.py new file mode 100644 index 00000000000..33b4acae7cb --- /dev/null +++ b/aws_lambda_powertools/event_handler/content_types.py @@ -0,0 +1,2 @@ +APPLICATION_JSON = "application/json" +PLAIN_TEXT = "plain/text" diff --git a/aws_lambda_powertools/event_handler/exceptions.py b/aws_lambda_powertools/event_handler/exceptions.py new file mode 100644 index 00000000000..56ea3e764d1 --- /dev/null +++ b/aws_lambda_powertools/event_handler/exceptions.py @@ -0,0 +1,45 @@ +from http import HTTPStatus + + +class ServiceError(Exception): + """Service Error""" + + def __init__(self, status_code: int, msg: str): + """ + Parameters + ---------- + status_code: int + Http status code + msg: str + Error message + """ + self.status_code = status_code + self.msg = msg + + +class BadRequestError(ServiceError): + """Bad Request Error""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.BAD_REQUEST, msg) + + +class UnauthorizedError(ServiceError): + """Unauthorized Error""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.UNAUTHORIZED, msg) + + +class NotFoundError(ServiceError): + """Not Found Error""" + + def __init__(self, msg: str = "Not found"): + super().__init__(HTTPStatus.NOT_FOUND, msg) + + +class InternalServerError(ServiceError): + """Internal Server Error""" + + def __init__(self, message: str): + super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, message) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index caaaeb1b97b..e542483d73e 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Dict +from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.api_gateway import ( ApiGatewayResolver, CORSConfig, @@ -12,6 +13,13 @@ Response, ResponseBuilder, ) +from aws_lambda_powertools.event_handler.exceptions import ( + BadRequestError, + InternalServerError, + NotFoundError, + ServiceError, + UnauthorizedError, +) from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from tests.functional.utils import load_event @@ -24,7 +32,6 @@ def read_media(file_name: str) -> bytes: LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") TEXT_HTML = "text/html" -APPLICATION_JSON = "application/json" def test_alb_event(): @@ -55,7 +62,7 @@ def test_api_gateway_v1(): def get_lambda() -> Response: assert isinstance(app.current_event, APIGatewayProxyEvent) assert app.lambda_context == {} - return Response(200, APPLICATION_JSON, json.dumps({"foo": "value"})) + return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"})) # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -63,7 +70,7 @@ def get_lambda() -> Response: # THEN process event correctly # AND set the current_event type as APIGatewayProxyEvent assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == APPLICATION_JSON + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON def test_api_gateway(): @@ -93,7 +100,7 @@ def test_api_gateway_v2(): def my_path() -> Response: assert isinstance(app.current_event, APIGatewayProxyEventV2) post_data = app.current_event.json_body - return Response(200, "plain/text", post_data["username"]) + return Response(200, content_types.PLAIN_TEXT, post_data["username"]) # WHEN calling the event handler result = app(load_event("apiGatewayProxyV2Event.json"), {}) @@ -101,7 +108,7 @@ def my_path() -> Response: # THEN process event correctly # AND set the current_event type as APIGatewayProxyEventV2 assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == "plain/text" + assert result["headers"]["Content-Type"] == content_types.PLAIN_TEXT assert result["body"] == "tom" @@ -215,7 +222,7 @@ def test_compress(): @app.get("/my/request", compress=True) def with_compression() -> Response: - return Response(200, APPLICATION_JSON, expected_value) + return Response(200, content_types.APPLICATION_JSON, expected_value) def handler(event, context): return app.resolve(event, context) @@ -261,7 +268,7 @@ def test_compress_no_accept_encoding(): @app.get("/my/path", compress=True) def return_text() -> Response: - return Response(200, "text/plain", expected_value) + return Response(200, content_types.PLAIN_TEXT, expected_value) # WHEN calling the event handler result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None) @@ -327,7 +334,7 @@ def rest_func() -> Dict: # THEN automatically process this as a json rest api response assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == APPLICATION_JSON + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder) assert result["body"] == expected_str @@ -382,7 +389,7 @@ def another_one(): # THEN routes by default return the custom cors headers assert "headers" in result headers = result["headers"] - assert headers["Content-Type"] == APPLICATION_JSON + assert headers["Content-Type"] == content_types.APPLICATION_JSON assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS))) assert headers["Access-Control-Allow-Headers"] == expected_allows_headers @@ -429,6 +436,7 @@ def test_no_matches_with_cors(): # AND cors headers are returned assert result["statusCode"] == 404 assert "Access-Control-Allow-Origin" in result["headers"] + assert "Not found" in result["body"] def test_cors_preflight(): @@ -490,3 +498,87 @@ def custom_method(): assert headers["Content-Type"] == TEXT_HTML assert "Access-Control-Allow-Origin" in result["headers"] assert headers["Access-Control-Allow-Methods"] == "CUSTOM" + + +def test_service_error_responses(): + # SCENARIO handling different kind of service errors being raised + app = ApiGatewayResolver(cors=CORSConfig()) + + def json_dump(obj): + return json.dumps(obj, separators=(",", ":")) + + # GIVEN an BadRequestError + @app.get(rule="/bad-request-error", cors=False) + def bad_request_error(): + raise BadRequestError("Missing required parameter") + + # WHEN calling the handler + # AND path is /bad-request-error + result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None) + # THEN return the bad request error response + # AND status code equals 400 + assert result["statusCode"] == 400 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 400, "message": "Missing required parameter"} + assert result["body"] == json_dump(expected) + + # GIVEN an UnauthorizedError + @app.get(rule="/unauthorized-error", cors=False) + def unauthorized_error(): + raise UnauthorizedError("Unauthorized") + + # WHEN calling the handler + # AND path is /unauthorized-error + result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None) + # THEN return the unauthorized error response + # AND status code equals 401 + assert result["statusCode"] == 401 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 401, "message": "Unauthorized"} + assert result["body"] == json_dump(expected) + + # GIVEN an NotFoundError + @app.get(rule="/not-found-error", cors=False) + def not_found_error(): + raise NotFoundError + + # WHEN calling the handler + # AND path is /not-found-error + result = app({"path": "/not-found-error", "httpMethod": "GET"}, None) + # THEN return the not found error response + # AND status code equals 404 + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 404, "message": "Not found"} + assert result["body"] == json_dump(expected) + + # GIVEN an InternalServerError + @app.get(rule="/internal-server-error", cors=False) + def internal_server_error(): + raise InternalServerError("Internal server error") + + # WHEN calling the handler + # AND path is /internal-server-error + result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None) + # THEN return the internal server error response + # AND status code equals 500 + assert result["statusCode"] == 500 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + expected = {"statusCode": 500, "message": "Internal server error"} + assert result["body"] == json_dump(expected) + + # GIVEN an ServiceError with a custom status code + @app.get(rule="/service-error", cors=True) + def service_error(): + raise ServiceError(502, "Something went wrong!") + + # WHEN calling the handler + # AND path is /service-error + result = app({"path": "/service-error", "httpMethod": "GET"}, None) + # THEN return the service error response + # AND status code equals 502 + assert result["statusCode"] == 502 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + assert "Access-Control-Allow-Origin" in result["headers"] + expected = {"statusCode": 502, "message": "Something went wrong!"} + assert result["body"] == json_dump(expected)