Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hinting #113

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .prospector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pep8:
full: true
options:
max-line-length: 120
disable:
- E731
eulogio-gutierrez marked this conversation as resolved.
Show resolved Hide resolved

pep257:
disable:
Expand Down
41 changes: 23 additions & 18 deletions aws_lambda_decorators/classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""All the classes used as parameters for the decorators."""
from typing import Any, Callable, List, Tuple

from aws_lambda_decorators.decoders import decode
from aws_lambda_decorators.utils import is_valid_variable_name
from aws_lambda_decorators.validators import Validator


PATH_DIVIDER = "/"
ANNOTATIONS_START = "["
Expand All @@ -10,37 +13,38 @@
class ExceptionHandler:
"""Class mapping a friendly error message to a given Exception."""

def __init__(self, exception, friendly_message=None, status_code=400):
def __init__(self, exception: Exception, friendly_message: str = None, status_code: int = 400):
"""
Sets the private variables of the ExceptionHandler object.

Args:
exception (object|Exception): An exception to be handled.
friendly_message (str): Friendly Message to be returned if the exception is caught.
status_code (str): HTTP status code to be returned if the exception is caught.
"""
self._exception = exception
self._friendly_message = friendly_message
self._status_code = status_code

@property
def friendly_message(self):
def friendly_message(self) -> str:
"""Getter for the friendly message parameter."""
return self._friendly_message

@property
def exception(self):
def exception(self) -> Exception:
"""Getter for the exception parameter."""
return self._exception

@property
def status_code(self):
def status_code(self) -> int:
"""Getter for the status code parameter."""
return self._status_code


class BaseParameter: # noqa: pylint - too-few-public-methods
"""Parent class of all parameter classes."""
def __init__(self, var_name):
def __init__(self, var_name: str):
"""
Set the private variables of the BaseParameter object.

Expand All @@ -49,7 +53,7 @@ def __init__(self, var_name):
"""
self._name = var_name

def get_var_name(self):
def get_var_name(self) -> str:
"""Gets the name of the variable that represents the parameter."""
if self._name and not is_valid_variable_name(self._name):
raise SyntaxError(self._name)
Expand All @@ -59,7 +63,7 @@ def get_var_name(self):
class SSMParameter(BaseParameter):
"""Class used for defining the key and, optionally, the variable name for ssm parameter extraction."""

def __init__(self, ssm_name, var_name=None):
def __init__(self, ssm_name: str, var_name: str = None):
"""
Set the private variables of the SSMParameter object.

Expand All @@ -70,15 +74,15 @@ def __init__(self, ssm_name, var_name=None):
self._ssm_name = ssm_name
BaseParameter.__init__(self, var_name if var_name else ssm_name)

def get_ssm_name(self):
def get_ssm_name(self) -> str:
"""Getter for the ssm_name parameter."""
return self._ssm_name


class ValidatedParameter:
"""Class used to encapsulate the validation methods parameter data."""

def __init__(self, func_param_name=None, validators=None):
def __init__(self, func_param_name: str = None, validators: List[Validator] = None):
"""
Sets the private variables of the ValidatedParameter object.
Args:
Expand All @@ -91,16 +95,16 @@ def fun(event, context). To extract from context func_param_name has to be "cont
self._validators = validators

@property
def func_param_name(self):
def func_param_name(self) -> str:
"""Getter for the func_param_name parameter."""
return self._func_param_name

@func_param_name.setter
def func_param_name(self, value):
def func_param_name(self, value: str):
"""Setter for the func_param_name parameter."""
self._func_param_name = value

def validate(self, value, group_errors):
def validate(self, value: Any, group_errors: bool) -> List[str]:
"""
Validates a value against the passed in validators

Expand Down Expand Up @@ -130,7 +134,8 @@ def validate(self, value, group_errors):
class Parameter(ValidatedParameter, BaseParameter):
"""Class used to encapsulate the extract methods parameter data."""

def __init__(self, path="", func_param_name=None, validators=None, var_name=None, default=None, transform=None): # noqa: pylint - too-many-arguments
def __init__(self, path="", func_param_name: str = None, validators: List[Validator] = None, # noqa: pylint - too-many-arguments
var_name: str = None, default: Any = None, transform: Callable = None):
"""
Sets the private variables of the Parameter object.

Expand Down Expand Up @@ -163,11 +168,11 @@ def fun(event, context). To extract from context func_param_name has to be "cont
BaseParameter.__init__(self, var_name)

@property
def path(self):
def path(self) -> str:
"""Getter for the path parameter."""
return self._path

def extract_value(self, dict_value):
def extract_value(self, dict_value: dict) -> Any:
"""
Calculate and decode the value of the variable in the given path.

Expand All @@ -194,7 +199,7 @@ def extract_value(self, dict_value):

return dict_value

def validate_path(self, value, group_errors=False):
def validate_path(self, value: Any, group_errors: bool = False) -> List[str]:
"""
Validates a value against the passed in validators

Expand All @@ -213,7 +218,7 @@ def validate_path(self, value, group_errors=False):
return {key: errors} if errors else {}

@staticmethod
def get_annotations_from_key(key):
def get_annotations_from_key(key: str) -> Tuple[str, str]:
"""
Extract the key and the encoding type (annotation) from the string.

Expand Down
8 changes: 5 additions & 3 deletions aws_lambda_decorators/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import json
import logging
import sys

import jwt


LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)

DECODE_FUNC_NAME = "decode_%s"
DECODE_FUNC_MISSING_ERROR = "Missing decode function for annotation: %s"


def decode(annotation, value):
def decode(annotation: str, value: str) -> dict:
"""
Converts an annotated string to a python dictionary.

Expand Down Expand Up @@ -47,12 +49,12 @@ def decode(annotation, value):


@functools.lru_cache()
def decode_json(value):
def decode_json(value: str) -> dict:
"""Convert a json to a dictionary."""
return json.loads(value)


@functools.lru_cache()
def decode_jwt(value):
def decode_jwt(value: str) -> dict:
"""Convert a jwt to a dictionary."""
return jwt.decode(value, verify=False)
29 changes: 18 additions & 11 deletions aws_lambda_decorators/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
A set of Python decorators to ease the development of AWS lambda functions.

"""
import json
from http import HTTPStatus
import json
from typing import Callable, List

import boto3

from aws_lambda_decorators.classes import BaseParameter, ExceptionHandler, SSMParameter, ValidatedParameter
from aws_lambda_decorators.utils import full_name, all_func_args, find_key_case_insensitive, failure, get_logger


Expand All @@ -27,7 +31,8 @@
UNKNOWN = "Unknown"


def extract_from_event(parameters, group_errors=False, allow_none_defaults=False):
def extract_from_event(parameters: List[BaseParameter], group_errors: bool = False,
allow_none_defaults: bool = False) -> Callable:
"""
Extracts a set of parameters from the event dictionary in a lambda handler.

Expand All @@ -48,7 +53,8 @@ def lambda_handler(event, context, my_param=None)
return extract(parameters, group_errors, allow_none_defaults)


def extract_from_context(parameters, group_errors=False, allow_none_defaults=False):
def extract_from_context(parameters: List[BaseParameter], group_errors: bool = False,
allow_none_defaults: bool = False):
"""
Extracts a set of parameters from the context dictionary in a lambda handler.

Expand All @@ -69,7 +75,7 @@ def lambda_handler(event, context, my_param=None)
return extract(parameters, group_errors, allow_none_defaults)


def extract(parameters, group_errors=False, allow_none_defaults=False):
def extract(parameters: List[BaseParameter], group_errors: bool = False, allow_none_defaults: bool = False) -> Callable:
"""
Extracts a set of parameters from any function parameter passed to an AWS lambda handler.

Expand Down Expand Up @@ -120,7 +126,7 @@ def wrapper(*args, **kwargs):
return decorator


def handle_exceptions(handlers):
def handle_exceptions(handlers: List[ExceptionHandler]) -> Callable:
"""
Handles exceptions thrown by the wrapped/decorated function.

Expand Down Expand Up @@ -150,7 +156,7 @@ def wrapper(*args, **kwargs):
return decorator


def log(parameters=False, response=False):
def log(parameters: bool = False, response: bool = False) -> Callable:
"""
Log parameters and/or response of the wrapped/decorated function using logging package

Expand All @@ -170,7 +176,7 @@ def wrapper(*args, **kwargs):
return decorator


def extract_from_ssm(ssm_parameters):
def extract_from_ssm(ssm_parameters: List[SSMParameter]) -> Callable:
"""
Load given ssm parameters from AWS parameter store to the handler variables.

Expand Down Expand Up @@ -198,7 +204,7 @@ def wrapper(*args, **kwargs):
return decorator


def response_body_as_json(func):
def response_body_as_json(func: Callable) -> Callable:
"""
Convert the dictionary response of the wrapped/decorated function to a json string literal.

Expand All @@ -220,7 +226,7 @@ def wrapper(*args, **kwargs):
return wrapper


def validate(parameters, group_errors=False):
def validate(parameters: ValidatedParameter, group_errors: bool = False) -> Callable:
"""
Validates a set of function parameters.

Expand Down Expand Up @@ -260,7 +266,7 @@ def wrapper(*args, **kwargs):
return decorator


def handle_all_exceptions():
def handle_all_exceptions() -> Callable:
"""
Handles all exceptions thrown by the wrapped/decorated function.

Expand All @@ -280,7 +286,8 @@ def wrapper(*args, **kwargs):
return decorator


def cors(allow_origin=None, allow_methods=None, allow_headers=None, max_age=None):
def cors(allow_origin: str = None, allow_methods: str = None, allow_headers: str = None,
max_age: int = None) -> Callable:
"""
Adds CORS headers to the response of the decorated function

Expand Down
32 changes: 20 additions & 12 deletions aws_lambda_decorators/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
"""Utility functions."""
import logging
import os
import json
from http import HTTPStatus
import keyword
import inspect
import json
import keyword
import logging
from logging import Logger
import os
from typing import Any, Callable, Dict, List
from unicodedata import normalize as normalise


LOG_LEVEL = getattr(logging, os.getenv("LOG_LEVEL", "INFO"))


def get_logger(name):
def get_logger(name: str) -> Logger:
logger = logging.getLogger(name)
logger.setLevel(LOG_LEVEL)
return logger


def full_name(class_type):
def full_name(class_type: type) -> str:
"""
Gets the fully qualified name of a class type.

Expand All @@ -34,7 +37,7 @@ def full_name(class_type):
return f"{module}.{class_type.__class__.__name__}"


def is_type_in_list(item_type, items):
def is_type_in_list(item_type: type, items: list) -> bool:
"""
Checks if there is an item of a given type in the list of items.

Expand All @@ -48,7 +51,7 @@ def is_type_in_list(item_type, items):
return any(isinstance(item, item_type) for item in items)


def is_valid_variable_name(name):
def is_valid_variable_name(name: str) -> bool:
"""
Check if the given name is python allowed variable name.

Expand All @@ -61,7 +64,7 @@ def is_valid_variable_name(name):
return name.isidentifier() and not keyword.iskeyword(name)


def all_func_args(func, args, kwargs):
def all_func_args(func: Callable, args: list, kwargs: dict) -> dict:
"""
Combine arguments and key word arguments to a dictionary.

Expand All @@ -79,7 +82,7 @@ def all_func_args(func, args, kwargs):
return arg_dictionary


def find_key_case_insensitive(key_name, the_dict):
def find_key_case_insensitive(key_name: str, the_dict: Dict[str, Any]) -> str:
"""
Finds if a dictionary (the_dict) has a string key (key_name) in any string case

Expand All @@ -91,13 +94,18 @@ def find_key_case_insensitive(key_name, the_dict):
The found key name in its original case, if found. Otherwise, returns the searching key name

"""
desentisise = lambda key: normalise('NFC', key).casefold()

key_name_lower = desentisise(key_name)

for key in the_dict:
if key.lower() == key_name:
if desentisise(key) == key_name_lower:
return key

return key_name


def failure(errors, status_code=HTTPStatus.BAD_REQUEST):
def failure(errors: List[str], status_code: int = HTTPStatus.BAD_REQUEST) -> dict:
"""
Returns an error to the caller

Expand Down
Loading