Skip to content

Commit

Permalink
chore: adding types part 1 (#2746)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaythapa authored Dec 16, 2022
1 parent 0bca270 commit a7d8ae5
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 27 deletions.
6 changes: 3 additions & 3 deletions samtranslator/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,20 +533,20 @@ def __init__(self, *modules): # type: ignore[no-untyped-def]
):
self.resource_types[resource_class.resource_type] = resource_class

def can_resolve(self, resource_dict): # type: ignore[no-untyped-def]
def can_resolve(self, resource_dict: Dict[str, Any]) -> bool:
if not isinstance(resource_dict, dict) or not isinstance(resource_dict.get("Type"), str):
return False

return resource_dict["Type"] in self.resource_types

def resolve_resource_type(self, resource_dict): # type: ignore[no-untyped-def]
def resolve_resource_type(self, resource_dict: Dict[str, Any]) -> Any:
"""Returns the Resource class corresponding to the 'Type' key in the given resource dict.
:param dict resource_dict: the resource dict to resolve
:returns: the resolved Resource class
:rtype: class
"""
if not self.can_resolve(resource_dict): # type: ignore[no-untyped-call]
if not self.can_resolve(resource_dict):
raise TypeError(
"Resource dict has missing or invalid value for key Type. Event Type is: {}.".format(
resource_dict.get("Type")
Expand Down
4 changes: 2 additions & 2 deletions samtranslator/model/eventsources/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def resources_to_link(self, resources): # type: ignore[no-untyped-def]
permitted_stage = "*"
stage_suffix = "AllStages"
explicit_api = None
rest_api_id = self.get_rest_api_id_string(self.RestApiId) # type: ignore[attr-defined, no-untyped-call]
rest_api_id = self.get_rest_api_id_string(self.RestApiId) # type: ignore[attr-defined]
if isinstance(rest_api_id, str):

if (
Expand Down Expand Up @@ -952,7 +952,7 @@ def _add_swagger_integration(self, api, api_id, function, intrinsics_resolver):
api["DefinitionBody"] = editor.swagger

@staticmethod
def get_rest_api_id_string(rest_api_id): # type: ignore[no-untyped-def]
def get_rest_api_id_string(rest_api_id: Any) -> Any:
"""
rest_api_id can be either a string or a dictionary where the actual api id is the value at key "Ref".
If rest_api_id is a dictionary with key "Ref", returns value at key "Ref". Otherwise, return rest_api_id.
Expand Down
6 changes: 3 additions & 3 deletions samtranslator/model/sam_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def _event_resources_to_link(self, resources): # type: ignore[no-untyped-def]
if self.Events:
for logical_id, event_dict in self.Events.items():
try:
event_source = self.event_resolver.resolve_resource_type(event_dict).from_dict( # type: ignore[no-untyped-call]
event_source = self.event_resolver.resolve_resource_type(event_dict).from_dict(
self.logical_id + logical_id, event_dict, logical_id
)
except (TypeError, AttributeError) as e:
Expand Down Expand Up @@ -742,7 +742,7 @@ def _generate_event_resources( # type: ignore[no-untyped-def]
if self.Events:
for logical_id, event_dict in sorted(self.Events.items(), key=SamFunction.order_events):
try:
eventsource = self.event_resolver.resolve_resource_type(event_dict).from_dict( # type: ignore[no-untyped-call]
eventsource = self.event_resolver.resolve_resource_type(event_dict).from_dict(
lambda_function.logical_id + logical_id, event_dict, logical_id
)
except TypeError as e:
Expand Down Expand Up @@ -1763,7 +1763,7 @@ def _event_resources_to_link(self, resources): # type: ignore[no-untyped-def]
if self.Events:
for logical_id, event_dict in self.Events.items():
try:
event_source = self.event_resolver.resolve_resource_type(event_dict).from_dict( # type: ignore[no-untyped-call]
event_source = self.event_resolver.resolve_resource_type(event_dict).from_dict(
self.logical_id + logical_id, event_dict, logical_id
)
except (TypeError, AttributeError) as e:
Expand Down
2 changes: 1 addition & 1 deletion samtranslator/model/stepfunctions/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def resources_to_link(self, resources): # type: ignore[no-untyped-def]
permitted_stage = "*"
stage_suffix = "AllStages"
explicit_api = None
rest_api_id = PushApi.get_rest_api_id_string(self.RestApiId) # type: ignore[attr-defined, no-untyped-call]
rest_api_id = PushApi.get_rest_api_id_string(self.RestApiId) # type: ignore[attr-defined]
if isinstance(rest_api_id, str):

if (
Expand Down
2 changes: 1 addition & 1 deletion samtranslator/plugins/api/implicit_api_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _get_api_id(self, event_properties): # type: ignore[no-untyped-def]
Handles case where API id is not specified or is a reference to a logical id.
"""
api_id = event_properties.get(self.api_id_property)
return Api.get_rest_api_id_string(api_id) # type: ignore[no-untyped-call]
return Api.get_rest_api_id_string(api_id)

def _maybe_add_condition_to_implicit_api(self, template_dict): # type: ignore[no-untyped-def]
"""
Expand Down
5 changes: 3 additions & 2 deletions samtranslator/sdk/parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import boto3
from typing import Dict, Any
import copy

from samtranslator.translator.arn_generator import ArnGenerator, NoRegionFound
Expand All @@ -9,7 +10,7 @@ class SamParameterValues(object):
Class representing SAM parameter values.
"""

def __init__(self, parameter_values): # type: ignore[no-untyped-def]
def __init__(self, parameter_values: Dict[Any, Any]):
"""
Initialize the object given the parameter values as a dictionary
Expand All @@ -18,7 +19,7 @@ def __init__(self, parameter_values): # type: ignore[no-untyped-def]

self.parameter_values = copy.deepcopy(parameter_values)

def add_default_parameter_values(self, sam_template): # type: ignore[no-untyped-def]
def add_default_parameter_values(self, sam_template: Dict[str, Any]) -> Any:
"""
Method to read default values for template parameters and merge with user supplied values.
Expand Down
2 changes: 1 addition & 1 deletion samtranslator/translator/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def transform(input_fragment, parameter_values, managed_policy_loader, feature_t
sam_parser = Parser()
to_py27_compatible_template(input_fragment, parameter_values) # type: ignore[no-untyped-call]
translator = Translator(managed_policy_loader.load(), sam_parser) # type: ignore[no-untyped-call]
transformed = translator.translate( # type: ignore[no-untyped-call]
transformed = translator.translate(
input_fragment,
parameter_values=parameter_values,
feature_toggle=feature_toggle,
Expand Down
38 changes: 24 additions & 14 deletions samtranslator/translator/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from samtranslator.metrics.method_decorator import MetricsMethodWrapperSingleton
from samtranslator.metrics.metrics import DummyMetricsPublisher, Metrics

from typing import Dict, Any, Optional, List
from samtranslator.feature_toggle.feature_toggle import (
FeatureToggle,
FeatureToggleDefaultConfigProvider,
Expand Down Expand Up @@ -55,7 +55,9 @@ def __init__(self, managed_policy_map, sam_parser, plugins=None, boto_session=No
if self.boto_session:
ArnGenerator.BOTO_SESSION_REGION_NAME = self.boto_session.region_name

def _get_function_names(self, resource_dict, intrinsics_resolver): # type: ignore[no-untyped-def]
def _get_function_names(
self, resource_dict: Dict[str, Any], intrinsics_resolver: IntrinsicsResolver
) -> Dict[str, str]:
"""
:param resource_dict: AWS::Serverless::Function resource is provided as input
:param intrinsics_resolver: to resolve intrinsics for function_name
Expand All @@ -71,19 +73,25 @@ def _get_function_names(self, resource_dict, intrinsics_resolver): # type: igno
item_properties = item.get("Properties", {})
if item.get("Type") == "Api" and item_properties.get("RestApiId"):
rest_api = item_properties.get("RestApiId")
api_name = Api.get_rest_api_id_string(rest_api) # type: ignore[no-untyped-call]
api_name = Api.get_rest_api_id_string(rest_api)
if isinstance(api_name, str):
resource_dict_copy = copy.deepcopy(resource_dict)
function_name = intrinsics_resolver.resolve_parameter_refs(
resource_dict_copy.get("Properties").get("FunctionName")
resource_dict_copy.get("Properties", {}).get("FunctionName")
)
if function_name:
self.function_names[api_name] = str(self.function_names.get(api_name, "")) + str(
function_name
)
return self.function_names

def translate(self, sam_template, parameter_values, feature_toggle=None, passthrough_metadata=False): # type: ignore[no-untyped-def]
def translate(
self,
sam_template: Dict[str, Any],
parameter_values: Dict[Any, Any],
feature_toggle: Optional[FeatureToggle] = None,
passthrough_metadata: Optional[bool] = False,
) -> Dict[str, Any]:
"""Loads the SAM resources from the given SAM manifest, replaces them with their corresponding
CloudFormation resources, and returns the resulting CloudFormation template.
Expand All @@ -103,10 +111,10 @@ def translate(self, sam_template, parameter_values, feature_toggle=None, passthr
if feature_toggle
else FeatureToggle(FeatureToggleDefaultConfigProvider(), stage=None, account_id=None, region=None) # type: ignore[no-untyped-call, no-untyped-call]
)
self.function_names = {}
self.function_names: Dict[Any, Any] = {}
self.redeploy_restapi_parameters = {}
sam_parameter_values = SamParameterValues(parameter_values) # type: ignore[no-untyped-call]
sam_parameter_values.add_default_parameter_values(sam_template) # type: ignore[no-untyped-call]
sam_parameter_values = SamParameterValues(parameter_values)
sam_parameter_values.add_default_parameter_values(sam_template)
sam_parameter_values.add_pseudo_parameter_values(self.boto_session) # type: ignore[no-untyped-call]
parameter_values = sam_parameter_values.parameter_values
# Create & Install plugins
Expand All @@ -130,10 +138,10 @@ def translate(self, sam_template, parameter_values, feature_toggle=None, passthr
shared_api_usage_plan = SharedApiUsagePlan()
document_errors = []
changed_logical_ids = {}
route53_record_set_groups = {} # type: ignore[var-annotated]
for logical_id, resource_dict in self._get_resources_to_iterate(sam_template, macro_resolver): # type: ignore[no-untyped-call]
route53_record_set_groups: Dict[Any, Any] = {}
for logical_id, resource_dict in self._get_resources_to_iterate(sam_template, macro_resolver):
try:
macro = macro_resolver.resolve_resource_type(resource_dict).from_dict( # type: ignore[no-untyped-call]
macro = macro_resolver.resolve_resource_type(resource_dict).from_dict(
logical_id, resource_dict, sam_plugins=sam_plugins
)

Expand All @@ -146,7 +154,7 @@ def translate(self, sam_template, parameter_values, feature_toggle=None, passthr
kwargs["resource_resolver"] = resource_resolver
kwargs["original_template"] = sam_template
# add the value of FunctionName property if the function is referenced with the api resource
self.redeploy_restapi_parameters["function_names"] = self._get_function_names( # type: ignore[no-untyped-call]
self.redeploy_restapi_parameters["function_names"] = self._get_function_names(
resource_dict, intrinsics_resolver
)
kwargs["redeploy_restapi_parameters"] = self.redeploy_restapi_parameters
Expand Down Expand Up @@ -181,7 +189,7 @@ def translate(self, sam_template, parameter_values, feature_toggle=None, passthr
if deployment_preference_collection.needs_resource_condition(): # type: ignore[no-untyped-call]
new_conditions = deployment_preference_collection.create_aggregate_deployment_condition() # type: ignore[no-untyped-call]
if new_conditions:
template.get("Conditions").update(new_conditions)
template.get("Conditions", {}).update(new_conditions)

if not deployment_preference_collection.can_skip_service_role(): # type: ignore[no-untyped-call]
template["Resources"].update(deployment_preference_collection.get_codedeploy_iam_role().to_dict()) # type: ignore[no-untyped-call]
Expand Down Expand Up @@ -211,7 +219,9 @@ def translate(self, sam_template, parameter_values, feature_toggle=None, passthr
raise InvalidDocumentException(document_errors)

# private methods
def _get_resources_to_iterate(self, sam_template, macro_resolver): # type: ignore[no-untyped-def]
def _get_resources_to_iterate(
self, sam_template: Dict[str, Any], macro_resolver: ResourceTypeResolver
) -> List[Any]:
"""
Returns a list of resources to iterate, order them based on the following order:
Expand Down

0 comments on commit a7d8ae5

Please sign in to comment.