diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 391b1e4a2c2..b6e9cd4698b 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1,7 +1,9 @@ import base64 import json import logging +import os import re +import traceback import zlib from enum import Enum from http import HTTPStatus @@ -9,6 +11,8 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import ServiceError +from aws_lambda_powertools.shared import constants +from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -28,43 +32,46 @@ class ProxyEventType(Enum): class CORSConfig(object): """CORS Config - Examples -------- Simple cors example using the default permissive cors, not this should only be used during early prototyping - from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver + ```python + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver - app = ApiGatewayResolver() + app = ApiGatewayResolver() - @app.get("/my/path", cors=True) - def with_cors(): - return {"message": "Foo"} + @app.get("/my/path", cors=True) + def with_cors(): + return {"message": "Foo"} + ``` Using a custom CORSConfig where `with_cors` used the custom provided CORSConfig and `without_cors` do not include any cors headers. - from aws_lambda_powertools.event_handler.api_gateway import ( - ApiGatewayResolver, CORSConfig - ) - - cors_config = CORSConfig( - allow_origin="https://wwww.example.com/", - expose_headers=["x-exposed-response-header"], - allow_headers=["x-custom-request-header"], - max_age=100, - allow_credentials=True, - ) - app = ApiGatewayResolver(cors=cors_config) - - @app.get("/my/path") - def with_cors(): - return {"message": "Foo"} + ```python + from aws_lambda_powertools.event_handler.api_gateway import ( + ApiGatewayResolver, CORSConfig + ) + + cors_config = CORSConfig( + allow_origin="https://wwww.example.com/", + expose_headers=["x-exposed-response-header"], + allow_headers=["x-custom-request-header"], + max_age=100, + allow_credentials=True, + ) + app = ApiGatewayResolver(cors=cors_config) + + @app.get("/my/path") + def with_cors(): + return {"message": "Foo"} - @app.get("/another-one", cors=False) - def without_cors(): - return {"message": "Foo"} + @app.get("/another-one", cors=False) + def without_cors(): + return {"message": "Foo"} + ``` """ _REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"] @@ -240,7 +247,12 @@ def lambda_handler(event, context): current_event: BaseProxyEvent lambda_context: LambdaContext - def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: CORSConfig = None): + def __init__( + self, + proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, + cors: CORSConfig = None, + debug: Optional[bool] = None, + ): """ Parameters ---------- @@ -248,12 +260,18 @@ def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: Proxy request type, defaults to API Gateway V1 cors: CORSConfig Optionally configure and enabled CORS. Not each route will need to have to cors=True + debug: Optional[bool] + Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG" + environment variable """ self._proxy_type = proxy_type self._routes: List[Route] = [] self._cors = cors self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} + self._debug = resolve_truthy_env_var_choice( + choice=debug, env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false") + ) def get(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): """Get route decorator with GET `method` @@ -416,6 +434,8 @@ def resolve(self, event, context) -> Dict[str, Any]: dict Returns the dict response """ + if self._debug: + print(self._json_dump(event)) self.current_event = self._to_proxy_event(event) self.lambda_context = context return self._resolve().build(self.current_event, self._cors) @@ -489,6 +509,19 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: ), route, ) + except Exception: + if self._debug: + # If the user has turned on debug mode, + # we'll let the original exception propagate so + # they get more information about what went wrong. + return ResponseBuilder( + Response( + status_code=500, + content_type=content_types.TEXT_PLAIN, + body="".join(traceback.format_exc()), + ) + ) + raise def _to_response(self, result: Union[Dict, Response]) -> Response: """Convert the route's result to a Response @@ -509,7 +542,9 @@ def _to_response(self, result: Union[Dict, Response]) -> Response: body=self._json_dump(result), ) - @staticmethod - def _json_dump(obj: Any) -> str: - """Does a concise json serialization""" - return json.dumps(obj, separators=(",", ":"), cls=Encoder) + def _json_dump(self, obj: Any) -> str: + """Does a concise json serialization or pretty print when in debug mode""" + if self._debug: + return json.dumps(obj, indent=4, cls=Encoder) + else: + return json.dumps(obj, separators=(",", ":"), cls=Encoder) diff --git a/aws_lambda_powertools/event_handler/content_types.py b/aws_lambda_powertools/event_handler/content_types.py index 00ec3db168e..0f55b1088ad 100644 --- a/aws_lambda_powertools/event_handler/content_types.py +++ b/aws_lambda_powertools/event_handler/content_types.py @@ -1,4 +1,5 @@ # use mimetypes library to be certain, e.g., mimetypes.types_map[".json"] APPLICATION_JSON = "application/json" -PLAIN_TEXT = "text/plain" +TEXT_PLAIN = "text/plain" +TEXT_HTML = "text/html" diff --git a/aws_lambda_powertools/shared/constants.py b/aws_lambda_powertools/shared/constants.py index eaad5640dfd..8388eded654 100644 --- a/aws_lambda_powertools/shared/constants.py +++ b/aws_lambda_powertools/shared/constants.py @@ -10,11 +10,12 @@ METRICS_NAMESPACE_ENV: str = "POWERTOOLS_METRICS_NAMESPACE" +EVENT_HANDLER_DEBUG_ENV: str = "POWERTOOLS_EVENT_HANDLER_DEBUG" + SAM_LOCAL_ENV: str = "AWS_SAM_LOCAL" CHALICE_LOCAL_ENV: str = "AWS_CHALICE_CLI_MODE" SERVICE_NAME_ENV: str = "POWERTOOLS_SERVICE_NAME" XRAY_TRACE_ID_ENV: str = "_X_AMZN_TRACE_ID" - -XRAY_SDK_MODULE = "aws_xray_sdk" -XRAY_SDK_CORE_MODULE = "aws_xray_sdk.core" +XRAY_SDK_MODULE: str = "aws_xray_sdk" +XRAY_SDK_CORE_MODULE: str = "aws_xray_sdk.core" diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index e542483d73e..b39dccc6084 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Dict +import pytest + from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.api_gateway import ( ApiGatewayResolver, @@ -20,6 +22,7 @@ ServiceError, UnauthorizedError, ) +from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from tests.functional.utils import load_event @@ -31,7 +34,6 @@ def read_media(file_name: str) -> bytes: LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") -TEXT_HTML = "text/html" def test_alb_event(): @@ -42,7 +44,7 @@ def test_alb_event(): def foo(): assert isinstance(app.current_event, ALBEvent) assert app.lambda_context == {} - return Response(200, TEXT_HTML, "foo") + return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler result = app(load_event("albEvent.json"), {}) @@ -50,7 +52,7 @@ def foo(): # THEN process event correctly # AND set the current_event type as ALBEvent assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML assert result["body"] == "foo" @@ -80,7 +82,7 @@ def test_api_gateway(): @app.get("/my/path") def get_lambda() -> Response: assert isinstance(app.current_event, APIGatewayProxyEvent) - return Response(200, TEXT_HTML, "foo") + return Response(200, content_types.TEXT_HTML, "foo") # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -88,7 +90,7 @@ def get_lambda() -> Response: # THEN process event correctly # AND set the current_event type as APIGatewayProxyEvent assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML assert result["body"] == "foo" @@ -100,7 +102,7 @@ def test_api_gateway_v2(): def my_path() -> Response: assert isinstance(app.current_event, APIGatewayProxyEventV2) post_data = app.current_event.json_body - return Response(200, content_types.PLAIN_TEXT, post_data["username"]) + return Response(200, content_types.TEXT_PLAIN, post_data["username"]) # WHEN calling the event handler result = app(load_event("apiGatewayProxyV2Event.json"), {}) @@ -108,7 +110,7 @@ def my_path() -> Response: # THEN process event correctly # AND set the current_event type as APIGatewayProxyEventV2 assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == content_types.PLAIN_TEXT + assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN assert result["body"] == "tom" @@ -119,14 +121,14 @@ def test_include_rule_matching(): @app.get("//") def get_lambda(my_id: str, name: str) -> Response: assert name == "my" - return Response(200, TEXT_HTML, my_id) + return Response(200, content_types.TEXT_HTML, my_id) # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) # THEN assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML assert result["body"] == "path" @@ -187,11 +189,11 @@ def test_cors(): @app.get("/my/path", cors=True) def with_cors() -> Response: - return Response(200, TEXT_HTML, "test") + return Response(200, content_types.TEXT_HTML, "test") @app.get("/without-cors") def without_cors() -> Response: - return Response(200, TEXT_HTML, "test") + return Response(200, content_types.TEXT_HTML, "test") def handler(event, context): return app.resolve(event, context) @@ -202,7 +204,7 @@ def handler(event, context): # THEN the headers should include cors headers assert "headers" in result headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert headers["Access-Control-Allow-Origin"] == "*" assert "Access-Control-Allow-Credentials" not in headers assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS)) @@ -268,7 +270,7 @@ def test_compress_no_accept_encoding(): @app.get("/my/path", compress=True) def return_text() -> Response: - return Response(200, content_types.PLAIN_TEXT, expected_value) + return Response(200, content_types.TEXT_PLAIN, expected_value) # WHEN calling the event handler result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None) @@ -284,7 +286,7 @@ def test_cache_control_200(): @app.get("/success", cache_control="max-age=600") def with_cache_control() -> Response: - return Response(200, TEXT_HTML, "has 200 response") + return Response(200, content_types.TEXT_HTML, "has 200 response") def handler(event, context): return app.resolve(event, context) @@ -295,7 +297,7 @@ def handler(event, context): # THEN return the set Cache-Control headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert headers["Cache-Control"] == "max-age=600" @@ -305,7 +307,7 @@ def test_cache_control_non_200(): @app.delete("/fails", cache_control="max-age=600") def with_cache_control_has_500() -> Response: - return Response(503, TEXT_HTML, "has 503 response") + return Response(503, content_types.TEXT_HTML, "has 503 response") def handler(event, context): return app.resolve(event, context) @@ -316,7 +318,7 @@ def handler(event, context): # THEN return a Cache-Control of "no-cache" headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert headers["Cache-Control"] == "no-cache" @@ -479,7 +481,7 @@ def test_custom_preflight_response(): def custom_preflight(): return Response( status_code=200, - content_type=TEXT_HTML, + content_type=content_types.TEXT_HTML, body="Foo", headers={"Access-Control-Allow-Methods": "CUSTOM"}, ) @@ -495,7 +497,7 @@ def custom_method(): assert result["statusCode"] == 200 assert result["body"] == "Foo" headers = result["headers"] - assert headers["Content-Type"] == TEXT_HTML + assert headers["Content-Type"] == content_types.TEXT_HTML assert "Access-Control-Allow-Origin" in result["headers"] assert headers["Access-Control-Allow-Methods"] == "CUSTOM" @@ -582,3 +584,83 @@ def service_error(): assert "Access-Control-Allow-Origin" in result["headers"] expected = {"statusCode": 502, "message": "Something went wrong!"} assert result["body"] == json_dump(expected) + + +def test_debug_unhandled_exceptions_debug_on(): + # GIVEN debug is enabled + # AND an unhandled exception is raised + app = ApiGatewayResolver(debug=True) + assert app._debug + + @app.get("/raises-error") + def raises_error(): + raise RuntimeError("Foo") + + # WHEN calling the handler + result = app({"path": "/raises-error", "httpMethod": "GET"}, None) + + # THEN return a 500 + # AND Content-Type is set to text/plain + # AND include the exception traceback in the response + assert result["statusCode"] == 500 + assert "Traceback (most recent call last)" in result["body"] + headers = result["headers"] + assert headers["Content-Type"] == content_types.TEXT_PLAIN + + +def test_debug_unhandled_exceptions_debug_off(): + # GIVEN debug is disabled + # AND an unhandled exception is raised + app = ApiGatewayResolver(debug=False) + assert not app._debug + + @app.get("/raises-error") + def raises_error(): + raise RuntimeError("Foo") + + # WHEN calling the handler + # THEN raise the original exception + with pytest.raises(RuntimeError) as e: + app({"path": "/raises-error", "httpMethod": "GET"}, None) + + # AND include the original error + assert e.value.args == ("Foo",) + + +def test_debug_mode_environment_variable(monkeypatch): + # GIVEN a debug mode environment variable is set + monkeypatch.setenv(constants.EVENT_HANDLER_DEBUG_ENV, "true") + app = ApiGatewayResolver() + + # WHEN calling app._debug + # THEN the debug mode is enabled + assert app._debug + + +def test_debug_json_formatting(): + # GIVEN debug is True + app = ApiGatewayResolver(debug=True) + response = {"message": "Foo"} + + @app.get("/foo") + def foo(): + return response + + # WHEN calling the handler + result = app({"path": "/foo", "httpMethod": "GET"}, None) + + # THEN return a pretty print json in the body + assert result["body"] == json.dumps(response, indent=4) + + +def test_debug_print_event(capsys): + # GIVE debug is True + app = ApiGatewayResolver(debug=True) + + # WHEN calling resolve + event = {"path": "/foo", "httpMethod": "GET"} + app(event, None) + + # THEN print the event + out, err = capsys.readouterr() + assert json.loads(out) == event diff --git a/tests/functional/py.typed b/tests/functional/py.typed new file mode 100644 index 00000000000..e69de29bb2d