Skip to content

Commit

Permalink
fix(typing): improve overloads to ensure the return type follows the …
Browse files Browse the repository at this point in the history
…default_value type (aws-powertools#4114)
  • Loading branch information
Wurstnase authored Apr 12, 2024
1 parent 32e733b commit c0622a5
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,17 @@ def path_parameters(self) -> Optional[Dict[str, str]]:
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")

@overload
def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, overload

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
Expand Down Expand Up @@ -214,6 +214,22 @@ def stash(self) -> Optional[dict]:
a pipeline resolver."""
return self.get("stash")

@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
Expand Down
6 changes: 6 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ def http_method(self) -> str:
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
return self["httpMethod"]

@overload
def get_query_string_value(self, name: str, default_value: str) -> str: ...

@overload
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...

def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
"""Get query string value by name
Expand Down
20 changes: 18 additions & 2 deletions aws_lambda_powertools/utilities/data_classes/kafka_event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
from functools import cached_property
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, overload

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
Expand Down Expand Up @@ -69,10 +69,26 @@ def decoded_headers(self) -> Dict[str, bytes]:
"""Decodes the headers as a single dictionary."""
return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()}

@overload
def get_header_value(
self,
name: str,
default_value: Optional[Any] = None,
default_value: str,
case_sensitive: bool = True,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = True,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = True,
) -> Optional[str]:
"""Get a decoded header value by name."""
Expand Down
18 changes: 17 additions & 1 deletion aws_lambda_powertools/utilities/data_classes/s3_object_event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, overload

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
Expand Down Expand Up @@ -73,6 +73,22 @@ def headers(self) -> Dict[str, str]:
The case of the original headers is retained in this map."""
return self["headers"]

@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
Expand Down
42 changes: 38 additions & 4 deletions aws_lambda_powertools/utilities/data_classes/shared_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import base64
from typing import Any, Dict
from typing import Any, Dict, overload


def base64_decode(value: str) -> str:
Expand All @@ -21,11 +21,29 @@ def base64_decode(value: str) -> str:
return base64.b64decode(value).decode("UTF-8")


@overload
def get_header_value(
headers: dict[str, Any],
name: str,
default_value: str | None,
case_sensitive: bool | None,
default_value: str,
case_sensitive: bool | None = False,
) -> str: ...


@overload
def get_header_value(
headers: dict[str, Any],
name: str,
default_value: str | None = None,
case_sensitive: bool | None = False,
) -> str | None: ...


def get_header_value(
headers: dict[str, Any],
name: str,
default_value: str | None = None,
case_sensitive: bool | None = False,
) -> str | None:
"""
Get the value of a header by its name.
Expand All @@ -39,7 +57,7 @@ def get_header_value(
default_value: str, optional
The default value to return if the header is not found. Default is None.
case_sensitive: bool, optional
Indicates whether the header name should be case-sensitive. Default is None.
Indicates whether the header name should be case-sensitive. Default is False.
Returns
-------
Expand All @@ -62,6 +80,22 @@ def get_header_value(
)


@overload
def get_query_string_value(
query_string_parameters: Dict[str, str] | None,
name: str,
default_value: str,
) -> str: ...


@overload
def get_query_string_value(
query_string_parameters: Dict[str, str] | None,
name: str,
default_value: str | None = None,
) -> str | None: ...


def get_query_string_value(
query_string_parameters: Dict[str, str] | None,
name: str,
Expand Down
6 changes: 6 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def http_method(self) -> str:
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
return self["method"]

@overload
def get_query_string_value(self, name: str, default_value: str) -> str: ...

@overload
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...

def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
"""Get query string value by name
Expand Down
4 changes: 2 additions & 2 deletions examples/event_handler_graphql/src/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class Location(TypedDict, total=False):
class MyCustomModel(AppSyncResolverEvent):
@property
def country_viewer(self) -> str:
return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501
return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False)

@property
def api_key(self) -> str:
return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501
return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False)


@app.resolver(type_name="Query", field_name="listLocations")
Expand Down

0 comments on commit c0622a5

Please sign in to comment.