diff --git a/aws_lambda_powertools/utilities/data_classes/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py index a47c32ee07f..c5391880122 100644 --- a/aws_lambda_powertools/utilities/data_classes/__init__.py +++ b/aws_lambda_powertools/utilities/data_classes/__init__.py @@ -10,6 +10,7 @@ from .connect_contact_flow_event import ConnectContactFlowEvent from .dynamo_db_stream_event import DynamoDBStreamEvent from .event_bridge_event import EventBridgeEvent +from .event_source import event_source from .kinesis_stream_event import KinesisStreamEvent from .s3_event import S3Event from .ses_event import SESEvent @@ -31,4 +32,5 @@ "SESEvent", "SNSEvent", "SQSEvent", + "event_source", ] diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 73e064d0f26..159779c86a7 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -6,6 +6,7 @@ class ALBEventRequestContext(DictWrapper): @property def elb_target_group_arn(self) -> str: + """Target group arn for your Lambda function""" return self["requestContext"]["elb"]["targetGroupArn"] @@ -15,6 +16,7 @@ class ALBEvent(BaseProxyEvent): Documentation: -------------- - https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html + - https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html """ @property diff --git a/aws_lambda_powertools/utilities/data_classes/event_source.py b/aws_lambda_powertools/utilities/data_classes/event_source.py new file mode 100644 index 00000000000..3968f923573 --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/event_source.py @@ -0,0 +1,39 @@ +from typing import Any, Callable, Dict, Type + +from aws_lambda_powertools.middleware_factory import lambda_handler_decorator +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.typing import LambdaContext + + +@lambda_handler_decorator +def event_source( + handler: Callable[[Any, LambdaContext], Any], + event: Dict[str, Any], + context: LambdaContext, + data_class: Type[DictWrapper], +): + """Middleware to create an instance of the passed in event source data class + + Parameters + ---------- + handler: Callable + Lambda's handler + event: Dict + Lambda's Event + context: Dict + Lambda's Context + data_class: Type[DictWrapper] + Data class type to instantiate + + Example + -------- + + **Sample usage** + + from aws_lambda_powertools.utilities.data_classes import S3Event, event_source + + @event_source(data_class=S3Event) + def handler(event: S3Event, context): + return {"key": event.object_key} + """ + return handler(data_class(event), context) diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/base.py b/aws_lambda_powertools/utilities/idempotency/persistence/base.py index 0cbd34213c1..31aef6dc0f2 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/base.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/base.py @@ -224,6 +224,7 @@ def _generate_hash(self, data: Any) -> str: Hashed representation of the provided data """ + data = getattr(data, "raw_event", data) # could be a data class depending on decorator order hashed_data = self.hash_function(json.dumps(data, cls=Encoder).encode()) return hashed_data.hexdigest() diff --git a/docs/utilities/data_classes.md b/docs/utilities/data_classes.md index 0fc33d3a3f7..5b0d0db8c0a 100644 --- a/docs/utilities/data_classes.md +++ b/docs/utilities/data_classes.md @@ -21,22 +21,35 @@ Lambda function. ### Utilizing the data classes -The classes are initialized by passing in the Lambda event object into the constructor of the appropriate data class. +The classes are initialized by passing in the Lambda event object into the constructor of the appropriate data class or +by using the `event_source` decorator. For example, if your Lambda function is being triggered by an API Gateway proxy integration, you can use the `APIGatewayProxyEvent` class. === "app.py" - ```python hl_lines="1 4" - from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent +```python hl_lines="1 4" +from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent - def lambda_handler(event, context): - event: APIGatewayProxyEvent = APIGatewayProxyEvent(event) +def lambda_handler(event: dict, context): + event = APIGatewayProxyEvent(event) + if 'helloworld' in event.path and event.http_method == 'GET': + do_something_with(event.body, user) +``` - if 'helloworld' in event.path and event.http_method == 'GET': - do_something_with(event.body, user) - ``` +Same example as above, but using the `event_source` decorator + +=== "app.py" + +```python hl_lines="1 3" +from aws_lambda_powertools.utilities.data_classes import event_source, APIGatewayProxyEvent + +@event_source(data_class=APIGatewayProxyEvent) +def lambda_handler(event: APIGatewayProxyEvent, context): + if 'helloworld' in event.path and event.http_method == 'GET': + do_something_with(event.body, user) +``` **Autocomplete with self-documented properties and methods** @@ -49,7 +62,8 @@ For example, if your Lambda function is being triggered by an API Gateway proxy Event Source | Data_class ------------------------------------------------- | --------------------------------------------------------------------------------- [API Gateway Proxy](#api-gateway-proxy) | `APIGatewayProxyEvent` -[API Gateway Proxy event v2](#api-gateway-proxy-v2) | `APIGatewayProxyEventV2` +[API Gateway Proxy V2](#api-gateway-proxy-v2) | `APIGatewayProxyEventV2` +[Application Load Balancer](#application-load-balancer) | `ALBEvent` [AppSync Resolver](#appsync-resolver) | `AppSyncResolverEvent` [CloudWatch Logs](#cloudwatch-logs) | `CloudWatchLogsEvent` [CodePipeline Job Event](#codepipeline-job) | `CodePipelineJobEvent` @@ -76,34 +90,47 @@ It is used for either API Gateway REST API or HTTP API using v1 proxy event. === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent +```python +from aws_lambda_powertools.utilities.data_classes import event_source, APIGatewayProxyEvent - def lambda_handler(event, context): - event: APIGatewayProxyEvent = APIGatewayProxyEvent(event) +@event_source(data_class=APIGatewayProxyEvent) +def lambda_handler(event: APIGatewayProxyEvent, context): + if "helloworld" in event.path and event.http_method == "GET": request_context = event.request_context identity = request_context.identity + user = identity.user + do_something_with(event.json_body, user) +``` - if 'helloworld' in event.path and event.http_method == 'GET': - user = identity.user - do_something_with(event.body, user) - ``` +### API Gateway Proxy V2 -### API Gateway Proxy v2 +It is used for HTTP API using v2 proxy event. === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEventV2 +```python +from aws_lambda_powertools.utilities.data_classes import event_source, APIGatewayProxyEventV2 - def lambda_handler(event, context): - event: APIGatewayProxyEventV2 = APIGatewayProxyEventV2(event) - request_context = event.request_context - query_string_parameters = event.query_string_parameters +@event_source(data_class=APIGatewayProxyEventV2) +def lambda_handler(event: APIGatewayProxyEventV2, context): + if "helloworld" in event.path and event.http_method == "POST": + do_something_with(event.json_body, event.query_string_parameters) +``` - if 'helloworld' in event.raw_path and request_context.http.method == 'POST': - do_something_with(event.body, query_string_parameters) - ``` +### Application Load Balancer + +Is it used for Application load balancer event. + +=== "app.py" + +```python +from aws_lambda_powertools.utilities.data_classes import event_source, ALBEvent + +@event_source(data_class=ALBEvent) +def lambda_handler(event: ALBEvent, context): + if "helloworld" in event.path and event.http_method == "POST": + do_something_with(event.json_body, event.query_string_parameters) +``` ### AppSync Resolver @@ -210,18 +237,17 @@ decompress and parse json data from the event. === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import CloudWatchLogsEvent - from aws_lambda_powertools.utilities.data_classes.cloud_watch_logs_event import CloudWatchLogsDecodedData +```python +from aws_lambda_powertools.utilities.data_classes import event_source, CloudWatchLogsEvent +from aws_lambda_powertools.utilities.data_classes.cloud_watch_logs_event import CloudWatchLogsDecodedData - def lambda_handler(event, context): - event: CloudWatchLogsEvent = CloudWatchLogsEvent(event) - - decompressed_log: CloudWatchLogsDecodedData = event.parse_logs_data - log_events = decompressed_log.log_events - for event in log_events: - do_something_with(event.timestamp, event.message) - ``` +@event_source(data_class=CloudWatchLogsEvent) +def lambda_handler(event: CloudWatchLogsEvent, context): + decompressed_log: CloudWatchLogsDecodedData = event.parse_logs_data + log_events = decompressed_log.log_events + for event in log_events: + do_something_with(event.timestamp, event.message) +``` ### CodePipeline Job @@ -229,51 +255,50 @@ Data classes and utility functions to help create continuous delivery pipelines === "app.py" - ```python - from aws_lambda_powertools import Logger - from aws_lambda_powertools.utilities.data_classes import CodePipelineJobEvent +```python +from aws_lambda_powertools import Logger +from aws_lambda_powertools.utilities.data_classes import event_source, CodePipelineJobEvent - logger = Logger() +logger = Logger() +@event_source(data_class=CodePipelineJobEvent) +def lambda_handler(event, context): + """The Lambda function handler - def lambda_handler(event, context): - """The Lambda function handler - - If a continuing job then checks the CloudFormation stack status - and updates the job accordingly. - - If a new job then kick of an update or creation of the target - CloudFormation stack. - """ - event: CodePipelineJobEvent = CodePipelineJobEvent(event) - - # Extract the Job ID - job_id = event.get_id - - # Extract the params - params: dict = event.decoded_user_parameters - stack = params["stack"] - artifact_name = params["artifact"] - template_file = params["file"] - - try: - if event.data.continuation_token: - # If we're continuing then the create/update has already been triggered - # we just need to check if it has finished. - check_stack_update_status(job_id, stack) - else: - template = event.get_artifact(artifact_name, template_file) - # Kick off a stack update or create - start_update_or_create(job_id, stack, template) - except Exception as e: - # If any other exceptions which we didn't expect are raised - # then fail the job and log the exception message. - logger.exception("Function failed due to exception.") - put_job_failure(job_id, "Function exception: " + str(e)) - - logger.debug("Function complete.") - return "Complete." - ``` + If a continuing job then checks the CloudFormation stack status + and updates the job accordingly. + + If a new job then kick of an update or creation of the target + CloudFormation stack. + """ + + # Extract the Job ID + job_id = event.get_id + + # Extract the params + params: dict = event.decoded_user_parameters + stack = params["stack"] + artifact_name = params["artifact"] + template_file = params["file"] + + try: + if event.data.continuation_token: + # If we're continuing then the create/update has already been triggered + # we just need to check if it has finished. + check_stack_update_status(job_id, stack) + else: + template = event.get_artifact(artifact_name, template_file) + # Kick off a stack update or create + start_update_or_create(job_id, stack, template) + except Exception as e: + # If any other exceptions which we didn't expect are raised + # then fail the job and log the exception message. + logger.exception("Function failed due to exception.") + put_job_failure(job_id, "Function exception: " + str(e)) + + logger.debug("Function complete.") + return "Complete." +``` ### Cognito User Pool @@ -297,15 +322,15 @@ Verify Auth Challenge | `data_classes.cognito_user_pool_event.VerifyAuthChalleng === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import PostConfirmationTriggerEvent +```python +from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import PostConfirmationTriggerEvent - def lambda_handler(event, context): - event: PostConfirmationTriggerEvent = PostConfirmationTriggerEvent(event) +def lambda_handler(event, context): + event: PostConfirmationTriggerEvent = PostConfirmationTriggerEvent(event) - user_attributes = event.request.user_attributes - do_something_with(user_attributes) - ``` + user_attributes = event.request.user_attributes + do_something_with(user_attributes) +``` #### Define Auth Challenge Example @@ -470,17 +495,18 @@ This example is based on the AWS Cognito docs for [Create Auth Challenge Lambda === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import CreateAuthChallengeTriggerEvent +```python +from aws_lambda_powertools.utilities.data_classes import event_source +from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import CreateAuthChallengeTriggerEvent - def handler(event: dict, context) -> dict: - event: CreateAuthChallengeTriggerEvent = CreateAuthChallengeTriggerEvent(event) - if event.request.challenge_name == "CUSTOM_CHALLENGE": - event.response.public_challenge_parameters = {"captchaUrl": "url/123.jpg"} - event.response.private_challenge_parameters = {"answer": "5"} - event.response.challenge_metadata = "CAPTCHA_CHALLENGE" - return event.raw_event - ``` +@event_source(data_class=CreateAuthChallengeTriggerEvent) +def handler(event: CreateAuthChallengeTriggerEvent, context) -> dict: + if event.request.challenge_name == "CUSTOM_CHALLENGE": + event.response.public_challenge_parameters = {"captchaUrl": "url/123.jpg"} + event.response.private_challenge_parameters = {"answer": "5"} + event.response.challenge_metadata = "CAPTCHA_CHALLENGE" + return event.raw_event +``` #### Verify Auth Challenge Response Example @@ -488,16 +514,17 @@ This example is based on the AWS Cognito docs for [Verify Auth Challenge Respons === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import VerifyAuthChallengeResponseTriggerEvent +```python +from aws_lambda_powertools.utilities.data_classes import event_source +from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import VerifyAuthChallengeResponseTriggerEvent - def handler(event: dict, context) -> dict: - event: VerifyAuthChallengeResponseTriggerEvent = VerifyAuthChallengeResponseTriggerEvent(event) - event.response.answer_correct = ( - event.request.private_challenge_parameters.get("answer") == event.request.challenge_answer - ) - return event.raw_event - ``` +@event_source(data_class=VerifyAuthChallengeResponseTriggerEvent) +def handler(event: VerifyAuthChallengeResponseTriggerEvent, context) -> dict: + event.response.answer_correct = ( + event.request.private_challenge_parameters.get("answer") == event.request.challenge_answer + ) + return event.raw_event +``` ### Connect Contact Flow @@ -505,21 +532,21 @@ This example is based on the AWS Cognito docs for [Verify Auth Challenge Respons === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes.connect_contact_flow_event import ( - ConnectContactFlowChannel, - ConnectContactFlowEndpointType, - ConnectContactFlowEvent, - ConnectContactFlowInitiationMethod, - ) - - def lambda_handler(event, context): - event: ConnectContactFlowEvent = ConnectContactFlowEvent(event) - assert event.contact_data.attributes == {"Language": "en-US"} - assert event.contact_data.channel == ConnectContactFlowChannel.VOICE - assert event.contact_data.customer_endpoint.endpoint_type == ConnectContactFlowEndpointType.TELEPHONE_NUMBER - assert event.contact_data.initiation_method == ConnectContactFlowInitiationMethod.API - ``` +```python +from aws_lambda_powertools.utilities.data_classes.connect_contact_flow_event import ( + ConnectContactFlowChannel, + ConnectContactFlowEndpointType, + ConnectContactFlowEvent, + ConnectContactFlowInitiationMethod, +) + +def lambda_handler(event, context): + event: ConnectContactFlowEvent = ConnectContactFlowEvent(event) + assert event.contact_data.attributes == {"Language": "en-US"} + assert event.contact_data.channel == ConnectContactFlowChannel.VOICE + assert event.contact_data.customer_endpoint.endpoint_type == ConnectContactFlowEndpointType.TELEPHONE_NUMBER + assert event.contact_data.initiation_method == ConnectContactFlowInitiationMethod.API +``` ### DynamoDB Streams @@ -529,34 +556,34 @@ attributes values (`AttributeValue`), as well as enums for stream view type (`St === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( - DynamoDBStreamEvent, - DynamoDBRecordEventName - ) +```python +from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( + DynamoDBStreamEvent, + DynamoDBRecordEventName +) - def lambda_handler(event, context): - event: DynamoDBStreamEvent = DynamoDBStreamEvent(event) +def lambda_handler(event, context): + event: DynamoDBStreamEvent = DynamoDBStreamEvent(event) - # Multiple records can be delivered in a single event - for record in event.records: - if record.event_name == DynamoDBRecordEventName.MODIFY: - do_something_with(record.dynamodb.new_image) - do_something_with(record.dynamodb.old_image) - ``` + # Multiple records can be delivered in a single event + for record in event.records: + if record.event_name == DynamoDBRecordEventName.MODIFY: + do_something_with(record.dynamodb.new_image) + do_something_with(record.dynamodb.old_image) +``` ### EventBridge === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import EventBridgeEvent +```python +from aws_lambda_powertools.utilities.data_classes import event_source, EventBridgeEvent - def lambda_handler(event, context): - event: EventBridgeEvent = EventBridgeEvent(event) - do_something_with(event.detail) +@event_source(data_class=EventBridgeEvent) +def lambda_handler(event: EventBridgeEvent, context): + do_something_with(event.detail) - ``` +``` ### Kinesis streams @@ -565,40 +592,40 @@ or plain text, depending on the original payload. === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import KinesisStreamEvent +```python +from aws_lambda_powertools.utilities.data_classes import event_source, KinesisStreamEvent - def lambda_handler(event, context): - event: KinesisStreamEvent = KinesisStreamEvent(event) - kinesis_record = next(event.records).kinesis +@event_source(data_class=KinesisStreamEvent) +def lambda_handler(event: KinesisStreamEvent, context): + kinesis_record = next(event.records).kinesis - # if data was delivered as text - data = kinesis_record.data_as_text() + # if data was delivered as text + data = kinesis_record.data_as_text() - # if data was delivered as json - data = kinesis_record.data_as_json() + # if data was delivered as json + data = kinesis_record.data_as_json() - do_something_with(data) - ``` + do_something_with(data) +``` ### S3 === "app.py" - ```python - from urllib.parse import unquote_plus - from aws_lambda_powertools.utilities.data_classes import S3Event +```python +from urllib.parse import unquote_plus +from aws_lambda_powertools.utilities.data_classes import event_source, S3Event - def lambda_handler(event, context): - event: S3Event = S3Event(event) - bucket_name = event.bucket_name +@event_source(data_class=S3Event) +def lambda_handler(event: S3Event, context): + bucket_name = event.bucket_name - # Multiple records can be delivered in a single event - for record in event.records: - object_key = unquote_plus(record.s3.get_object.key) + # Multiple records can be delivered in a single event + for record in event.records: + object_key = unquote_plus(record.s3.get_object.key) - do_something_with(f'{bucket_name}/{object_key}') - ``` + do_something_with(f"{bucket_name}/{object_key}") +``` ### S3 Object Lambda @@ -606,84 +633,81 @@ This example is based on the AWS Blog post [Introducing Amazon S3 Object Lambda === "app.py" - ```python hl_lines="5-6 12 14" - import boto3 - import requests +```python hl_lines="5-6 12 14" +import boto3 +import requests - from aws_lambda_powertools import Logger - from aws_lambda_powertools.logging.correlation_paths import S3_OBJECT_LAMBDA - from aws_lambda_powertools.utilities.data_classes.s3_object_event import S3ObjectLambdaEvent +from aws_lambda_powertools import Logger +from aws_lambda_powertools.logging.correlation_paths import S3_OBJECT_LAMBDA +from aws_lambda_powertools.utilities.data_classes.s3_object_event import S3ObjectLambdaEvent - logger = Logger() - session = boto3.Session() - s3 = session.client("s3") +logger = Logger() +session = boto3.Session() +s3 = session.client("s3") - @logger.inject_lambda_context(correlation_id_path=S3_OBJECT_LAMBDA, log_event=True) - def lambda_handler(event, context): - event = S3ObjectLambdaEvent(event) +@logger.inject_lambda_context(correlation_id_path=S3_OBJECT_LAMBDA, log_event=True) +def lambda_handler(event, context): + event = S3ObjectLambdaEvent(event) - # Get object from S3 - response = requests.get(event.input_s3_url) - original_object = response.content.decode("utf-8") + # Get object from S3 + response = requests.get(event.input_s3_url) + original_object = response.content.decode("utf-8") - # Make changes to the object about to be returned - transformed_object = original_object.upper() + # Make changes to the object about to be returned + transformed_object = original_object.upper() - # Write object back to S3 Object Lambda - s3.write_get_object_response( - Body=transformed_object, RequestRoute=event.request_route, RequestToken=event.request_token - ) + # Write object back to S3 Object Lambda + s3.write_get_object_response( + Body=transformed_object, RequestRoute=event.request_route, RequestToken=event.request_token + ) - return {"status_code": 200} - ``` + return {"status_code": 200} +``` ### SES === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import SESEvent +```python +from aws_lambda_powertools.utilities.data_classes import event_source, SESEvent - def lambda_handler(event, context): - event: SESEvent = SESEvent(event) +@event_source(data_class=SESEvent) +def lambda_handler(event: SESEvent, context): + # Multiple records can be delivered in a single event + for record in event.records: + mail = record.ses.mail + common_headers = mail.common_headers - # Multiple records can be delivered in a single event - for record in event.records: - mail = record.ses.mail - common_headers = mail.common_headers - - do_something_with(common_headers.to, common_headers.subject) - ``` + do_something_with(common_headers.to, common_headers.subject) +``` ### SNS === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import SNSEvent +```python +from aws_lambda_powertools.utilities.data_classes import event_source, SNSEvent - def lambda_handler(event, context): - event: SNSEvent = SNSEvent(event) +@event_source(data_class=SNSEvent) +def lambda_handler(event: SNSEvent, context): + # Multiple records can be delivered in a single event + for record in event.records: + message = record.sns.message + subject = record.sns.subject - # Multiple records can be delivered in a single event - for record in event.records: - message = record.sns.message - subject = record.sns.subject - - do_something_with(subject, message) - ``` + do_something_with(subject, message) +``` ### SQS === "app.py" - ```python - from aws_lambda_powertools.utilities.data_classes import SQSEvent - - def lambda_handler(event, context): - event: SQSEvent = SQSEvent(event) +```python +from aws_lambda_powertools.utilities.data_classes import event_source, SQSEvent - # Multiple records can be delivered in a single event - for record in event.records: - do_something_with(record.body) - ``` +@event_source(data_class=SQSEvent) +def lambda_handler(event: SQSEvent, context): + # Multiple records can be delivered in a single event + for record in event.records: + do_something_with(record.body) +``` diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index 25f76af48be..0cf19ab9de0 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -1,4 +1,5 @@ import copy +import hashlib import json import sys from hashlib import md5 @@ -7,6 +8,7 @@ import pytest from botocore import stub +from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEventV2, event_source from aws_lambda_powertools.utilities.idempotency import DynamoDBPersistenceLayer, IdempotencyConfig from aws_lambda_powertools.utilities.idempotency.exceptions import ( IdempotencyAlreadyInProgressError, @@ -19,6 +21,7 @@ from aws_lambda_powertools.utilities.idempotency.idempotency import idempotent from aws_lambda_powertools.utilities.idempotency.persistence.base import BasePersistenceLayer, DataRecord from aws_lambda_powertools.utilities.validation import envelopes, validator +from tests.functional.utils import load_event TABLE_NAME = "TEST_TABLE" @@ -223,7 +226,7 @@ def lambda_handler(event, context): def test_idempotent_lambda_first_execution_cached( idempotency_config: IdempotencyConfig, persistence_store: DynamoDBPersistenceLayer, - lambda_apigw_event: DynamoDBPersistenceLayer, + lambda_apigw_event, expected_params_update_item, expected_params_put_item, lambda_response, @@ -845,3 +848,41 @@ def handler(event, context): handler({}, lambda_context) assert "No data found to create a hashed idempotency_key" == e.value.args[0] + + +class MockPersistenceLayer(BasePersistenceLayer): + def __init__(self, expected_idempotency_key: str): + self.expected_idempotency_key = expected_idempotency_key + super(MockPersistenceLayer, self).__init__() + + def _put_record(self, data_record: DataRecord) -> None: + assert data_record.idempotency_key == self.expected_idempotency_key + + def _update_record(self, data_record: DataRecord) -> None: + assert data_record.idempotency_key == self.expected_idempotency_key + + def _get_record(self, idempotency_key) -> DataRecord: + ... + + def _delete_record(self, data_record: DataRecord) -> None: + ... + + +def test_idempotent_lambda_event_source(lambda_context): + # Scenario to validate that we can use the event_source decorator before or after the idempotent decorator + mock_event = load_event("apiGatewayProxyV2Event.json") + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + expected_result = {"message": "Foo"} + + # GIVEN an event_source decorator + # AND then an idempotent decorator + @event_source(data_class=APIGatewayProxyEventV2) + @idempotent(persistence_store=persistence_layer) + def lambda_handler(event, _): + assert isinstance(event, APIGatewayProxyEventV2) + return expected_result + + # WHEN calling the lambda handler + result = lambda_handler(mock_event, lambda_context) + # THEN we expect the handler to execute successfully + assert result == expected_result diff --git a/tests/functional/test_data_classes.py b/tests/functional/test_data_classes.py index 07648f84ee9..60dfc591897 100644 --- a/tests/functional/test_data_classes.py +++ b/tests/functional/test_data_classes.py @@ -62,6 +62,7 @@ DynamoDBStreamEvent, StreamViewType, ) +from aws_lambda_powertools.utilities.data_classes.event_source import event_source from aws_lambda_powertools.utilities.data_classes.s3_object_event import S3ObjectLambdaEvent from tests.functional.utils import load_event @@ -1237,3 +1238,15 @@ def download_file(bucket: str, key: str, tmp_name: str): } ) assert artifact_str == file_contents + + +def test_reflected_types(): + # GIVEN an event_source decorator + @event_source(data_class=APIGatewayProxyEventV2) + def lambda_handler(event: APIGatewayProxyEventV2, _): + # THEN we except the event to be of the pass in data class type + assert isinstance(event, APIGatewayProxyEventV2) + assert event.get_header_value("x-foo") == "Foo" + + # WHEN calling the lambda handler + lambda_handler({"headers": {"X-Foo": "Foo"}}, None)