Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Routinely add type hints #2771

Merged
merged 2 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions samtranslator/feature_toggle/dialup.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import hashlib
from abc import ABCMeta, abstractmethod


class BaseDialup(object):
class BaseDialup(object, metaclass=ABCMeta):
"""BaseDialup class to provide an interface for all dialup classes"""

def __init__(self, region_config, **kwargs): # type: ignore[no-untyped-def]
self.region_config = region_config

def is_enabled(self): # type: ignore[no-untyped-def]
@abstractmethod
def is_enabled(self) -> bool:
"""
Returns a bool on whether this dialup is enabled or not
"""
raise NotImplementedError

def __str__(self): # type: ignore[no-untyped-def]
def __str__(self) -> str:
return self.__class__.__name__


Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(self, region_config, account_id, feature_name, **kwargs): # type:
self.account_id = account_id
self.feature_name = feature_name

def _get_account_percentile(self): # type: ignore[no-untyped-def]
def _get_account_percentile(self) -> int:
"""
Get account percentile based on sha256 hash of account ID and feature_name

Expand All @@ -65,10 +66,10 @@ def _get_account_percentile(self): # type: ignore[no-untyped-def]
m.update(self.feature_name.encode())
return int(m.hexdigest(), 16) % 100

def is_enabled(self): # type: ignore[no-untyped-def]
def is_enabled(self) -> bool:
"""
Enable when account_percentile falls within target_percentile
Meaning only (target_percentile)% of accounts will be enabled
"""
target_percentile = self.region_config.get("enabled-%", 0)
return self._get_account_percentile() < target_percentile # type: ignore[no-untyped-call]
target_percentile: int = self.region_config.get("enabled-%", 0)
return self._get_account_percentile() < target_percentile
22 changes: 13 additions & 9 deletions samtranslator/feature_toggle/feature_toggle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, cast

import boto3
import logging

Expand Down Expand Up @@ -87,15 +90,16 @@ def is_enabled(self, feature_name: str) -> bool:
return is_enabled


class FeatureToggleConfigProvider:
class FeatureToggleConfigProvider(ABC):
"""Interface for all FeatureToggle config providers"""

def __init__(self) -> None:
pass

@property
def config(self): # type: ignore[no-untyped-def]
raise NotImplementedError
@abstractmethod
def config(self) -> Dict[str, Any]:
pass


class FeatureToggleDefaultConfigProvider(FeatureToggleConfigProvider):
Expand All @@ -105,7 +109,7 @@ def __init__(self) -> None:
FeatureToggleConfigProvider.__init__(self)

@property
def config(self): # type: ignore[no-untyped-def]
def config(self) -> Dict[str, Any]:
return {}


Expand All @@ -116,10 +120,10 @@ def __init__(self, local_config_path): # type: ignore[no-untyped-def]
FeatureToggleConfigProvider.__init__(self)
with open(local_config_path, "r", encoding="utf-8") as f:
config_json = f.read()
self.feature_toggle_config = json.loads(config_json)
self.feature_toggle_config = cast(Dict[str, Any], json.loads(config_json))

@property
def config(self): # type: ignore[no-untyped-def]
def config(self) -> Dict[str, Any]:
return self.feature_toggle_config


Expand Down Expand Up @@ -147,13 +151,13 @@ def __init__(self, application_id, environment_id, configuration_profile_id, app
ClientId="FeatureToggleAppConfigConfigProvider",
)
binary_config_string = response["Content"].read()
self.feature_toggle_config = json.loads(binary_config_string.decode("utf-8"))
self.feature_toggle_config = cast(Dict[str, Any], json.loads(binary_config_string.decode("utf-8")))
LOG.info("Finished loading feature toggle config from AppConfig.")
except Exception as ex:
LOG.error("Failed to load config from AppConfig: {}. Using empty config.".format(ex))
# There is chance that AppConfig is not available in a particular region.
self.feature_toggle_config = json.loads("{}")
self.feature_toggle_config = {}

@property
def config(self): # type: ignore[no-untyped-def]
def config(self) -> Dict[str, Any]:
return self.feature_toggle_config
4 changes: 2 additions & 2 deletions samtranslator/intrinsics/resource_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def get_all(self, logical_id): # type: ignore[no-untyped-def]
"""
return self._refs.get(logical_id, None)

def __len__(self): # type: ignore[no-untyped-def]
def __len__(self) -> int:
"""
To make len(this_object) work
:return: Number of resource references available
"""
return len(self._refs)

def __str__(self): # type: ignore[no-untyped-def]
def __str__(self) -> str:
return str(self._refs)
4 changes: 2 additions & 2 deletions samtranslator/metrics/method_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class MetricsMethodWrapperSingleton:
_METRICS_INSTANCE = _DUMMY_INSTANCE

@staticmethod
def set_instance(metrics): # type: ignore[no-untyped-def]
def set_instance(metrics: Metrics) -> None:
MetricsMethodWrapperSingleton._METRICS_INSTANCE = metrics

@staticmethod
def get_instance(): # type: ignore[no-untyped-def]
def get_instance() -> Metrics:
"""
Return the instance, if nothing is set return a dummy one
"""
Expand Down
9 changes: 5 additions & 4 deletions samtranslator/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import logging
from datetime import datetime
from typing import Any, Dict

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,7 +102,7 @@ def __init__(self, name, value, unit, dimensions=None, timestamp=None): # type:
self.dimensions = dimensions if dimensions else []
self.timestamp = timestamp if timestamp else datetime.utcnow()

def get_metric_data(self): # type: ignore[no-untyped-def]
def get_metric_data(self) -> Dict[str, Any]:
return {
"MetricName": self.name,
"Value": self.value,
Expand All @@ -123,13 +124,13 @@ def __init__(self, namespace="ServerlessTransform", metrics_publisher=None): #
self.metrics_cache = {}
self.namespace = namespace

def __del__(self): # type: ignore[no-untyped-def]
def __del__(self) -> None:
if len(self.metrics_cache) > 0:
# attempting to publish if user forgot to call publish in code
LOG.warning(
"There are unpublished metrics. Please make sure you call publish after you record all metrics."
)
self.publish() # type: ignore[no-untyped-call]
self.publish()

def _record_metric(self, name, value, unit, dimensions=None, timestamp=None): # type: ignore[no-untyped-def]
"""
Expand Down Expand Up @@ -167,7 +168,7 @@ def record_latency(self, name, value, dimensions=None, timestamp=None): # type:
"""
self._record_metric(name, value, Unit.Milliseconds, dimensions, timestamp) # type: ignore[no-untyped-call]

def publish(self): # type: ignore[no-untyped-def]
def publish(self) -> None:
"""Calls publish method from the configured metrics publisher to publish metrics"""
# flatten the key->list dict into a flat list; we don't care about the key as it's
# the metric name which is also in the MetricDatum object
Expand Down
10 changes: 5 additions & 5 deletions samtranslator/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" CloudFormation Resource serialization, deserialization, and validation """
import re
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from samtranslator.intrinsics.resolver import IntrinsicsResolver
from samtranslator.model.exceptions import ExpectedType, InvalidResourceException, InvalidResourcePropertyTypeException
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_supported_resource_attributes(cls): # type: ignore[no-untyped-def]
return tuple(cls._supported_resource_attributes)

@classmethod
def get_pass_through_attributes(cls): # type: ignore[no-untyped-def]
def get_pass_through_attributes(cls) -> Tuple[str, ...]:
"""
A getter method for the resource attributes to be passed to auto-generated resources
returns: a tuple that contains the name of all pass through attributes
Expand Down Expand Up @@ -254,11 +254,11 @@ def to_dict(self) -> Dict[str, Dict[str, Any]]:
"""
self.validate_properties()

resource_dict = self._generate_resource_dict() # type: ignore[no-untyped-call]
resource_dict = self._generate_resource_dict()

return {self.logical_id: resource_dict}

def _generate_resource_dict(self): # type: ignore[no-untyped-def]
def _generate_resource_dict(self) -> Dict[str, Any]:
"""Generates the resource dict for this Resource, the value associated with the logical id in a CloudFormation
template's Resources section.

Expand Down Expand Up @@ -383,7 +383,7 @@ def get_passthrough_resource_attributes(self) -> Dict[str, Any]:
:return: Dictionary of resource attributes.
"""
attributes = {}
for resource_attribute in self.get_pass_through_attributes(): # type: ignore[no-untyped-call]
for resource_attribute in self.get_pass_through_attributes():
if resource_attribute in self.resource_attributes:
attributes[resource_attribute] = self.resource_attributes.get(resource_attribute)
return attributes
Expand Down
4 changes: 2 additions & 2 deletions samtranslator/model/api/api_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ def _get_authorizers(self, authorizers_config, default_authorizer=None): # type
# The dict below will eventually become part of swagger/openapi definition, thus requires using Py27Dict()
authorizers = Py27Dict()
if default_authorizer == "AWS_IAM":
authorizers[default_authorizer] = ApiGatewayAuthorizer( # type: ignore[no-untyped-call]
authorizers[default_authorizer] = ApiGatewayAuthorizer(
api_logical_id=self.logical_id, name=default_authorizer, is_aws_iam_authorizer=True
)

Expand All @@ -1131,7 +1131,7 @@ def _get_authorizers(self, authorizers_config, default_authorizer=None): # type
for authorizer_name, authorizer in authorizers_config.items():
sam_expect(authorizer, self.logical_id, f"Auth.Authorizers.{authorizer_name}").to_be_a_map()

authorizers[authorizer_name] = ApiGatewayAuthorizer( # type: ignore[no-untyped-call]
authorizers[authorizer_name] = ApiGatewayAuthorizer(
api_logical_id=self.logical_id,
name=authorizer_name,
user_pool_arn=authorizer.get("UserPoolArn"),
Expand Down
Loading