diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index d10c57baa43a..2b22c086099b 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -43,6 +43,7 @@ from apache_beam.runners.direct.bundle_factory import BundleFactory, _Bundle from apache_beam.utils.timestamp import Timestamp + class _ExecutionContext(object): """Contains the context for the execution of a single PTransform. @@ -389,6 +390,7 @@ def extract_all_timers(self): return self._watermark_manager.extract_all_timers() def is_done(self, transform=None): + # type: (Optional[AppliedPTransform]) -> bool """Checks completion of a step or the pipeline. Args: @@ -407,6 +409,7 @@ def is_done(self, transform=None): return True def _is_transform_done(self, transform): + # type: (AppliedPTransform) -> bool tw = self._watermark_manager.get_watermarks(transform) return tw.output_watermark == WatermarkManager.WATERMARK_POS_INF diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index 32a6b32ea561..a49e2e992c4a 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -25,9 +25,13 @@ import sys import threading import traceback +import typing from builtins import object from builtins import range from weakref import WeakValueDictionary +from typing import Generic +from typing import Type +from typing import TypeVar from future.moves import queue from future.utils import raise_ @@ -37,6 +41,11 @@ from apache_beam.transforms import sideinputs from apache_beam.utils import counters +if typing.TYPE_CHECKING: + from apache_beam.runners.direct.evaluation_context import EvaluationContext + +T = TypeVar('T') + class _ExecutorService(object): """Thread pool for executing tasks in parallel.""" @@ -408,8 +417,11 @@ class _ExecutorServiceParallelExecutor(object): NUM_WORKERS = 1 - def __init__(self, value_to_consumers, transform_evaluator_registry, - evaluation_context): + def __init__(self, + value_to_consumers, + transform_evaluator_registry, + evaluation_context # type: EvaluationContext + ): self.executor_service = _ExecutorService( _ExecutorServiceParallelExecutor.NUM_WORKERS) self.transform_executor_services = _TransformExecutorServices( @@ -481,14 +493,16 @@ def schedule_consumption(self, consumer_applied_ptransform, committed_bundle, on_complete, transform_executor_service) transform_executor_service.schedule(transform_executor) - class _TypedUpdateQueue(object): + class _TypedUpdateQueue(Generic[T]): """Type checking update queue with blocking and non-blocking operations.""" def __init__(self, item_type): + # type: (Type[T]) -> None self._item_type = item_type - self._queue = queue.Queue() + self._queue = queue.Queue() # type: queue.Queue[T] def poll(self): + # type: () -> Optional[T] try: item = self._queue.get_nowait() self._queue.task_done() @@ -497,6 +511,7 @@ def poll(self): return None def take(self): + # type: () -> T # The implementation of Queue.Queue.get() does not propagate # KeyboardInterrupts when a timeout is not used. We therefore use a # one-second timeout in the following loop to allow KeyboardInterrupts @@ -510,6 +525,7 @@ def take(self): pass def offer(self, item): + # type: (T) -> None assert isinstance(item, self._item_type) self._queue.put_nowait(item) @@ -548,6 +564,7 @@ class _MonitorTask(_ExecutorService.CallableTask): """MonitorTask continuously runs to ensure that pipeline makes progress.""" def __init__(self, executor): + # type: (_ExecutorServiceParallelExecutor) -> None self._executor = executor @property diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 23431f16ddbc..5a0c4179542a 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -20,7 +20,9 @@ from __future__ import absolute_import import threading +import typing from builtins import object +from typing import Dict from apache_beam import pipeline from apache_beam import pvalue @@ -29,6 +31,9 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import TIME_GRANULARITY +if typing.TYPE_CHECKING: + from apache_beam.pipeline import AppliedPTransform + class WatermarkManager(object): """For internal use only; no backwards-compatibility guarantees. @@ -45,7 +50,7 @@ def __init__(self, clock, root_transforms, value_to_consumers, self._value_to_consumers = value_to_consumers self._transform_keyed_states = transform_keyed_states # AppliedPTransform -> TransformWatermarks - self._transform_to_watermarks = {} + self._transform_to_watermarks = {} # type: Dict[AppliedPTransform, _TransformWatermarks] for root_transform in root_transforms: self._transform_to_watermarks[root_transform] = _TransformWatermarks( @@ -73,6 +78,7 @@ def _update_input_transform_watermarks(self, applied_ptransform): input_transform_watermarks) def get_watermarks(self, applied_ptransform): + # type: (AppliedPTransform) -> _TransformWatermarks """Gets the input and output watermarks for an AppliedPTransform. If the applied_ptransform has not processed any elements, return a