Skip to content

Commit

Permalink
More Type Annotations (#8536)
Browse files Browse the repository at this point in the history
* Extend use of type annotations in the events module.

* Add return type of None to more __init__ definitions.

* Still more type annotations adding -> None to __init__

* Tweak per review
  • Loading branch information
peterallenwebb authored and QMalcolm committed Oct 6, 2023
1 parent 64518d0 commit 0d82a31
Show file tree
Hide file tree
Showing 50 changed files with 426 additions and 411 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230831-164435.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Added more type annotations.
time: 2023-08-31T16:44:35.737954-04:00
custom:
Author: peterallenwebb
Issue: "8537"
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):

TYPE: str = NotImplemented

def __init__(self, profile: AdapterRequiredConfig):
def __init__(self, profile: AdapterRequiredConfig) -> None:
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = flags.MP_CONTEXT.RLock()
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class BaseAdapter(metaclass=AdapterMeta):
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}

def __init__(self, config):
def __init__(self, config) -> None:
self.config = config
self.cache = RelationsCache()
self.connections = self.ConnectionManager(config)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
credentials: Type[Credentials],
include_path: str,
dependencies: Optional[List[str]] = None,
):
) -> None:

self.adapter: Type[AdapterProtocol] = adapter
self.credentials: Type[Credentials] = credentials
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class NodeWrapper:
def __init__(self, node):
def __init__(self, node) -> None:
self._inner_node = node

def __getattr__(self, name):
Expand Down Expand Up @@ -57,7 +57,7 @@ def set(self, comment: Optional[str], append: bool):


class MacroQueryStringSetter:
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest):
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest) -> None:
self.manifest = manifest
self.config = config

Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class _CachedRelation:
:attr BaseRelation inner: The underlying dbt relation.
"""

def __init__(self, inner):
self.referenced_by = {}
def __init__(self, inner) -> None:
self.referenced_by: Dict[_ReferenceKey, _CachedRelation] = {}
self.inner = inner

def __str__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class AdapterProtocol( # type: ignore[misc]
ConnectionManager: Type[ConnectionManager_T]
connections: ConnectionManager_T

def __init__(self, config: AdapterRequiredConfig):
def __init__(self, config: AdapterRequiredConfig) -> None:
...

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self,
manifest: Optional[Manifest] = None,
callbacks: Optional[List[Callable[[EventMsg], None]]] = None,
):
) -> None:
self.manifest = manifest

if callbacks is None:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Implementation from: https://stackoverflow.com/a/48394004
# Note MultiOption options must be specified with type=tuple or type=ChoiceTuple (https://github.com/pallets/click/issues/2012)
class MultiOption(click.Option):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.save_other_options = kwargs.pop("save_other_options", True)
nargs = kwargs.pop("nargs", -1)
assert nargs == -1, "nargs, if set, must be -1 not {}".format(nargs)
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class _NullMarker:


class ColumnTypeBuilder(Dict[str, NullableAgateType]):
def __init__(self):
def __init__(self) -> None:
super().__init__()

def __setitem__(self, key, value):
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _get_tests_for_node(manifest: Manifest, unique_id: UniqueID) -> List[UniqueI


class Linker:
def __init__(self, data=None):
def __init__(self, data=None) -> None:
if data is None:
data = {}
self.graph = nx.DiGraph(**data)
Expand Down Expand Up @@ -274,7 +274,7 @@ def get_graph_summary(self, manifest: Manifest) -> Dict[int, Dict[str, Any]]:


class Compiler:
def __init__(self, config):
def __init__(self, config) -> None:
self.config = config

def initialize(self):
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
user_config: UserConfig,
threads: int,
credentials: Credentials,
):
) -> None:
"""Explicitly defining `__init__` to work around bug in Python 3.9.7
https://bugs.python.org/issue45081
"""
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _list_if_none_or_string(value):


class ProjectPostprocessor(Dict[Keypath, Callable[[Any], Any]]):
def __init__(self):
def __init__(self) -> None:
super().__init__()

self[("on-run-start",)] = _list_if_none_or_string
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _get_project_directories(self) -> Iterator[Path]:


class UnsetCredentials(Credentials):
def __init__(self):
def __init__(self) -> None:
super().__init__("", "")

@property
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LazyHandle:
connection, updating the handle on the Connection.
"""

def __init__(self, opener: Callable[[Connection], Connection]):
def __init__(self, opener: Callable[[Connection], Connection]) -> None:
self.opener = opener

def resolve(self, connection: Connection) -> Connection:
Expand Down
16 changes: 8 additions & 8 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def find_unique_id_for_package(storage, key, package: Optional[PackageName]):


class DocLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -119,7 +119,7 @@ def perform_lookup(self, unique_id: UniqueID, manifest) -> Documentation:


class SourceLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -156,7 +156,7 @@ class RefableLookup(dbtClassMixin):
_lookup_types: ClassVar[set] = set(NodeType.refable())
_versioned_types: ClassVar[set] = set(NodeType.versioned())

def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -267,7 +267,7 @@ def _find_unique_ids_for_package(self, key, package: Optional[PackageName]) -> L


class MetricLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -306,7 +306,7 @@ class SemanticModelByMeasureLookup(dbtClassMixin):
the semantic models in a manifest.
"""

def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: DefaultDict[str, Dict[PackageName, UniqueID]] = defaultdict(dict)
self.populate(manifest)

Expand Down Expand Up @@ -355,7 +355,7 @@ def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SemanticM

# This handles both models/seeds/snapshots and sources/metrics/exposures/semantic_models
class DisabledLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest"):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, List[Any]]] = {}
self.populate(manifest)

Expand Down Expand Up @@ -1427,12 +1427,12 @@ def __reduce_ex__(self, protocol):


class MacroManifest(MacroMethods):
def __init__(self, macros):
def __init__(self, macros) -> None:
self.macros = macros
self.metadata = ManifestMetadata()
# This is returned by the 'graph' context property
# in the ProviderContext class.
self.flat_graph = {}
self.flat_graph: Dict[str, Any] = {}


AnyManifest = Union[Manifest, MacroManifest]
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class MetricReference(object):
def __init__(self, metric_name, package_name=None):
def __init__(self, metric_name, package_name=None) -> None:
self.metric_name = metric_name
self.package_name = package_name

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/semantic_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class SemanticManifest:
def __init__(self, manifest):
def __init__(self, manifest) -> None:
self.manifest = manifest

def validate(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def to_msg_dict(self):

# This is a context manager
class collect_timing_info:
def __init__(self, name: str, callback: Callable[[TimingInfo], None]):
def __init__(self, name: str, callback: Callable[[TimingInfo], None]) -> None:
self.timing_info = TimingInfo(name=name)
self.callback = callback

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class PreviousState:
def __init__(self, state_path: Path, target_path: Path, project_root: Path):
def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> None:
self.state_path: Path = state_path
self.target_path: Path = target_path
self.project_root: Path = project_root
Expand Down
12 changes: 6 additions & 6 deletions core/dbt/events/adapter_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,32 @@
class AdapterLogger:
name: str

def debug(self, msg, *args):
def debug(self, msg, *args) -> None:
event = AdapterEventDebug(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def info(self, msg, *args):
def info(self, msg, *args) -> None:
event = AdapterEventInfo(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def warning(self, msg, *args):
def warning(self, msg, *args) -> None:
event = AdapterEventWarning(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

def error(self, msg, *args):
def error(self, msg, *args) -> None:
event = AdapterEventError(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
fire_event(event)

# The default exc_info=True is what makes this method different
def exception(self, msg, *args):
def exception(self, msg, *args) -> None:
exc_info = str(traceback.format_exc())
event = AdapterEventError(
name=self.name,
Expand All @@ -51,7 +51,7 @@ def exception(self, msg, *args):
)
fire_event(event)

def critical(self, msg, *args):
def critical(self, msg, *args) -> None:
event = AdapterEventError(
name=self.name, base_msg=str(msg), args=list(args), node_info=get_node_info()
)
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/events/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_pid() -> int:
return os.getpid()


# in theory threads can change so we don't cache them.
# in theory threads can change, so we don't cache them.
def get_thread_name() -> str:
return threading.current_thread().name

Expand All @@ -55,7 +55,7 @@ class EventLevel(str, Enum):
class BaseEvent:
"""BaseEvent for proto message generated python events"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
class_name = type(self).__name__
msg_cls = getattr(types_pb2, class_name)
if class_name == "Formatting" and len(args) > 0:
Expand Down Expand Up @@ -100,9 +100,9 @@ def to_dict(self):
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True
)

def to_json(self):
def to_json(self) -> str:
return MessageToJson(
self.pb_msg, preserving_proto_field_name=True, including_default_valud_fields=True
self.pb_msg, preserving_proto_field_name=True, including_default_value_fields=True
)

def level_tag(self) -> EventLevel:
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/events/eventmgr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import traceback
from typing import Callable, List, Optional, Protocol
from typing import Callable, List, Optional, Protocol, Tuple
from uuid import uuid4

from dbt.events.base_types import BaseEvent, EventLevel, msg_from_base_event, EventMsg
Expand Down Expand Up @@ -38,14 +38,15 @@ def add_logger(self, config: LoggerConfig) -> None:
)
self.loggers.append(logger)

def flush(self):
def flush(self) -> None:
for logger in self.loggers:
logger.flush()


class IEventManager(Protocol):
callbacks: List[Callable[[EventMsg], None]]
invocation_id: str
loggers: List[_Logger]

def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
...
Expand All @@ -55,8 +56,9 @@ def add_logger(self, config: LoggerConfig) -> None:


class TestEventManager(IEventManager):
def __init__(self):
self.event_history = []
def __init__(self) -> None:
self.event_history: List[Tuple[BaseEvent, Optional[EventLevel]]] = []
self.loggers = []

def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None:
self.event_history.append((e, level))
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/events/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def _pluralize(string: Union[str, NodeType]) -> str:
return convert.pluralize()


def pluralize(count, string: Union[str, NodeType]):
def pluralize(count, string: Union[str, NodeType]) -> str:
pluralized: str = str(string)
if count != 1:
pluralized = _pluralize(string)
return f"{count} {pluralized}"


def timestamp_to_datetime_string(ts):
def timestamp_to_datetime_string(ts) -> str:
timestamp_dt = datetime.fromtimestamp(ts.seconds + ts.nanos / 1e9)
return timestamp_dt.strftime("%H:%M:%S.%f")
Loading

0 comments on commit 0d82a31

Please sign in to comment.