diff --git a/aws_lambda_decorators/classes.py b/aws_lambda_decorators/classes.py index 99f19f2..5d39312 100644 --- a/aws_lambda_decorators/classes.py +++ b/aws_lambda_decorators/classes.py @@ -1,6 +1,9 @@ -"""All the classes used as parameters for the decorators.""" +from typing import Any, Callable, List, Tuple + from aws_lambda_decorators.decoders import decode from aws_lambda_decorators.utils import is_valid_variable_name +from aws_lambda_decorators.validators import Validator + PATH_DIVIDER = "/" ANNOTATIONS_START = "[" @@ -10,37 +13,38 @@ class ExceptionHandler: """Class mapping a friendly error message to a given Exception.""" - def __init__(self, exception, friendly_message=None, status_code=400): + def __init__(self, exception: Exception, friendly_message: str = None, status_code: int = 400): """ Sets the private variables of the ExceptionHandler object. Args: exception (object|Exception): An exception to be handled. friendly_message (str): Friendly Message to be returned if the exception is caught. + status_code (str): HTTP status code to be returned if the exception is caught. """ self._exception = exception self._friendly_message = friendly_message self._status_code = status_code @property - def friendly_message(self): + def friendly_message(self) -> str: """Getter for the friendly message parameter.""" return self._friendly_message @property - def exception(self): + def exception(self) -> Exception: """Getter for the exception parameter.""" return self._exception @property - def status_code(self): + def status_code(self) -> int: """Getter for the status code parameter.""" return self._status_code class BaseParameter: # noqa: pylint - too-few-public-methods """Parent class of all parameter classes.""" - def __init__(self, var_name): + def __init__(self, var_name: str): """ Set the private variables of the BaseParameter object. @@ -49,7 +53,7 @@ def __init__(self, var_name): """ self._name = var_name - def get_var_name(self): + def get_var_name(self) -> str: """Gets the name of the variable that represents the parameter.""" if self._name and not is_valid_variable_name(self._name): raise SyntaxError(self._name) @@ -59,7 +63,7 @@ def get_var_name(self): class SSMParameter(BaseParameter): """Class used for defining the key and, optionally, the variable name for ssm parameter extraction.""" - def __init__(self, ssm_name, var_name=None): + def __init__(self, ssm_name: str, var_name: str = None): """ Set the private variables of the SSMParameter object. @@ -70,7 +74,7 @@ def __init__(self, ssm_name, var_name=None): self._ssm_name = ssm_name BaseParameter.__init__(self, var_name if var_name else ssm_name) - def get_ssm_name(self): + def get_ssm_name(self) -> str: """Getter for the ssm_name parameter.""" return self._ssm_name @@ -78,7 +82,7 @@ def get_ssm_name(self): class ValidatedParameter: """Class used to encapsulate the validation methods parameter data.""" - def __init__(self, func_param_name=None, validators=None): + def __init__(self, func_param_name: str = None, validators: List[Validator] = None): """ Sets the private variables of the ValidatedParameter object. Args: @@ -91,16 +95,16 @@ def fun(event, context). To extract from context func_param_name has to be "cont self._validators = validators @property - def func_param_name(self): + def func_param_name(self) -> str: """Getter for the func_param_name parameter.""" return self._func_param_name @func_param_name.setter - def func_param_name(self, value): + def func_param_name(self, value: str): """Setter for the func_param_name parameter.""" self._func_param_name = value - def validate(self, value, group_errors): + def validate(self, value: Any, group_errors: bool) -> List[str]: """ Validates a value against the passed in validators @@ -130,7 +134,8 @@ def validate(self, value, group_errors): class Parameter(ValidatedParameter, BaseParameter): """Class used to encapsulate the extract methods parameter data.""" - def __init__(self, path="", func_param_name=None, validators=None, var_name=None, default=None, transform=None): # noqa: pylint - too-many-arguments + def __init__(self, path="", func_param_name: str = None, validators: List[Validator] = None, # noqa: pylint - too-many-arguments + var_name: str = None, default: Any = None, transform: Callable = None): """ Sets the private variables of the Parameter object. @@ -163,11 +168,11 @@ def fun(event, context). To extract from context func_param_name has to be "cont BaseParameter.__init__(self, var_name) @property - def path(self): + def path(self) -> str: """Getter for the path parameter.""" return self._path - def extract_value(self, dict_value): + def extract_value(self, dict_value: dict) -> Any: """ Calculate and decode the value of the variable in the given path. @@ -194,7 +199,7 @@ def extract_value(self, dict_value): return dict_value - def validate_path(self, value, group_errors=False): + def validate_path(self, value: Any, group_errors: bool = False) -> List[str]: """ Validates a value against the passed in validators @@ -213,7 +218,7 @@ def validate_path(self, value, group_errors=False): return {key: errors} if errors else {} @staticmethod - def get_annotations_from_key(key): + def get_annotations_from_key(key: str) -> Tuple[str, str]: """ Extract the key and the encoding type (annotation) from the string. diff --git a/aws_lambda_decorators/decoders.py b/aws_lambda_decorators/decoders.py index 3c1bdd1..a3542f5 100644 --- a/aws_lambda_decorators/decoders.py +++ b/aws_lambda_decorators/decoders.py @@ -3,8 +3,10 @@ import json import logging import sys + import jwt + LOGGER = logging.getLogger() LOGGER.setLevel(logging.INFO) @@ -12,7 +14,7 @@ DECODE_FUNC_MISSING_ERROR = "Missing decode function for annotation: %s" -def decode(annotation, value): +def decode(annotation: str, value: str) -> dict: """ Converts an annotated string to a python dictionary. @@ -47,12 +49,12 @@ def decode(annotation, value): @functools.lru_cache() -def decode_json(value): +def decode_json(value: str) -> dict: """Convert a json to a dictionary.""" return json.loads(value) @functools.lru_cache() -def decode_jwt(value): +def decode_jwt(value: str) -> dict: """Convert a jwt to a dictionary.""" return jwt.decode(value, verify=False) diff --git a/aws_lambda_decorators/decorators.py b/aws_lambda_decorators/decorators.py index 3d5a8f0..0944468 100644 --- a/aws_lambda_decorators/decorators.py +++ b/aws_lambda_decorators/decorators.py @@ -4,9 +4,13 @@ A set of Python decorators to ease the development of AWS lambda functions. """ -import json from http import HTTPStatus +import json +from typing import Callable, List + import boto3 + +from aws_lambda_decorators.classes import BaseParameter, ExceptionHandler, SSMParameter, ValidatedParameter from aws_lambda_decorators.utils import (full_name, all_func_args, find_key_case_insensitive, failure, get_logger, find_websocket_connection_id, get_websocket_endpoint) @@ -30,7 +34,8 @@ UNKNOWN = "Unknown" -def extract_from_event(parameters, group_errors=False, allow_none_defaults=False): +def extract_from_event(parameters: List[BaseParameter], group_errors: bool = False, + allow_none_defaults: bool = False) -> Callable: """ Extracts a set of parameters from the event dictionary in a lambda handler. @@ -51,7 +56,8 @@ def lambda_handler(event, context, my_param=None) return extract(parameters, group_errors, allow_none_defaults) -def extract_from_context(parameters, group_errors=False, allow_none_defaults=False): +def extract_from_context(parameters: List[BaseParameter], group_errors: bool = False, + allow_none_defaults: bool = False): """ Extracts a set of parameters from the context dictionary in a lambda handler. @@ -72,7 +78,7 @@ def lambda_handler(event, context, my_param=None) return extract(parameters, group_errors, allow_none_defaults) -def extract(parameters, group_errors=False, allow_none_defaults=False): +def extract(parameters: List[BaseParameter], group_errors: bool = False, allow_none_defaults: bool = False) -> Callable: """ Extracts a set of parameters from any function parameter passed to an AWS lambda handler. @@ -123,7 +129,7 @@ def wrapper(*args, **kwargs): return decorator -def handle_exceptions(handlers): +def handle_exceptions(handlers: List[ExceptionHandler]) -> Callable: """ Handles exceptions thrown by the wrapped/decorated function. @@ -135,8 +141,8 @@ def lambda_handler(params) Args: handlers (list): A collection of ExceptionHandler type items. """ - def decorator(func): - def wrapper(*args, **kwargs): + def decorator(func: Callable) -> Callable: + def wrapper(*args, **kwargs) -> Callable: try: return func(*args, **kwargs) except tuple(handler.exception for handler in handlers) as ex: # noqa: pylint - catching-non-exception @@ -153,7 +159,7 @@ def wrapper(*args, **kwargs): return decorator -def log(parameters=False, response=False): +def log(parameters: bool = False, response: bool = False) -> Callable: """ Log parameters and/or response of the wrapped/decorated function using logging package @@ -173,7 +179,7 @@ def wrapper(*args, **kwargs): return decorator -def extract_from_ssm(ssm_parameters): +def extract_from_ssm(ssm_parameters: List[SSMParameter]) -> Callable: """ Load given ssm parameters from AWS parameter store to the handler variables. @@ -201,7 +207,7 @@ def wrapper(*args, **kwargs): return decorator -def response_body_as_json(func): +def response_body_as_json(func: Callable) -> Callable: """ Convert the dictionary response of the wrapped/decorated function to a json string literal. @@ -223,7 +229,7 @@ def wrapper(*args, **kwargs): return wrapper -def validate(parameters, group_errors=False): +def validate(parameters: ValidatedParameter, group_errors: bool = False) -> Callable: """ Validates a set of function parameters. @@ -263,7 +269,7 @@ def wrapper(*args, **kwargs): return decorator -def handle_all_exceptions(): +def handle_all_exceptions() -> Callable: """ Handles all exceptions thrown by the wrapped/decorated function. @@ -283,7 +289,8 @@ def wrapper(*args, **kwargs): return decorator -def cors(allow_origin=None, allow_methods=None, allow_headers=None, max_age=None): +def cors(allow_origin: str = None, allow_methods: str = None, allow_headers: str = None, + max_age: int = None) -> Callable: """ Adds CORS headers to the response of the decorated function diff --git a/aws_lambda_decorators/utils.py b/aws_lambda_decorators/utils.py index cefaa29..c25e33a 100644 --- a/aws_lambda_decorators/utils.py +++ b/aws_lambda_decorators/utils.py @@ -1,11 +1,15 @@ """Utility functions.""" + from functools import lru_cache from http import HTTPStatus import inspect import json import keyword import logging +from logging import Logger import os +from typing import Any, Callable, Dict, List +from unicodedata import normalize as normalise import boto3 @@ -13,13 +17,13 @@ LOG_LEVEL = getattr(logging, os.getenv("LOG_LEVEL", "INFO")) -def get_logger(name): +def get_logger(name: str) -> Logger: logger = logging.getLogger(name) logger.setLevel(LOG_LEVEL) return logger -def full_name(class_type): +def full_name(class_type: type) -> str: """ Gets the fully qualified name of a class type. @@ -37,7 +41,7 @@ def full_name(class_type): return f"{module}.{class_type.__class__.__name__}" -def is_type_in_list(item_type, items): +def is_type_in_list(item_type: type, items: list) -> bool: """ Checks if there is an item of a given type in the list of items. @@ -51,7 +55,7 @@ def is_type_in_list(item_type, items): return any(isinstance(item, item_type) for item in items) -def is_valid_variable_name(name): +def is_valid_variable_name(name: str) -> bool: """ Check if the given name is python allowed variable name. @@ -64,7 +68,7 @@ def is_valid_variable_name(name): return name.isidentifier() and not keyword.iskeyword(name) -def all_func_args(func, args, kwargs): +def all_func_args(func: Callable, args: list, kwargs: dict) -> dict: """ Combine arguments and key word arguments to a dictionary. @@ -82,7 +86,7 @@ def all_func_args(func, args, kwargs): return arg_dictionary -def find_key_case_insensitive(key_name, the_dict): +def find_key_case_insensitive(key_name: str, the_dict: Dict[str, Any]) -> str: """ Finds if a dictionary (the_dict) has a string key (key_name) in any string case @@ -94,13 +98,19 @@ def find_key_case_insensitive(key_name, the_dict): The found key name in its original case, if found. Otherwise, returns the searching key name """ + def desensitise(txt: str) -> str: + return normalise("NFC", txt).casefold() + + key_name_lower = desensitise(key_name) + for key in the_dict: - if key.lower() == key_name: + if desensitise(key) == key_name_lower: return key + return key_name -def failure(errors, status_code=HTTPStatus.BAD_REQUEST): +def failure(errors: List[str], status_code: int = HTTPStatus.BAD_REQUEST) -> dict: """ Returns an error to the caller diff --git a/aws_lambda_decorators/validators.py b/aws_lambda_decorators/validators.py index d242764..b16184b 100644 --- a/aws_lambda_decorators/validators.py +++ b/aws_lambda_decorators/validators.py @@ -1,7 +1,10 @@ """Validation rules.""" import datetime import re -from schema import SchemaError +from typing import Any + +from schema import Schema, SchemaError + CURRENCIES = {"LKR", "ETB", "RWF", "NZD", "SBD", "MKD", "NPR", "LAK", "KWD", "INR", "HUF", "AFN", "BTN", "ISK", "MVR", "WST", "MNT", "AZN", "SAR", "JMD", "BIF", "BMD", "CAD", "GEL", "MXN", "BHD", "HKD", "RSD", "PKR", "SLL", @@ -21,7 +24,7 @@ class Validator: # noqa: pylint - too-few-public-methods """Validation rule to check if the given mandatory value exists.""" ERROR_MESSAGE = "Unknown error" - def __init__(self, error_message, condition=None): + def __init__(self, error_message: str, condition: Any = None): """ Validates a parameter @@ -32,7 +35,7 @@ def __init__(self, error_message, condition=None): self._error_message = error_message or self.ERROR_MESSAGE self._condition = condition - def message(self, value=None): # noqa: pylint - unused-argument + def message(self, value: Any = None) -> str: """ Gets the formatted error message for a failed mandatory check @@ -49,7 +52,7 @@ class Mandatory(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if the given mandatory value exists.""" ERROR_MESSAGE = "Missing mandatory value" - def __init__(self, error_message=None): + def __init__(self, error_message: str = None): """ Checks if a parameter has a value @@ -59,7 +62,7 @@ def __init__(self, error_message=None): super().__init__(error_message) @staticmethod - def validate(value=None): + def validate(value: Any = None) -> bool: """ Check if the given mandatory value exists. @@ -73,7 +76,7 @@ class RegexValidator(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if a value matches a regular expression.""" ERROR_MESSAGE = "'{value}' does not conform to regular expression '{condition}'" - def __init__(self, regex="", error_message=None): + def __init__(self, regex: str = "", error_message: str = None): """ Compile a regular expression to a regular expression pattern. @@ -84,7 +87,7 @@ def __init__(self, regex="", error_message=None): super().__init__(error_message, regex) self._regexp = re.compile(regex) - def validate(self, value=None): + def validate(self, value: str = None) -> bool: """ Check if a value adheres to the defined regular expression. @@ -101,7 +104,7 @@ class SchemaValidator(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if a value matches a regular expression.""" ERROR_MESSAGE = "'{value}' does not validate against schema '{condition}'" - def __init__(self, schema, error_message=None): + def __init__(self, schema: Schema, error_message: str = None): """ Set the schema field. @@ -111,7 +114,7 @@ def __init__(self, schema, error_message=None): """ super().__init__(error_message, schema) - def validate(self, value=None): + def validate(self, value: Any = None) -> bool: """ Check if the object adheres to the defined schema. @@ -131,7 +134,7 @@ class Minimum(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if a value is greater than a minimum value.""" ERROR_MESSAGE = "'{value}' is less than minimum value '{condition}'" - def __init__(self, minimum: (float, int), error_message=None): + def __init__(self, minimum: (float, int), error_message: str = None): """ Set the minimum value. @@ -141,7 +144,7 @@ def __init__(self, minimum: (float, int), error_message=None): """ super().__init__(error_message, minimum) - def validate(self, value=None): + def validate(self, value: (float, int) = None) -> bool: # pylint:disable=bad-whitespace """ Check if the value is greater than the minimum. @@ -161,7 +164,7 @@ class Maximum(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if a value is less than a maximum value.""" ERROR_MESSAGE = "'{value}' is greater than maximum value '{condition}'" - def __init__(self, maximum: (float, int), error_message=None): + def __init__(self, maximum: (float, int), error_message: str = None): """ Set the maximum value. @@ -171,7 +174,7 @@ def __init__(self, maximum: (float, int), error_message=None): """ super().__init__(error_message, maximum) - def validate(self, value=None): + def validate(self, value: (float, int) = None) -> bool: # pylint:disable=bad-whitespace """ Check if the value is less than the maximum. @@ -191,7 +194,7 @@ class MinLength(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if a string is shorter than a minimum length.""" ERROR_MESSAGE = "'{value}' is shorter than minimum length '{condition}'" - def __init__(self, min_length: int, error_message=None): + def __init__(self, min_length: int, error_message: str = None): """ Set the minimum length. @@ -201,7 +204,7 @@ def __init__(self, min_length: int, error_message=None): """ super().__init__(error_message, min_length) - def validate(self, value=None): + def validate(self, value: str = None) -> bool: """ Check if a string is shorter than the minimum length. @@ -218,7 +221,7 @@ class MaxLength(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if a string is longer than a maximum length.""" ERROR_MESSAGE = "'{value}' is longer than maximum length '{condition}'" - def __init__(self, max_length: int, error_message=None): + def __init__(self, max_length: int, error_message: str = None): """ Set the maximum length. @@ -228,7 +231,7 @@ def __init__(self, max_length: int, error_message=None): """ super().__init__(error_message, max_length) - def validate(self, value=None): + def validate(self, value: str = None) -> bool: """ Check if a string is longer than the maximum length. @@ -244,7 +247,7 @@ def validate(self, value=None): class Type(Validator): ERROR_MESSAGE = "'{value}' is not of type '{condition.__name__}'" - def __init__(self, valid_type: type, error_message=None): + def __init__(self, valid_type: type, error_message: str = None): """ Set the valid type. @@ -254,7 +257,7 @@ def __init__(self, valid_type: type, error_message=None): """ super().__init__(error_message, valid_type) - def validate(self, value=None): + def validate(self, value: Any = None) -> bool: """ Check if a value is of the right type @@ -270,7 +273,7 @@ def validate(self, value=None): class EnumValidator(Validator): ERROR_MESSAGE = "'{value}' is not in list '{condition}'" - def __init__(self, *args: list, error_message=None): + def __init__(self, *args: list, error_message: str = None): """ Set the list of valid values. @@ -280,7 +283,7 @@ def __init__(self, *args: list, error_message=None): """ super().__init__(error_message, args) - def validate(self, value=None): + def validate(self, value: Any = None) -> bool: """ Check if a value is in a list of valid values @@ -297,7 +300,7 @@ class NonEmpty(Validator): # noqa: pylint - too-few-public-methods """Validation rule to check if the given value is empty.""" ERROR_MESSAGE = "Value is empty" - def __init__(self, error_message=None): + def __init__(self, error_message: str = None): """ Checks if a parameter has a non empty value @@ -307,7 +310,7 @@ def __init__(self, error_message=None): super().__init__(error_message) @staticmethod - def validate(value=None): + def validate(value: Any = None) -> bool: """ Check if the given value is non empty. @@ -324,7 +327,7 @@ class DateValidator(Validator): """Validation rule to check if a string is a valid date according to some format.""" ERROR_MESSAGE = "'{value}' is not a '{condition}' date" - def __init__(self, date_format: str, error_message=None): + def __init__(self, date_format: str, error_message: str = None): """ Checks if a string is a date with a given format @@ -334,7 +337,7 @@ def __init__(self, date_format: str, error_message=None): """ super().__init__(error_message, date_format) - def validate(self, value=None): + def validate(self, value: str = None) -> bool: """ Check if a string is a date with a given format @@ -356,7 +359,7 @@ class CurrencyValidator(Validator): """Validation rule to check if a string is a valid currency according to ISO 4217 Currency Code.""" ERROR_MESSAGE = "'{value}' is not a valid currency code." - def __init__(self, error_message=None): + def __init__(self, error_message: str = None): """ Checks if a string is a valid currency based on ISO 4217 @@ -366,7 +369,7 @@ def __init__(self, error_message=None): super().__init__(error_message) @staticmethod - def validate(value=None): + def validate(value: str = None) -> bool: """ Check if a string is a valid currency based on ISO 4217 diff --git a/buildspec.yml b/buildspec.yml index 34cbcea..cd52163 100644 --- a/buildspec.yml +++ b/buildspec.yml @@ -4,12 +4,13 @@ phases: install: commands: - pip install --upgrade pip - - pip install -q boto3 bandit coverage==4.5.4 schema pylint_quotes prospector PyJWT==1.7.1 + - pip install -q boto3 bandit coverage==4.5.4 schema pylint_quotes prospector==1.3.1 PyJWT==1.7.1 mypy pre_build: commands: - export LOG_LEVEL=CRITICAL - export OUR_COMMIT_SHA=`git rev-parse HEAD` - bandit -r -q . + - mypy - prospector - coverage run --source='.' -m unittest - coverage report -m --fail-under=100 --omit=*/__init__.py,tests/*,setup.py,examples/* diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..45332da --- /dev/null +++ b/mypy.ini @@ -0,0 +1,11 @@ +[mypy] +python_version = 3.7 +files = aws_lambda_decorators/*.py +warn_return_any = True +warn_unused_configs = True +namespace_packages = True +strict_optional = False +disallow_untyped_calls = True +disallow_untyped_defs = True +ignore_missing_imports = True +explicit_package_bases = True \ No newline at end of file diff --git a/setup.py b/setup.py index b6377ec..b1750a8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ LONG_DESCRIPTION = open("README.md").read() setup(name="aws-lambda-decorators", - version="0.51", + version="0.52", description="A set of python decorators to simplify aws python lambda development", long_description=LONG_DESCRIPTION, long_description_content_type="text/markdown",