From c25792c83785bce1ed4b01dd531ee1404b899438 Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Fri, 25 Oct 2019 12:39:38 -0700 Subject: [PATCH] Add type annotations to bulk of python codebase There are no meaningful runtime changes in this commit. --- sdks/python/apache_beam/coders/coder_impl.py | 106 +++- sdks/python/apache_beam/coders/coders.py | 124 ++++- .../apache_beam/coders/observable_test.py | 4 +- sdks/python/apache_beam/coders/slow_stream.py | 14 +- .../coders/standard_coders_test.py | 4 +- sdks/python/apache_beam/coders/typecoders.py | 13 +- .../examples/cookbook/bigtableio_it_test.py | 6 +- sdks/python/apache_beam/internal/pickler.py | 6 +- sdks/python/apache_beam/internal/util.py | 16 +- sdks/python/apache_beam/io/avroio_test.py | 3 +- sdks/python/apache_beam/io/filebasedsource.py | 16 +- sdks/python/apache_beam/io/fileio.py | 29 +- sdks/python/apache_beam/io/filesystem.py | 8 + sdks/python/apache_beam/io/filesystems.py | 5 + .../flink/flink_streaming_impulse_source.py | 4 +- .../gcp/datastore/v1/query_splitter_test.py | 3 +- .../io/gcp/datastore/v1new/helper.py | 4 +- .../apache_beam/io/gcp/gcsfilesystem.py | 3 + sdks/python/apache_beam/io/gcp/pubsub.py | 40 +- .../python/apache_beam/io/hadoopfilesystem.py | 4 + sdks/python/apache_beam/io/iobase.py | 39 +- sdks/python/apache_beam/io/localfilesystem.py | 3 + .../apache_beam/io/restriction_trackers.py | 14 +- sdks/python/apache_beam/io/textio.py | 15 +- sdks/python/apache_beam/metrics/cells.py | 4 + .../apache_beam/metrics/monitoring_infos.py | 11 +- .../apache_beam/options/pipeline_options.py | 25 +- .../apache_beam/options/value_provider.py | 3 +- sdks/python/apache_beam/pipeline.py | 115 ++++- sdks/python/apache_beam/pvalue.py | 82 +++- sdks/python/apache_beam/runners/common.py | 161 ++++-- .../runners/dataflow/dataflow_runner.py | 4 +- .../runners/direct/bundle_factory.py | 13 +- .../consumer_tracking_pipeline_visitor.py | 17 +- .../runners/direct/direct_runner.py | 1 + .../runners/direct/evaluation_context.py | 80 ++- .../apache_beam/runners/direct/executor.py | 75 ++- .../runners/direct/sdf_direct_runner.py | 3 + .../runners/direct/transform_evaluator.py | 57 ++- .../runners/direct/watermark_manager.py | 52 +- .../runners/interactive/cache_manager.py | 5 +- .../interactive/display/display_manager.py | 4 +- .../interactive/display/pipeline_graph.py | 19 +- .../display/pipeline_graph_renderer.py | 15 + sdks/python/apache_beam/runners/job/utils.py | 2 + .../apache_beam/runners/pipeline_context.py | 43 +- .../portability/abstract_job_service.py | 105 +++- .../runners/portability/artifact_service.py | 35 +- .../runners/portability/fn_api_runner.py | 463 ++++++++++++++---- .../runners/portability/fn_api_runner_test.py | 3 +- .../portability/fn_api_runner_transforms.py | 128 ++++- .../runners/portability/job_server.py | 2 + .../runners/portability/local_job_service.py | 36 +- .../runners/portability/portable_runner.py | 12 +- .../runners/portability/portable_stager.py | 6 +- .../apache_beam/runners/portability/stager.py | 35 +- .../runners/portability/stager_test.py | 3 +- sdks/python/apache_beam/runners/runner.py | 41 +- sdks/python/apache_beam/runners/sdf_common.py | 7 + .../runners/worker/bundle_processor.py | 398 ++++++++++++--- .../apache_beam/runners/worker/data_plane.py | 106 +++- .../apache_beam/runners/worker/log_handler.py | 3 +- .../apache_beam/runners/worker/logger.py | 4 +- .../apache_beam/runners/worker/opcounters.py | 27 +- .../apache_beam/runners/worker/operations.py | 106 +++- .../apache_beam/runners/worker/sdk_worker.py | 162 ++++-- .../runners/worker/statesampler.py | 41 +- .../runners/worker/statesampler_slow.py | 22 +- .../runners/worker/worker_id_interceptor.py | 2 + .../runners/worker/worker_pool_main.py | 35 +- .../load_tests/load_test_metrics_utils.py | 3 +- .../python/apache_beam/testing/test_stream.py | 2 +- .../apache_beam/transforms/combiners.py | 4 +- sdks/python/apache_beam/transforms/core.py | 71 ++- sdks/python/apache_beam/transforms/display.py | 20 +- .../python/apache_beam/transforms/external.py | 5 + .../apache_beam/transforms/ptransform.py | 92 +++- .../apache_beam/transforms/sideinputs.py | 21 +- .../apache_beam/transforms/userstate.py | 21 + sdks/python/apache_beam/transforms/util.py | 52 +- sdks/python/apache_beam/transforms/window.py | 31 +- .../apache_beam/typehints/decorators.py | 22 +- .../typehints/decorators_test_py3.py | 4 +- .../typehints/native_type_compatibility.py | 2 +- .../python/apache_beam/typehints/typehints.py | 3 +- sdks/python/apache_beam/utils/counters.py | 12 +- sdks/python/apache_beam/utils/profiler.py | 4 + sdks/python/apache_beam/utils/proto_utils.py | 45 ++ sdks/python/apache_beam/utils/timestamp.py | 34 ++ sdks/python/apache_beam/utils/urns.py | 63 ++- .../apache_beam/utils/windowed_value.py | 34 +- sdks/python/gen_protos.py | 8 +- 92 files changed, 3006 insertions(+), 603 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 561d36d2acbd..8c50c14af746 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -38,6 +38,15 @@ from builtins import chr from builtins import object from io import BytesIO +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from fastavro import parse_schema from fastavro import schemaless_reader @@ -52,6 +61,9 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import Timestamp +if TYPE_CHECKING: + from apache_beam.transforms.window import IntervalWindow + # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: from .stream import InputStream as create_InputStream @@ -74,6 +86,7 @@ else: is_compiled = False fits_in_64_bits = lambda x: -(1 << 63) <= x <= (1 << 63) - 1 + # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports @@ -81,46 +94,58 @@ MIN_TIMESTAMP_micros = MIN_TIMESTAMP.micros MAX_TIMESTAMP_micros = MAX_TIMESTAMP.micros +IterableStateReader = Callable[[bytes, 'CoderImpl'], Iterable] +IterableStateWriter = Callable[[Iterable, 'CoderImpl'], bytes] +Observables = List[Tuple[observable.ObservableMixin, 'CoderImpl']] class CoderImpl(object): """For internal use only; no backwards-compatibility guarantees.""" def encode_to_stream(self, value, stream, nested): + # type: (Any, create_OutputStream, bool) -> None """Reads object from potentially-nested encoding in stream.""" raise NotImplementedError def decode_from_stream(self, stream, nested): + # type: (create_InputStream, bool) -> Any """Reads object from potentially-nested encoding in stream.""" raise NotImplementedError def encode(self, value): + # type: (Any) -> bytes """Encodes an object to an unnested string.""" raise NotImplementedError def decode(self, encoded): + # type: (bytes) -> Any """Decodes an object to an unnested string.""" raise NotImplementedError def encode_all(self, values): + # type: (Iterable[Any]) -> bytes out = create_OutputStream() for value in values: self.encode_to_stream(value, out, True) return out.get() def decode_all(self, encoded): + # type: (bytes) -> Iterator[Any] input_stream = create_InputStream(encoded) while input_stream.size() > 0: yield self.decode_from_stream(input_stream, True) def encode_nested(self, value): + # type: (Any) -> bytes out = create_OutputStream() self.encode_to_stream(value, out, True) return out.get() def decode_nested(self, encoded): + # type: (bytes) -> Any return self.decode_from_stream(create_InputStream(encoded), True) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int """Estimates the encoded size of the given value, in bytes.""" out = ByteCountingOutputStream() self.encode_to_stream(value, out, nested) @@ -133,6 +158,7 @@ def _get_nested_size(self, inner_size, nested): return varint_size + inner_size def get_estimated_size_and_observables(self, value, nested=False): + # type: (Any, bool) -> Tuple[int, Observables] """Returns estimated size of value along with any nested observables. The list of nested observables is returned as a list of 2-tuples of @@ -157,10 +183,12 @@ class SimpleCoderImpl(CoderImpl): Subclass of CoderImpl implementing stream methods using encode/decode.""" def encode_to_stream(self, value, stream, nested): + # type: (Any, create_OutputStream, bool) -> None """Reads object from potentially-nested encoding in stream.""" stream.write(self.encode(value), nested) def decode_from_stream(self, stream, nested): + # type: (create_InputStream, bool) -> Any """Reads object from potentially-nested encoding in stream.""" return self.decode(stream.read_all(nested)) @@ -171,14 +199,17 @@ class StreamCoderImpl(CoderImpl): Subclass of CoderImpl implementing encode/decode using stream methods.""" def encode(self, value): + # type: (Any) -> bytes out = create_OutputStream() self.encode_to_stream(value, out, False) return out.get() def decode(self, encoded): + # type: (bytes) -> Any return self.decode_from_stream(create_InputStream(encoded), False) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int """Estimates the encoded size of the given value, in bytes.""" out = ByteCountingOutputStream() self.encode_to_stream(value, out, nested) @@ -203,9 +234,11 @@ def _default_size_estimator(self, value): return len(self.encode(value)) def encode_to_stream(self, value, stream, nested): + # type: (Any, create_OutputStream, bool) -> None return stream.write(self._encoder(value), nested) def decode_from_stream(self, stream, nested): + # type: (create_InputStream, bool) -> Any return self._decoder(stream.read_all(nested)) def encode(self, value): @@ -215,9 +248,11 @@ def decode(self, encoded): return self._decoder(encoded) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int return self._get_nested_size(self._size_estimator(value), nested) def get_estimated_size_and_observables(self, value, nested=False): + # type: (Any, bool) -> Tuple[int, Observables] # TODO(robertwb): Remove this once all coders are correct. if isinstance(value, observable.ObservableMixin): # CallbackCoderImpl can presumably encode the elements too. @@ -252,10 +287,12 @@ def _check_safe(self, value): value, type(value), self._step_label)) def encode_to_stream(self, value, stream, nested): + # type: (Any, create_OutputStream, bool) -> None self._check_safe(value) return self._underlying_coder.encode_to_stream(value, stream, nested) def decode_from_stream(self, stream, nested): + # type: (create_InputStream, bool) -> Any return self._underlying_coder.decode_from_stream(stream, nested) def encode(self, value): @@ -266,9 +303,11 @@ def decode(self, encoded): return self._underlying_coder.decode(encoded) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int return self._underlying_coder.estimate_size(value, nested) def get_estimated_size_and_observables(self, value, nested=False): + # type: (Any, bool) -> Tuple[int, Observables] return self._underlying_coder.get_estimated_size_and_observables( value, nested) @@ -328,6 +367,7 @@ def register_iterable_like_type(t): _ITERABLE_LIKE_TYPES.add(t) def get_estimated_size_and_observables(self, value, nested=False): + # type: (Any, bool) -> Tuple[int, Observables] if isinstance(value, observable.ObservableMixin): # FastPrimitivesCoderImpl can presumably encode the elements too. return 1, [(value, self)] @@ -337,6 +377,7 @@ def get_estimated_size_and_observables(self, value, nested=False): return out.get_count(), [] def encode_to_stream(self, value, stream, nested): + # type: (Any, create_OutputStream, bool) -> None t = type(value) if value is None: stream.write_byte(NONE_TYPE) @@ -391,6 +432,7 @@ def encode_to_stream(self, value, stream, nested): self.fallback_coder_impl.encode_to_stream(value, stream, nested) def decode_from_stream(self, stream, nested): + # type: (create_InputStream, bool) -> Any t = stream.read_byte() if t == NONE_TYPE: return None @@ -433,9 +475,11 @@ class BytesCoderImpl(CoderImpl): A coder for bytes/str objects.""" def encode_to_stream(self, value, out, nested): + # type: (bytes, create_OutputStream, bool) -> None out.write(value, nested) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> bytes return in_stream.read_all(nested) def encode(self, value): @@ -482,17 +526,21 @@ class FloatCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" def encode_to_stream(self, value, out, nested): + # type: (float, create_OutputStream, bool) -> None out.write_bigendian_double(value) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> float return in_stream.read_bigendian_double() def estimate_size(self, unused_value, nested=False): + # type: (Any, bool) -> int # A double is encoded as 8 bytes, regardless of nesting. return 8 -IntervalWindow = None +if not TYPE_CHECKING: + IntervalWindow = None class IntervalWindowCoderImpl(StreamCoderImpl): @@ -508,6 +556,7 @@ def _from_normal_time(self, value): return value + _TIME_SHIFT def encode_to_stream(self, value, out, nested): + # type: (IntervalWindow, create_OutputStream, bool) -> None typed_value = value span_millis = (typed_value._end_micros // 1000 - typed_value._start_micros // 1000) @@ -516,9 +565,11 @@ def encode_to_stream(self, value, out, nested): out.write_var_int64(span_millis) def decode_from_stream(self, in_, nested): - global IntervalWindow - if IntervalWindow is None: - from apache_beam.transforms.window import IntervalWindow + # type: (create_InputStream, bool) -> IntervalWindow + if not TYPE_CHECKING: + global IntervalWindow + if IntervalWindow is None: + from apache_beam.transforms.window import IntervalWindow typed_value = IntervalWindow(None, None) typed_value._end_micros = ( 1000 * self._to_normal_time(in_.read_bigendian_uint64())) @@ -527,6 +578,7 @@ def decode_from_stream(self, in_, nested): return typed_value def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int # An IntervalWindow is context-insensitive, with a timestamp (8 bytes) # and a varint timespam. typed_value = value @@ -545,6 +597,7 @@ class TimestampCoderImpl(StreamCoderImpl): """ def encode_to_stream(self, value, out, nested): + # type: (Timestamp, create_OutputStream, bool) -> None millis = value.micros // 1000 if millis >= 0: millis = millis - _TIME_SHIFT @@ -553,6 +606,7 @@ def encode_to_stream(self, value, out, nested): out.write_bigendian_int64(millis) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> Timestamp millis = in_stream.read_bigendian_int64() if millis < 0: millis = millis + _TIME_SHIFT @@ -573,10 +627,12 @@ def __init__(self, payload_coder_impl): self._payload_coder_impl = payload_coder_impl def encode_to_stream(self, value, out, nested): + # type: (dict, create_OutputStream, bool) -> None self._timestamp_coder_impl.encode_to_stream(value['timestamp'], out, True) self._payload_coder_impl.encode_to_stream(value.get('payload'), out, True) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> dict # TODO(robertwb): Consider using a concrete class rather than a dict here. return dict( timestamp=self._timestamp_coder_impl.decode_from_stream( @@ -593,9 +649,11 @@ class VarIntCoderImpl(StreamCoderImpl): A coder for long/int objects.""" def encode_to_stream(self, value, out, nested): + # type: (int, create_OutputStream, bool) -> None out.write_var_int64(value) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> int return in_stream.read_var_int64() def encode(self, value): @@ -612,6 +670,7 @@ def decode(self, encoded): return StreamCoderImpl.decode(self, encoded) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int # Note that VarInts are encoded the same way regardless of nesting. return get_varint_size(value) @@ -625,9 +684,11 @@ def __init__(self, value): self._value = value def encode_to_stream(self, value, stream, nested): + # type: (Any, create_OutputStream, bool) -> None pass def decode_from_stream(self, stream, nested): + # type: (create_InputStream, bool) -> Any return self._value def encode(self, value): @@ -638,6 +699,7 @@ def decode(self, encoded): return self._value def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int return 0 @@ -658,6 +720,7 @@ def _construct_from_components(self, components): raise NotImplementedError def encode_to_stream(self, value, out, nested): + # type: (Any, create_OutputStream, bool) -> None values = self._extract_components(value) if len(self._coder_impls) != len(values): raise ValueError( @@ -668,12 +731,14 @@ def encode_to_stream(self, value, out, nested): nested or i + 1 < len(self._coder_impls)) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> Any return self._construct_from_components( [c.decode_from_stream(in_stream, nested or i + 1 < len(self._coder_impls)) for i, c in enumerate(self._coder_impls)]) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int """Estimates the encoded size of the given value, in bytes.""" # TODO(ccy): This ignores sizes of observable components. estimated_size, _ = ( @@ -681,10 +746,11 @@ def estimate_size(self, value, nested=False): return estimated_size def get_estimated_size_and_observables(self, value, nested=False): + # type: (Any, bool) -> Tuple[int, Observables] """Returns estimated size of value along with any nested observables.""" values = self._extract_components(value) estimated_size = 0 - observables = [] + observables = [] # type: Observables for i in range(0, len(self._coder_impls)): c = self._coder_impls[i] # type cast child_size, child_observables = ( @@ -724,10 +790,12 @@ def _construct_from_components(self, components): class _ConcatSequence(object): def __init__(self, head, tail): + # type: (Iterable[Any], Iterable[Any]) -> None self._head = head self._tail = tail def __iter__(self): + # type: () -> Iterator[Any] for elem in self._head: yield elem for elem in self._tail: @@ -782,8 +850,12 @@ class SequenceCoderImpl(StreamCoderImpl): # Default buffer size of 64kB of handling iterables of unknown length. _DEFAULT_BUFFER_SIZE = 64 * 1024 - def __init__(self, elem_coder, - read_state=None, write_state=None, write_state_threshold=0): + def __init__(self, + elem_coder, # type: CoderImpl + read_state=None, # type: Optional[IterableStateReader] + write_state=None, # type: Optional[IterableStateWriter] + write_state_threshold=0 # type: int + ): self._elem_coder = elem_coder self._read_state = read_state self._write_state = write_state @@ -793,6 +865,7 @@ def _construct_from_sequence(self, values): raise NotImplementedError def encode_to_stream(self, value, out, nested): + # type: (Sequence, create_OutputStream, bool) -> None # Compatible with Java's IterableLikeCoder. if hasattr(value, '__len__') and self._write_state is None: out.write_bigendian_int32(len(value)) @@ -838,11 +911,12 @@ def encode_to_stream(self, value, out, nested): out.write_var_int64(0) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> Sequence size = in_stream.read_bigendian_int32() if size >= 0: elements = [self._elem_coder.decode_from_stream(in_stream, True) - for _ in range(size)] + for _ in range(size)] # type: Iterable[Any] else: elements = [] count = in_stream.read_var_int64() @@ -863,6 +937,7 @@ def decode_from_stream(self, in_stream, nested): return self._construct_from_sequence(elements) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int """Estimates the encoded size of the given value, in bytes.""" # TODO(ccy): This ignores element sizes. estimated_size, _ = ( @@ -870,6 +945,7 @@ def estimate_size(self, value, nested=False): return estimated_size def get_estimated_size_and_observables(self, value, nested=False): + # type: (Any, bool) -> Tuple[int, Observables] """Returns estimated size of value along with any nested observables.""" estimated_size = 0 # Size of 32-bit integer storing number of elements. @@ -877,7 +953,7 @@ def get_estimated_size_and_observables(self, value, nested=False): if isinstance(value, observable.ObservableMixin): return estimated_size, [(value, self._elem_coder)] - observables = [] + observables = [] # type: Observables for elem in value: child_size, child_observables = ( self._elem_coder.get_estimated_size_and_observables( @@ -948,6 +1024,7 @@ def _choose_encoding(self, value): return PaneInfoEncoding.TWO_INDICES def encode_to_stream(self, value, out, nested): + # type: (windowed_value.PaneInfo, create_OutputStream, bool) -> None pane_info = value # cast encoding_type = self._choose_encoding(pane_info) out.write_byte(pane_info._encoded_byte | (encoding_type << 4)) @@ -962,6 +1039,7 @@ def encode_to_stream(self, value, out, nested): raise NotImplementedError('Invalid PaneInfoEncoding: %s' % encoding_type) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> windowed_value.PaneInfo encoded_first_byte = in_stream.read_byte() base = windowed_value._BYTE_TO_PANE_INFO[encoded_first_byte & 0xF] assert base is not None @@ -983,6 +1061,7 @@ def decode_from_stream(self, in_stream, nested): base.is_first, base.is_last, base.timing, index, nonspeculative_index) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int """Estimates the encoded size of the given value, in bytes.""" size = 1 encoding_type = self._choose_encoding(value) @@ -1019,6 +1098,7 @@ def __init__(self, value_coder, timestamp_coder, window_coder): self._pane_info_coder = PaneInfoCoderImpl() def encode_to_stream(self, value, out, nested): + # type: (windowed_value.WindowedValue, create_OutputStream, bool) -> None wv = value # type cast # Avoid creation of Timestamp object. restore_sign = -1 if wv.timestamp_micros < 0 else 1 @@ -1041,6 +1121,7 @@ def encode_to_stream(self, value, out, nested): self._value_coder.encode_to_stream(wv.value, out, nested) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> windowed_value.WindowedValue timestamp = self._to_normal_time(in_stream.read_bigendian_uint64()) # Restore MIN/MAX timestamps to their actual values as encoding incurs loss # of precision while converting to millis. @@ -1067,13 +1148,14 @@ def decode_from_stream(self, in_stream, nested): pane_info) def get_estimated_size_and_observables(self, value, nested=False): + # type: (Any, bool) -> Tuple[int, Observables] """Returns estimated size of value along with any nested observables.""" if isinstance(value, observable.ObservableMixin): # Should never be here. # TODO(robertwb): Remove when coders are set correctly. return 0, [(value, self._value_coder)] estimated_size = 0 - observables = [] + observables = [] # type: Observables value_estimated_size, value_observables = ( self._value_coder.get_estimated_size_and_observables( value.value, nested=nested)) @@ -1094,17 +1176,21 @@ class LengthPrefixCoderImpl(StreamCoderImpl): Coder which prefixes the length of the encoded object in the stream.""" def __init__(self, value_coder): + # type: (CoderImpl) -> None self._value_coder = value_coder def encode_to_stream(self, value, out, nested): + # type: (Any, create_OutputStream, bool) -> None encoded_value = self._value_coder.encode(value) out.write_var_int64(len(encoded_value)) out.write(encoded_value) def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> Any value_length = in_stream.read_var_int64() return self._value_coder.decode(in_stream.read(value_length)) def estimate_size(self, value, nested=False): + # type: (Any, bool) -> int value_size = self._value_coder.estimate_size(value) return get_varint_size(value_size) + value_size diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 35020b624cbe..a2bc85e7834e 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -23,8 +23,19 @@ import base64 import sys -import typing from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import overload import google.protobuf.wrappers_pb2 from future.moves import pickle @@ -38,6 +49,11 @@ from apache_beam.typehints import typehints from apache_beam.utils import proto_utils +if TYPE_CHECKING: + from google.protobuf import message # pylint: disable=ungrouped-imports + from apache_beam.coders.typecoders import CoderRegistry + from apache_beam.runners.pipeline_context import PipelineContext + # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: from .stream import get_varint_size @@ -67,6 +83,15 @@ 'WindowedValueCoder' ] +T = TypeVar('T') +CoderT = TypeVar('CoderT', bound='Coder') +ProtoCoderT = TypeVar('ProtoCoderT', bound='ProtoCoder') +ConstructorFn = Callable[ + [Optional[Any], + List['Coder'], + 'PipelineContext'], + Any] + def serialize_coder(coder): from apache_beam.internal import pickler @@ -84,6 +109,7 @@ class Coder(object): """Base class for coders.""" def encode(self, value): + # type: (Any) -> bytes """Encodes the given object into a byte string.""" raise NotImplementedError('Encode not implemented: %s.' % self) @@ -100,6 +126,7 @@ def decode_nested(self, encoded): return self.get_impl().decode_nested(encoded) def is_deterministic(self): + # type: () -> bool """Whether this coder is guaranteed to encode values deterministically. A deterministic coder is required for key coders in GroupByKey operations @@ -152,6 +179,7 @@ def estimate_size(self, value): # =========================================================================== def _create_impl(self): + # type: () -> coder_impl.CoderImpl """Creates a CoderImpl to do the actual encoding and decoding. """ return coder_impl.CallbackCoderImpl(self.encode, self.decode, @@ -182,25 +210,30 @@ def to_type_hint(self): @classmethod def from_type_hint(cls, unused_typehint, unused_registry): + # type: (Type[CoderT], Any, CoderRegistry) -> CoderT # If not overridden, just construct the coder without arguments. return cls() def is_kv_coder(self): + # () -> bool return False def key_coder(self): + # type: () -> Coder if self.is_kv_coder(): raise NotImplementedError('key_coder: %s' % self) else: raise ValueError('Not a KV coder: %s.' % self) def value_coder(self): + # type: () -> Coder if self.is_kv_coder(): raise NotImplementedError('value_coder: %s' % self) else: raise ValueError('Not a KV coder: %s.' % self) def _get_component_coders(self): + # type: () -> Sequence[Coder] """For internal use only; no backwards-compatibility guarantees. Returns the internal component coders of this coder.""" @@ -248,7 +281,26 @@ def __ne__(self, other): def __hash__(self): return hash(type(self)) - _known_urns = {} + _known_urns = {} # type: Dict[str, Tuple[type, ConstructorFn]] + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: Optional[Type[T]] + ): + # type: (...) -> Callable[[Callable[[T, List[Coder], PipelineContext], Any]], Callable[[T, List[Coder], PipelineContext], Any]] + pass + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: Optional[Type[T]] + fn # type: Callable[[T, List[Coder], PipelineContext], Any] + ): + # type: (...) -> None + pass @classmethod def register_urn(cls, urn, parameter_type, fn=None): @@ -274,6 +326,7 @@ def register(fn): return register def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.Coder urn, typed_param, components = self.to_runner_api_parameter(context) return beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.FunctionSpec( @@ -285,6 +338,7 @@ def to_runner_api(self, context): @classmethod def from_runner_api(cls, coder_proto, context): + # type: (Type[CoderT], beam_runner_api_pb2.Coder, PipelineContext) -> CoderT """Converts from an FunctionSpec to a Fn object. Prefer registering a urn with its parameter type and constructor. @@ -299,10 +353,11 @@ def from_runner_api(cls, coder_proto, context): context) except Exception: if context.allow_proto_holders: - return RunnerAPICoderHolder(coder_proto) + return RunnerAPICoderHolder(coder_proto) # type: ignore # too ambiguous raise def to_runner_api_parameter(self, context): + # type: (Optional[PipelineContext]) -> Tuple[str, Any, Sequence[Coder]] return ( python_urns.PICKLED_CODER, google.protobuf.wrappers_pb2.BytesValue(value=serialize_coder(self)), @@ -310,6 +365,7 @@ def to_runner_api_parameter(self, context): @staticmethod def register_structured_urn(urn, cls): + # type: (str, Type[Coder]) -> None """Register a coder that's completely defined by its urn and its component(s), if any, which are passed to construct the instance. """ @@ -403,6 +459,7 @@ def _create_impl(self): return coder_impl.BytesCoderImpl() def is_deterministic(self): + # type: () -> bool return True def to_type_hint(self): @@ -450,6 +507,7 @@ def _create_impl(self): return coder_impl.VarIntCoderImpl() def is_deterministic(self): + # type: () -> bool return True def to_type_hint(self): @@ -477,6 +535,7 @@ def _create_impl(self): return coder_impl.FloatCoderImpl() def is_deterministic(self): + # type: () -> bool return True def to_type_hint(self): @@ -499,6 +558,7 @@ def _create_impl(self): return coder_impl.TimestampCoderImpl() def is_deterministic(self): + # () -> bool return True def __eq__(self, other): @@ -513,15 +573,18 @@ class _TimerCoder(FastCoder): For internal use.""" def __init__(self, payload_coder): + # type: (Coder) -> None self._payload_coder = payload_coder def _get_component_coders(self): + # type: () -> List[Coder] return [self._payload_coder] def _create_impl(self): return coder_impl.TimerCoderImpl(self._payload_coder.get_impl()) def is_deterministic(self): + # () -> bool return self._payload_coder.is_deterministic() def __eq__(self, other): @@ -546,6 +609,7 @@ def _create_impl(self): return coder_impl.SingletonCoderImpl(self._value) def is_deterministic(self): + # () -> bool return True def __eq__(self, other): @@ -577,6 +641,7 @@ class _PickleCoderBase(FastCoder): """Base class for pickling coders.""" def is_deterministic(self): + # () -> bool # Note that the default coder, the PickleCoder, is not deterministic (for # example, the ordering of picked entries in maps may vary across # executions), and so is not in general suitable for usage as a key coder in @@ -602,6 +667,7 @@ def as_cloud_object(self, coders_context=None, is_pair_like=True): # we can't always infer the return values of lambdas in ParDo operations, the # result of which may be used in a GroupBykey. def is_kv_coder(self): + # () -> bool return True def key_coder(self): @@ -630,7 +696,7 @@ def as_deterministic_coder(self, step_label, error_message=None): return DeterministicFastPrimitivesCoder(self, step_label) def to_type_hint(self): - return typing.Any + return Any class DillCoder(_PickleCoderBase): @@ -652,9 +718,11 @@ def _create_impl(self): self._underlying_coder.get_impl(), self._step_label) def is_deterministic(self): + # () -> bool return True def is_kv_coder(self): + # () -> bool return True def key_coder(self): @@ -664,7 +732,7 @@ def value_coder(self): return self def to_type_hint(self): - return typing.Any + return Any class FastPrimitivesCoder(FastCoder): @@ -673,6 +741,7 @@ class FastPrimitivesCoder(FastCoder): For unknown types, falls back to another coder (e.g. PickleCoder). """ def __init__(self, fallback_coder=PickleCoder()): + # type: (Coder) -> None self._fallback_coder = fallback_coder def _create_impl(self): @@ -680,6 +749,7 @@ def _create_impl(self): self._fallback_coder.get_impl()) def is_deterministic(self): + # () -> bool return self._fallback_coder.is_deterministic() def as_deterministic_coder(self, step_label, error_message=None): @@ -689,7 +759,7 @@ def as_deterministic_coder(self, step_label, error_message=None): return DeterministicFastPrimitivesCoder(self, step_label) def to_type_hint(self): - return typing.Any + return Any def as_cloud_object(self, coders_context=None, is_pair_like=True): value = super(FastCoder, self).as_cloud_object(coders_context) @@ -710,6 +780,7 @@ def as_cloud_object(self, coders_context=None, is_pair_like=True): # since we can't always infer the return values of lambdas in ParDo # operations, the result of which may be used in a GroupBykey. def is_kv_coder(self): + # () -> bool return True def key_coder(self): @@ -737,6 +808,7 @@ def decode(self, encoded): return pickle.loads(base64.b64decode(encoded)) def is_deterministic(self): + # () -> bool # Note that the Base64PickleCoder is not deterministic. See the # corresponding comments for PickleCoder above. return False @@ -771,12 +843,14 @@ class ProtoCoder(FastCoder): """ def __init__(self, proto_message_type): + # type: (google.protobuf.message.Message) -> None self.proto_message_type = proto_message_type def _create_impl(self): return coder_impl.ProtoCoderImpl(self.proto_message_type) def is_deterministic(self): + # () -> bool # TODO(vikasrk): A proto message can be deterministic if it does not contain # a Map. return False @@ -813,6 +887,7 @@ def _create_impl(self): return coder_impl.DeterministicProtoCoderImpl(self.proto_message_type) def is_deterministic(self): + # () -> bool return True def as_deterministic_coder(self, step_label, error_message=None): @@ -857,12 +932,14 @@ class TupleCoder(FastCoder): """Coder of tuple objects.""" def __init__(self, components): + # type: (Iterable[Coder]) -> None self._coders = tuple(components) def _create_impl(self): return coder_impl.TupleCoderImpl([c.get_impl() for c in self._coders]) def is_deterministic(self): + # () -> bool return all(c.is_deterministic() for c in self._coders) def as_deterministic_coder(self, step_label, error_message=None): @@ -877,6 +954,7 @@ def to_type_hint(self): @staticmethod def from_type_hint(typehint, registry): + # type: (typehints.TupleConstraint, CoderRegistry) -> TupleCoder return TupleCoder([registry.get_coder(t) for t in typehint.tuple_types]) def as_cloud_object(self, coders_context=None): @@ -895,20 +973,25 @@ def as_cloud_object(self, coders_context=None): return super(TupleCoder, self).as_cloud_object(coders_context) def _get_component_coders(self): + # type: () -> Tuple[Coder, ...] return self.coders() def coders(self): + # type: () -> Tuple[Coder, ...] return self._coders def is_kv_coder(self): + # () -> bool return len(self._coders) == 2 def key_coder(self): + # type: () -> Coder if len(self._coders) != 2: raise ValueError('TupleCoder does not have exactly 2 components.') return self._coders[0] def value_coder(self): + # type: () -> Coder if len(self._coders) != 2: raise ValueError('TupleCoder does not have exactly 2 components.') return self._coders[1] @@ -938,6 +1021,7 @@ class TupleSequenceCoder(FastCoder): """Coder of homogeneous tuple objects.""" def __init__(self, elem_coder): + # type: (Coder) -> None self._elem_coder = elem_coder def value_coder(self): @@ -947,6 +1031,7 @@ def _create_impl(self): return coder_impl.TupleSequenceCoderImpl(self._elem_coder.get_impl()) def is_deterministic(self): + # () -> bool return self._elem_coder.is_deterministic() def as_deterministic_coder(self, step_label, error_message=None): @@ -958,9 +1043,11 @@ def as_deterministic_coder(self, step_label, error_message=None): @staticmethod def from_type_hint(typehint, registry): + # type: (Any, CoderRegistry) -> TupleSequenceCoder return TupleSequenceCoder(registry.get_coder(typehint.inner_type)) def _get_component_coders(self): + # type: () -> Tuple[Coder, ...] return (self._elem_coder,) def __repr__(self): @@ -978,12 +1065,14 @@ class IterableCoder(FastCoder): """Coder of iterables of homogeneous objects.""" def __init__(self, elem_coder): + # type: (Coder) -> None self._elem_coder = elem_coder def _create_impl(self): return coder_impl.IterableCoderImpl(self._elem_coder.get_impl()) def is_deterministic(self): + # () -> bool return self._elem_coder.is_deterministic() def as_deterministic_coder(self, step_label, error_message=None): @@ -1012,9 +1101,11 @@ def to_type_hint(self): @staticmethod def from_type_hint(typehint, registry): + # type: (Any, CoderRegistry) -> IterableCoder return IterableCoder(registry.get_coder(typehint.inner_type)) def _get_component_coders(self): + # type: () -> Tuple[Coder, ...] return (self._elem_coder,) def __repr__(self): @@ -1055,6 +1146,7 @@ def _create_impl(self): return coder_impl.IntervalWindowCoderImpl() def is_deterministic(self): + # () -> bool return True def as_cloud_object(self, coders_context=None): @@ -1077,6 +1169,7 @@ class WindowedValueCoder(FastCoder): """Coder for windowed values.""" def __init__(self, wrapped_value_coder, window_coder=None): + # type: (Coder, Optional[Coder]) -> None if not window_coder: window_coder = PickleCoder() self.wrapped_value_coder = wrapped_value_coder @@ -1090,6 +1183,7 @@ def _create_impl(self): self.window_coder.get_impl()) def is_deterministic(self): + # () -> bool return all(c.is_deterministic() for c in [self.wrapped_value_coder, self.timestamp_coder, self.window_coder]) @@ -1107,15 +1201,19 @@ def as_cloud_object(self, coders_context=None): } def _get_component_coders(self): + # type: () -> List[Coder] return [self.wrapped_value_coder, self.window_coder] def is_kv_coder(self): + # () -> bool return self.wrapped_value_coder.is_kv_coder() def key_coder(self): + # type: () -> Coder return self.wrapped_value_coder.key_coder() def value_coder(self): + # type: () -> Coder return self.wrapped_value_coder.value_coder() def __repr__(self): @@ -1142,12 +1240,14 @@ class LengthPrefixCoder(FastCoder): Coder which prefixes the length of the encoded object in the stream.""" def __init__(self, value_coder): + # type: (Coder) -> None self._value_coder = value_coder def _create_impl(self): return coder_impl.LengthPrefixCoderImpl(self._value_coder.get_impl()) def is_deterministic(self): + # () -> bool return self._value_coder.is_deterministic() def estimate_size(self, value): @@ -1167,6 +1267,7 @@ def as_cloud_object(self, coders_context=None): } def _get_component_coders(self): + # type: () -> Tuple[Coder, ...] return (self._value_coder,) def __repr__(self): @@ -1187,9 +1288,9 @@ def __hash__(self): class StateBackedIterableCoder(FastCoder): def __init__( self, - element_coder, - read_state=None, - write_state=None, + element_coder, # type: Coder + read_state=None, # type: Optional[coder_impl.IterableStateReader] + write_state=None, # type: Optional[coder_impl.IterableStateWriter] write_state_threshold=1): self._element_coder = element_coder self._read_state = read_state @@ -1204,9 +1305,11 @@ def _create_impl(self): self._write_state_threshold) def is_deterministic(self): + # () -> bool return False def _get_component_coders(self): + # type: () -> Tuple[Coder, ...] return (self._element_coder,) def __repr__(self): @@ -1221,6 +1324,7 @@ def __hash__(self): return hash((type(self), self._element_coder, self._write_state_threshold)) def to_runner_api_parameter(self, context): + # type: (Optional[PipelineContext]) -> Tuple[str, Any, Sequence[Coder]] return ( common_urns.coders.STATE_BACKED_ITERABLE.urn, str(self._write_state_threshold).encode('ascii'), @@ -1254,4 +1358,4 @@ def to_runner_api(self, context): return self._proto def to_type_hint(self): - return typing.Any + return Any diff --git a/sdks/python/apache_beam/coders/observable_test.py b/sdks/python/apache_beam/coders/observable_test.py index a56a3208486f..fc3c4102f58b 100644 --- a/sdks/python/apache_beam/coders/observable_test.py +++ b/sdks/python/apache_beam/coders/observable_test.py @@ -20,6 +20,8 @@ import logging import unittest +from typing import List +from typing import Optional from apache_beam.coders import observable @@ -27,7 +29,7 @@ class ObservableMixinTest(unittest.TestCase): observed_count = 0 observed_sum = 0 - observed_keys = [] + observed_keys = [] # type: List[Optional[str]] def observer(self, value, key=None): self.observed_count += 1 diff --git a/sdks/python/apache_beam/coders/slow_stream.py b/sdks/python/apache_beam/coders/slow_stream.py index efd5434e8d8d..08a6c8374066 100644 --- a/sdks/python/apache_beam/coders/slow_stream.py +++ b/sdks/python/apache_beam/coders/slow_stream.py @@ -25,6 +25,7 @@ import sys from builtins import chr from builtins import object +from typing import List class OutputStream(object): @@ -33,10 +34,11 @@ class OutputStream(object): A pure Python implementation of stream.OutputStream.""" def __init__(self): - self.data = [] + self.data = [] # type: List[bytes] self.byte_count = 0 def write(self, b, nested=False): + # type: (bytes, bool) -> None assert isinstance(b, bytes) if nested: self.write_var_int64(len(b)) @@ -48,6 +50,7 @@ def write_byte(self, val): self.byte_count += 1 def write_var_int64(self, v): + # type: (int) -> None if v < 0: v += 1 << 64 if v <= 0: @@ -74,12 +77,15 @@ def write_bigendian_double(self, v): self.write(struct.pack('>d', v)) def get(self): + # type: () -> bytes return b''.join(self.data) def size(self): + # type: () -> int return self.byte_count def _clear(self): + # type: () -> None self.data = [] self.byte_count = 0 @@ -95,6 +101,7 @@ def __init__(self): self.count = 0 def write(self, byte_array, nested=False): + # type: (bytes, bool) -> None blen = len(byte_array) if nested: self.write_var_int64(blen) @@ -119,6 +126,7 @@ class InputStream(object): A pure Python implementation of stream.InputStream.""" def __init__(self, data): + # type: (bytes) -> None self.data = data self.pos = 0 @@ -139,18 +147,22 @@ def size(self): return len(self.data) - self.pos def read(self, size): + # type: (int) -> bytes self.pos += size return self.data[self.pos - size : self.pos] def read_all(self, nested): + # type: (bool) -> bytes return self.read(self.read_var_int64() if nested else self.size()) def read_byte_py2(self): + # type: () -> int self.pos += 1 # mypy tests against python 3.x, where this is an error: return ord(self.data[self.pos - 1]) # type: ignore[arg-type] def read_byte_py3(self): + # type: () -> int self.pos += 1 return self.data[self.pos - 1] diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py index 606ca811ed87..fbc80b4d3a8b 100644 --- a/sdks/python/apache_beam/coders/standard_coders_test.py +++ b/sdks/python/apache_beam/coders/standard_coders_test.py @@ -27,6 +27,8 @@ import sys import unittest from builtins import map +from typing import Dict +from typing import Tuple import yaml @@ -147,7 +149,7 @@ def json_value_parser(self, coder_spec): # Used when --fix is passed. fix = False - to_fix = {} + to_fix = {} # type: Dict[Tuple[int, bytes], bytes] @classmethod def tearDownClass(cls): diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 6f6f3229f451..b957f67440e2 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -66,6 +66,11 @@ def MakeXyzs(v): from __future__ import absolute_import from builtins import object +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Type from past.builtins import unicode @@ -79,8 +84,8 @@ class CoderRegistry(object): """A coder registry for typehint/coder associations.""" def __init__(self, fallback_coder=None): - self._coders = {} - self.custom_types = [] + self._coders = {} # type: Dict[Any, Type[coders.Coder]] + self.custom_types = [] # type: List[Any] self.register_standard_coders(fallback_coder) def register_standard_coders(self, fallback_coder): @@ -97,9 +102,11 @@ def register_standard_coders(self, fallback_coder): self._fallback_coder = fallback_coder or FirstOf(default_fallback_coders) def _register_coder_internal(self, typehint_type, typehint_coder_class): + # type: (Any, Type[coders.Coder]) -> None self._coders[typehint_type] = typehint_coder_class def register_coder(self, typehint_type, typehint_coder_class): + # type: (Any, Type[coders.Coder]) -> None if not isinstance(typehint_coder_class, type): raise TypeError('Coder registration requires a coder class object. ' 'Received %r instead.' % typehint_coder_class) @@ -108,6 +115,7 @@ def register_coder(self, typehint_type, typehint_coder_class): self._register_coder_internal(typehint_type, typehint_coder_class) def get_coder(self, typehint): + # type: (Any) -> coders.Coder coder = self._coders.get( typehint.__class__ if isinstance(typehint, typehints.TypeConstraint) else typehint, None) @@ -164,6 +172,7 @@ class FirstOf(object): A class used to get the first matching coder from a list of coders.""" def __init__(self, coders): + # type: (Iterable[Type[coders.Coder]]) -> None self._coders = coders def from_type_hint(self, typehint, registry): diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index ca5c4a5662d0..3c1204cddcff 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -24,6 +24,8 @@ import string import unittest import uuid +from typing import TYPE_CHECKING +from typing import List import pytz @@ -47,8 +49,10 @@ _microseconds_from_datetime = lambda label_stamp: label_stamp _datetime_from_microseconds = lambda micro: micro +if TYPE_CHECKING: + import google.cloud.bigtable.instance -EXISTING_INSTANCES = [] +EXISTING_INSTANCES = [] # type: List[google.cloud.bigtable.instance.Instance] LABEL_KEY = u'python-bigtable-beam' label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC) label_stamp_micros = _microseconds_from_datetime(label_stamp) diff --git a/sdks/python/apache_beam/internal/pickler.py b/sdks/python/apache_beam/internal/pickler.py index ec8db53e270f..e06e2a1e1359 100644 --- a/sdks/python/apache_beam/internal/pickler.py +++ b/sdks/python/apache_beam/internal/pickler.py @@ -36,6 +36,9 @@ import traceback import types import zlib +from typing import Any +from typing import Dict +from typing import Tuple import dill @@ -157,7 +160,7 @@ def save_module(pickler, obj): # Pickle module dictionaries (commonly found in lambda's globals) # by referencing their module. old_save_module_dict = dill.dill.save_module_dict - known_module_dicts = {} + known_module_dicts = {} # type: Dict[int, Tuple[types.ModuleType, Dict[str, Any]]] @dill.dill.register(dict) def new_save_module_dict(pickler, obj): @@ -227,6 +230,7 @@ def new_log_info(msg, *args, **kwargs): # pickler.loads() being used for data, which results in an unnecessary base64 # encoding. This should be cleaned up. def dumps(o, enable_trace=True): + # type: (...) -> bytes """For internal use only; no backwards-compatibility guarantees.""" try: diff --git a/sdks/python/apache_beam/internal/util.py b/sdks/python/apache_beam/internal/util.py index 499214f445d4..6b13a378bbc4 100644 --- a/sdks/python/apache_beam/internal/util.py +++ b/sdks/python/apache_beam/internal/util.py @@ -27,6 +27,16 @@ import weakref from builtins import object from multiprocessing.pool import ThreadPool +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union + +T = TypeVar('T') class ArgumentPlaceholder(object): @@ -62,7 +72,11 @@ def __hash__(self): return hash(type(self)) -def remove_objects_from_args(args, kwargs, pvalue_class): +def remove_objects_from_args(args, # type: Iterable[Any] + kwargs, # type: Dict[str, Any] + pvalue_class # type: Union[Type[T], Tuple[Type[T], ...]] + ): + # type: (...) -> Tuple[List[Any], Dict[str, Any], List[T]] """For internal use only; no backwards-compatibility guarantees. Replaces all objects of a given type in args/kwargs with a placeholder. diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 5ea9b9b0238a..1b277821a381 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -24,6 +24,7 @@ import tempfile import unittest from builtins import range +from typing import List import sys # patches unittest.TestCase to be python3 compatible @@ -90,7 +91,7 @@ class AvroBase(object): - _temp_files = [] + _temp_files = [] # type: List[str] def __init__(self, methodName='runTest'): super(AvroBase, self).__init__(methodName) diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py index ec16b060c329..0f87d25cca74 100644 --- a/sdks/python/apache_beam/io/filebasedsource.py +++ b/sdks/python/apache_beam/io/filebasedsource.py @@ -28,6 +28,8 @@ from __future__ import absolute_import +from typing import Callable + from past.builtins import long from past.builtins import unicode @@ -71,7 +73,7 @@ def __init__(self, file_pattern (str): the file glob to read a string or a :class:`~apache_beam.options.value_provider.ValueProvider` (placeholder to inject a runtime value). - min_bundle_size (str): minimum size of bundles that should be generated + min_bundle_size (int): minimum size of bundles that should be generated when performing initial splitting on this source. compression_type (str): Used to handle compressed output files. Typical value is :attr:`CompressionTypes.AUTO @@ -128,6 +130,7 @@ def display_data(self): @check_accessible(['_pattern']) def _get_concat_source(self): + # type: () -> concat_source.ConcatSource if self._concat_source is None: pattern = self._pattern.get() @@ -358,6 +361,7 @@ def process(self, element, *args, **kwargs): class _ReadRange(DoFn): def __init__(self, source_from_file): + # type: (Callable[[str], iobase.BoundedSource]) -> None self._source_from_file = source_from_file def process(self, element, *args, **kwargs): @@ -380,9 +384,13 @@ class ReadAllFiles(PTransform): read a PCollection of files. """ - def __init__( - self, splittable, compression_type, desired_bundle_size, min_bundle_size, - source_from_file): + def __init__(self, + splittable, # type: bool + compression_type, + desired_bundle_size, # type: int + min_bundle_size, # type: int + source_from_file, # type: Callable[[str], iobase.BoundedSource] + ): """ Args: splittable: If False, files won't be split into sub-ranges. If True, diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index a1a7a589ca84..748cc2f84850 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -94,6 +94,13 @@ import logging import random import uuid +from typing import TYPE_CHECKING +from typing import Any +from typing import BinaryIO +from typing import Callable +from typing import DefaultDict +from typing import Dict +from typing import Tuple from past.builtins import unicode @@ -107,6 +114,9 @@ from apache_beam.transforms.window import GlobalWindow from apache_beam.utils.annotations import experimental +if TYPE_CHECKING: + from apache_beam.transforms.window import BoundedWindow + __all__ = ['EmptyMatchTreatment', 'MatchFiles', 'MatchAll', @@ -261,6 +271,7 @@ class FileSink(object): """ def open(self, fh): + # type: (BinaryIO) -> None raise NotImplementedError def write(self, record): @@ -443,6 +454,7 @@ def __init__(self, @staticmethod def _get_sink_fn(input_sink): + # type: (...) -> Callable[[Any], FileSink] if isinstance(input_sink, FileSink): return lambda x: input_sink elif callable(input_sink): @@ -452,6 +464,7 @@ def _get_sink_fn(input_sink): @staticmethod def _get_destination_fn(destination): + # type: (...) -> Callable[[Any], str] if isinstance(destination, ValueProvider): return lambda elm: destination.get() elif callable(destination): @@ -592,7 +605,11 @@ def _remove_temporary_files(self, writer_key): class _WriteShardedRecordsFn(beam.DoFn): - def __init__(self, base_path, sink_fn, shards): + def __init__(self, + base_path, + sink_fn, # type: Callable[[Any], FileSink] + shards # type: int + ): self.base_path = base_path self.sink_fn = sink_fn self.shards = shards @@ -630,13 +647,16 @@ def process(self, class _AppendShardedDestination(beam.DoFn): - def __init__(self, destination, shards): + def __init__(self, + destination, # type: Callable[[Any], str] + shards # type: int + ): self.destination_fn = destination self.shards = shards # We start the shards for a single destination at an arbitrary point. self._shard_counter = collections.defaultdict( - lambda: random.randrange(self.shards)) + lambda: random.randrange(self.shards)) # type: DefaultDict[str, int] def _next_shard_for_destination(self, destination): self._shard_counter[destination] = ( @@ -656,6 +676,9 @@ class _WriteUnshardedRecordsFn(beam.DoFn): SPILLED_RECORDS = 'spilled_records' WRITTEN_FILES = 'written_files' + _writers_and_sinks = None # type: Dict[Tuple[str, BoundedWindow], Tuple[BinaryIO, FileSink]] + _file_names = None # type: Dict[Tuple[str, BoundedWindow], str] + def __init__(self, base_path, destination_fn, diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py index c2bc312497f4..24f8d826ad80 100644 --- a/sdks/python/apache_beam/io/filesystem.py +++ b/sdks/python/apache_beam/io/filesystem.py @@ -35,6 +35,8 @@ import zlib from builtins import object from builtins import zip +from typing import BinaryIO +from typing import Tuple from future.utils import with_metaclass from past.builtins import long @@ -478,6 +480,7 @@ def scheme(cls): @abc.abstractmethod def join(self, basepath, *paths): + # type: (str, *str) -> str """Join two or more pathname components for the filesystem Args: @@ -490,6 +493,7 @@ def join(self, basepath, *paths): @abc.abstractmethod def split(self, path): + # type: (str) -> Tuple[str, str] """Splits the given path into two parts. Splits the path into a pair (head, tail) such that tail contains the last @@ -717,6 +721,7 @@ def _match(pattern, limit): @abc.abstractmethod def create(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a write channel for the given file path. Args: @@ -731,6 +736,7 @@ def create(self, path, mime_type='application/octet-stream', @abc.abstractmethod def open(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a read channel for the given file path. Args: @@ -771,6 +777,7 @@ def rename(self, source_file_names, destination_file_names): @abc.abstractmethod def exists(self, path): + # type: (str) -> bool """Check if the provided path exists on the FileSystem. Args: @@ -782,6 +789,7 @@ def exists(self, path): @abc.abstractmethod def size(self, path): + # type: (str) -> int """Get size in bytes of a file on the FileSystem. Args: diff --git a/sdks/python/apache_beam/io/filesystems.py b/sdks/python/apache_beam/io/filesystems.py index d8b3a4a54542..e907ffef0921 100644 --- a/sdks/python/apache_beam/io/filesystems.py +++ b/sdks/python/apache_beam/io/filesystems.py @@ -21,6 +21,7 @@ import re from builtins import object +from typing import BinaryIO from past.builtins import unicode @@ -82,6 +83,7 @@ def get_scheme(path): @staticmethod def get_filesystem(path): + # type: (str) -> FileSystems """Get the correct filesystem for the specified path """ try: @@ -105,6 +107,7 @@ def get_filesystem(path): @staticmethod def join(basepath, *paths): + # type: (str, *str) -> str """Join two or more pathname components for the filesystem Args: @@ -189,6 +192,7 @@ def match(patterns, limits=None): @staticmethod def create(path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a write channel for the given file path. Args: @@ -205,6 +209,7 @@ def create(path, mime_type='application/octet-stream', @staticmethod def open(path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py index 1edf743408a0..a151409df101 100644 --- a/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py +++ b/sdks/python/apache_beam/io/flink/flink_streaming_impulse_source.py @@ -23,6 +23,8 @@ from __future__ import absolute_import import json +from typing import Any +from typing import Dict from apache_beam import PTransform from apache_beam import Windowing @@ -33,7 +35,7 @@ class FlinkStreamingImpulseSource(PTransform): URN = "flink:transform:streaming_impulse:v1" - config = {} + config = {} # type: Dict[str, Any] def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin), ( diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py index 6bb9c9818f27..6243c4dca0a4 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py @@ -21,6 +21,7 @@ import sys import unittest +from typing import Type # patches unittest.TestCase to be python3 compatible import future.tests.base # pylint: disable=unused-import @@ -65,7 +66,7 @@ def create_query(self, kinds=(), order=False, limit=None, offset=None, test_filter.property_filter.op = PropertyFilter.GREATER_THAN return query - split_error = ValueError + split_error = ValueError # type: Type[Exception] query_splitter = query_splitter def test_get_splits_query_with_multiple_kinds(self): diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py index a5e9ce3f8700..f29851a9207c 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py @@ -28,6 +28,8 @@ import time import uuid from builtins import range +from typing import List +from typing import Union from google.api_core import exceptions from google.cloud import environment_vars @@ -115,7 +117,7 @@ def write_mutations(batch, throttler, rpc_stats_callback, throttle_delay=1): def create_entities(count, id_or_name=False): """Creates a list of entities with random keys.""" if id_or_name: - ids_or_names = [uuid.uuid4().int & ((1 << 63) - 1) for _ in range(count)] + ids_or_names = [uuid.uuid4().int & ((1 << 63) - 1) for _ in range(count)] # type: List[Union[str, int]] else: ids_or_names = [str(uuid.uuid4()) for _ in range(count)] diff --git a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py index a6081c8e406f..4f4abc6d1771 100644 --- a/sdks/python/apache_beam/io/gcp/gcsfilesystem.py +++ b/sdks/python/apache_beam/io/gcp/gcsfilesystem.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from builtins import zip +from typing import BinaryIO from future.utils import iteritems @@ -139,6 +140,7 @@ def _path_open(self, path, mode, mime_type='application/octet-stream', def create(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a write channel for the given file path. Args: @@ -152,6 +154,7 @@ def create(self, path, mime_type='application/octet-stream', def open(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 0711b70da512..b2db2ebb6695 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -26,6 +26,8 @@ import re from builtins import object +from typing import Any +from typing import Optional from future.utils import iteritems from past.builtins import unicode @@ -87,6 +89,7 @@ def __repr__(self): @staticmethod def _from_proto_str(proto_msg): + # type: (bytes) -> PubsubMessage """Construct from serialized form of ``PubsubMessage``. Args: @@ -121,6 +124,7 @@ def _to_proto_str(self): @staticmethod def _from_message(msg): + # type: (Any) -> PubsubMessage """Construct from ``google.cloud.pubsub_v1.subscriber.message.Message``. https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html @@ -134,8 +138,14 @@ class ReadFromPubSub(PTransform): """A ``PTransform`` for reading from Cloud Pub/Sub.""" # Implementation note: This ``PTransform`` is overridden by Directrunner. - def __init__(self, topic=None, subscription=None, id_label=None, - with_attributes=False, timestamp_attribute=None): + def __init__(self, + topic=None, # type: Optional[str] + subscription=None, # type: Optional[str] + id_label=None, # type: Optional[str] + with_attributes=False, # type: bool + timestamp_attribute=None # type: Optional[str] + ): + # type: (...) -> None """Initializes ``ReadFromPubSub``. Args: @@ -242,8 +252,13 @@ class WriteToPubSub(PTransform): """A ``PTransform`` for writing messages to Cloud Pub/Sub.""" # Implementation note: This ``PTransform`` is overridden by Directrunner. - def __init__(self, topic, with_attributes=False, id_label=None, - timestamp_attribute=None): + def __init__(self, + topic, # type: str + with_attributes=False, # type: bool + id_label=None, # type: Optional[str] + timestamp_attribute=None # type: Optional[str] + ): + # type: (...) -> None """Initializes ``WriteToPubSub``. Args: @@ -267,6 +282,7 @@ def __init__(self, topic, with_attributes=False, id_label=None, @staticmethod def to_proto_str(element): + # type: (PubsubMessage) -> bytes if not isinstance(element, PubsubMessage): raise TypeError('Unexpected element. Type: %s (expected: PubsubMessage), ' 'value: %r' % (type(element), element)) @@ -327,8 +343,13 @@ class _PubSubSource(dataflow_io.NativeSource): fetches ``PubsubMessage`` protobufs. """ - def __init__(self, topic=None, subscription=None, id_label=None, - with_attributes=False, timestamp_attribute=None): + def __init__(self, + topic=None, # type: Optional[str] + subscription=None, # type: Optional[str] + id_label=None, # type: Optional[str] + with_attributes=False, # type: bool + timestamp_attribute=None # type: Optional[str] + ): self.coder = coders.BytesCoder() self.full_topic = topic self.full_subscription = subscription @@ -385,7 +406,12 @@ class _PubSubSink(dataflow_io.NativeSink): This ``NativeSource`` is overridden by a native Pubsub implementation. """ - def __init__(self, topic, id_label, with_attributes, timestamp_attribute): + def __init__(self, + topic, # type: str + id_label, # type: Optional[str] + with_attributes, # type: bool + timestamp_attribute # type: Optional[str] + ): self.coder = coders.BytesCoder() self.full_topic = topic self.id_label = id_label diff --git a/sdks/python/apache_beam/io/hadoopfilesystem.py b/sdks/python/apache_beam/io/hadoopfilesystem.py index 71d74e893497..efab49df31b9 100644 --- a/sdks/python/apache_beam/io/hadoopfilesystem.py +++ b/sdks/python/apache_beam/io/hadoopfilesystem.py @@ -25,6 +25,7 @@ import posixpath import re from builtins import zip +from typing import BinaryIO import hdfs @@ -207,6 +208,7 @@ def _add_compression(stream, path, mime_type, compression_type): def create(self, url, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """ Returns: A Python File-like object. @@ -224,6 +226,7 @@ def _create(self, path, mime_type='application/octet-stream', def open(self, url, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """ Returns: A Python File-like object. @@ -314,6 +317,7 @@ def rename(self, source_file_names, destination_file_names): raise BeamIOError('Rename operation failed', exceptions) def exists(self, url): + # type: (str) -> bool """Checks existence of url in HDFS. Args: diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 5b6673089fac..259cee85bd73 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -39,6 +39,12 @@ from builtins import object from builtins import range from collections import namedtuple +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import Tuple from apache_beam import coders from apache_beam import pvalue @@ -56,6 +62,11 @@ from apache_beam.utils import urns from apache_beam.utils.windowed_value import WindowedValue +if TYPE_CHECKING: + from apache_beam.io import restriction_trackers + from apache_beam.runners.pipeline_context import PipelineContext + from apache_beam.utils.timestamp import Timestamp + __all__ = ['BoundedSource', 'RangeTracker', 'Read', 'RestrictionTracker', 'Sink', 'Write', 'Writer'] @@ -86,6 +97,10 @@ class SourceBase(HasDisplayData, urns.RunnerApiFn): """ urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE) + def is_bounded(self): + # type: () -> bool + raise NotImplementedError + class BoundedSource(SourceBase): """A source that reads a finite amount of input records. @@ -124,6 +139,7 @@ class BoundedSource(SourceBase): """ def estimate_size(self): + # type: () -> Optional[int] """Estimates the size of source in bytes. An estimate of the total size (in bytes) of the data that would be read @@ -136,7 +152,12 @@ def estimate_size(self): """ raise NotImplementedError - def split(self, desired_bundle_size, start_position=None, stop_position=None): + def split(self, + desired_bundle_size, # type: int + start_position=None, # type: Optional[int] + stop_position=None, # type: Optional[int] + ): + # type: (...) -> Iterator[SourceBundle] """Splits the source into a set of bundles. Bundles should be approximately of size ``desired_bundle_size`` bytes. @@ -153,7 +174,11 @@ def split(self, desired_bundle_size, start_position=None, stop_position=None): """ raise NotImplementedError - def get_range_tracker(self, start_position, stop_position): + def get_range_tracker(self, + start_position, # type: Optional[int] + stop_position, # type: Optional[int] + ): + # type: (...) -> RangeTracker """Returns a RangeTracker for a given position range. Framework may invoke ``read()`` method with the RangeTracker object returned @@ -837,6 +862,7 @@ class Read(ptransform.PTransform): """A transform that reads a PCollection.""" def __init__(self, source): + # type: (SourceBase) -> None """Initializes a Read transform. Args: @@ -884,9 +910,11 @@ def split_source(unused_impulse): is_bounded=self.source.is_bounded()) def get_windowing(self, unused_inputs): + # type: (...) -> core.Windowing return core.Windowing(window.GlobalWindows()) def _infer_output_coder(self, input_type=None, input_coder=None): + # type: (...) -> Optional[coders.Coder] if isinstance(self.source, BoundedSource): return self.source.default_output_coder() else: @@ -898,6 +926,7 @@ def display_data(self): 'source_dd': self.source} def to_runner_api_parameter(self, context): + # type: (PipelineContext) -> Tuple[str, beam_runner_api_pb2.ReadPayload] return (common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload( source=self.source.to_runner_api(context), @@ -907,6 +936,7 @@ def to_runner_api_parameter(self, context): @staticmethod def from_runner_api_parameter(parameter, context): + # type: (beam_runner_api_pb2.ReadPayload, PipelineContext) -> Read return Read(SourceBase.from_runner_api(parameter.source, context)) @@ -977,6 +1007,7 @@ class WriteImpl(ptransform.PTransform): """Implements the writing of custom sinks.""" def __init__(self, sink): + # type: (Sink) -> None super(WriteImpl, self).__init__() self.sink = sink @@ -1089,6 +1120,7 @@ def _finalize_write(unused_element, sink, init_result, write_results, class _RoundRobinKeyFn(core.DoFn): def __init__(self, count): + # type: (int) -> None self.count = count def start_bundle(self): @@ -1134,6 +1166,7 @@ def current_restriction(self): raise NotImplementedError def current_progress(self): + # type: () -> RestrictionProgress """Returns a RestrictionProgress object representing the current progress. """ raise NotImplementedError @@ -1275,6 +1308,7 @@ def defer_remainder(self, watermark=None): raise NotImplementedError def deferred_status(self): + # type: () -> Optional[Tuple[restriction_trackers.OffsetRange, Timestamp]] """ Returns deferred_residual with deferred_watermark. TODO(BEAM-7472): Remove defer_status() once SDF.process() uses @@ -1357,6 +1391,7 @@ def __init__(self, restriction): self._weight = restriction.weight def current_progress(self): + # type: () -> RestrictionProgress return RestrictionProgress( fraction=self._delegate_range_tracker.fraction_consumed()) diff --git a/sdks/python/apache_beam/io/localfilesystem.py b/sdks/python/apache_beam/io/localfilesystem.py index 20748a2ff6be..18d32e423ba7 100644 --- a/sdks/python/apache_beam/io/localfilesystem.py +++ b/sdks/python/apache_beam/io/localfilesystem.py @@ -21,6 +21,7 @@ import os import shutil from builtins import zip +from typing import BinaryIO from apache_beam.io.filesystem import BeamIOError from apache_beam.io.filesystem import CompressedFile @@ -139,6 +140,7 @@ def _path_open(self, path, mode, mime_type='application/octet-stream', def create(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a write channel for the given file path. Args: @@ -152,6 +154,7 @@ def create(self, path, mime_type='application/octet-stream', def open(self, path, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO): + # type: (...) -> BinaryIO """Returns a read channel for the given file path. Args: diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py index 0ba5b23550c9..dc20fe9a8ade 100644 --- a/sdks/python/apache_beam/io/restriction_trackers.py +++ b/sdks/python/apache_beam/io/restriction_trackers.py @@ -21,11 +21,17 @@ import threading from builtins import object +from typing import TYPE_CHECKING +from typing import Optional +from typing import Tuple from apache_beam.io.iobase import RestrictionProgress from apache_beam.io.iobase import RestrictionTracker from apache_beam.io.range_trackers import OffsetRangeTracker +if TYPE_CHECKING: + from apache_beam.utils.timestamp import Timestamp + class OffsetRange(object): @@ -67,6 +73,7 @@ def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1): current_split_start = current_split_stop def split_at(self, split_pos): + # type: (...) -> Tuple[OffsetRange, OffsetRange] return OffsetRange(self.start, split_pos), OffsetRange(split_pos, self.stop) def new_tracker(self): @@ -83,12 +90,13 @@ class OffsetRestrictionTracker(RestrictionTracker): """ def __init__(self, offset_range): + # type: (OffsetRange) -> None assert isinstance(offset_range, OffsetRange) self._range = offset_range self._current_position = None self._current_watermark = None self._last_claim_attempt = None - self._deferred_residual = None + self._deferred_residual = None # type: Optional[OffsetRange] self._checkpointed = False self._lock = threading.RLock() @@ -110,6 +118,7 @@ def current_watermark(self): return self._current_watermark def current_progress(self): + # type: () -> RestrictionProgress with self._lock: if self._current_position is None: fraction = 0.0 @@ -185,5 +194,8 @@ def defer_remainder(self, watermark=None): self._deferred_residual = self.checkpoint() def deferred_status(self): + # type: () -> Optional[Tuple[OffsetRange, Timestamp]] if self._deferred_residual: return (self._deferred_residual, self._deferred_watermark) + else: + return None diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 340449f8896e..bf20f2e14d7e 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -24,6 +24,7 @@ from builtins import object from builtins import range from functools import partial +from typing import Optional from past.builtins import long @@ -92,7 +93,7 @@ def __init__(self, min_bundle_size, compression_type, strip_trailing_newlines, - coder, + coder, # type: coders.Coder buffer_size=DEFAULT_READ_BUFFER_SIZE, validate=True, skip_header_lines=0, @@ -338,7 +339,7 @@ def __init__(self, append_trailing_newlines=True, num_shards=0, shard_name_template=None, - coder=coders.ToStringCoder(), + coder=coders.ToStringCoder(), # type: coders.Coder compression_type=CompressionTypes.AUTO, header=None): """Initialize a _TextSink. @@ -440,7 +441,7 @@ def __init__( desired_bundle_size=DEFAULT_DESIRED_BUNDLE_SIZE, compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, - coder=coders.StrUtf8Coder(), + coder=coders.StrUtf8Coder(), # type: coders.Coder skip_header_lines=0, **kwargs): """Initialize the ``ReadAllFromText`` transform. @@ -501,7 +502,7 @@ def __init__( min_bundle_size=0, compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, - coder=coders.StrUtf8Coder(), + coder=coders.StrUtf8Coder(), # type: coders.Coder validate=True, skip_header_lines=0, **kwargs): @@ -556,12 +557,12 @@ class WriteToText(PTransform): def __init__( self, - file_path_prefix, + file_path_prefix, # type: str file_name_suffix='', append_trailing_newlines=True, num_shards=0, - shard_name_template=None, - coder=coders.ToStringCoder(), + shard_name_template=None, # type: Optional[str] + coder=coders.ToStringCoder(), # type: coders.Coder compression_type=CompressionTypes.AUTO, header=None): r"""Initialize a :class:`WriteToText` transform. diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 6dbc1af48c4c..cf7f07f8ce97 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -391,6 +391,7 @@ def singleton(value, timestamp=None): return GaugeData(value, timestamp=timestamp) def to_runner_api(self): + # type: () -> beam_fn_api_pb2.Metrics.User.GaugeData seconds = int(self.timestamp) nanos = int((self.timestamp - seconds) * 10**9) gauge_timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) @@ -399,6 +400,7 @@ def to_runner_api(self): @staticmethod def from_runner_api(proto): + # type: (beam_fn_api_pb2.Metrics.User.GaugeData) -> GaugeData gauge_timestamp = (proto.timestamp.seconds + float(proto.timestamp.nanos) / 10**9) return GaugeData(proto.value, timestamp=gauge_timestamp) @@ -470,11 +472,13 @@ def singleton(value): return DistributionData(value, 1, value, value) def to_runner_api(self): + # type: () -> beam_fn_api_pb2.Metrics.User.DistributionData return beam_fn_api_pb2.Metrics.User.DistributionData( count=self.count, sum=self.sum, min=self.min, max=self.max) @staticmethod def from_runner_api(proto): + # type: (beam_fn_api_pb2.Metrics.User.DistributionData) -> DistributionData return DistributionData(proto.sum, proto.count, proto.min, proto.max) def to_runner_api_monitoring_info(self): diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py index 0e73461784c9..16a3b2a5e320 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos.py @@ -23,6 +23,9 @@ import collections import time from functools import reduce +from typing import FrozenSet +from typing import Hashable +from typing import List from google.protobuf import timestamp_pb2 @@ -129,6 +132,7 @@ def create_labels(ptransform=None, tag=None, namespace=None, name=None): def int64_user_counter(namespace, name, metric, ptransform=None, tag=None): + # type: (...) -> metrics_pb2.MonitoringInfo """Return the counter monitoring info for the specifed URN, metric and labels. Args: @@ -151,6 +155,7 @@ def int64_user_counter(namespace, name, metric, ptransform=None, tag=None): def int64_counter(urn, metric, ptransform=None, tag=None): + # type: (...) -> metrics_pb2.MonitoringInfo """Return the counter monitoring info for the specifed URN, metric and labels. Args: @@ -187,6 +192,7 @@ def int64_user_distribution(namespace, name, metric, ptransform=None, tag=None): def int64_distribution(urn, metric, ptransform=None, tag=None): + # type: (...) -> metrics_pb2.MonitoringInfo """Return a distribution monitoring info for the URN, metric and labels. Args: @@ -201,6 +207,7 @@ def int64_distribution(urn, metric, ptransform=None, tag=None): def int64_user_gauge(namespace, name, metric, ptransform=None, tag=None): + # type: (...) -> metrics_pb2.MonitoringInfo """Return the gauge monitoring info for the URN, metric and labels. Args: @@ -218,6 +225,7 @@ def int64_user_gauge(namespace, name, metric, ptransform=None, tag=None): def create_monitoring_info(urn, type_urn, metric_proto, labels=None): + # type: (...) -> metrics_pb2.MonitoringInfo """Return the gauge monitoring info for the URN, type, metric and labels. Args: @@ -301,11 +309,12 @@ def parse_namespace_and_name(monitoring_info_proto): def to_key(monitoring_info_proto): + # type: (metrics_pb2.MonitoringInfo) -> FrozenSet[Hashable] """Returns a key based on the URN and labels. This is useful in maps to prevent reporting the same MonitoringInfo twice. """ - key_items = list(monitoring_info_proto.labels.items()) + key_items = list(monitoring_info_proto.labels.items()) # type: List[Hashable] key_items.append(monitoring_info_proto.urn) return frozenset(key_items) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 02dda26422e0..45dd3b19cd54 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -24,6 +24,13 @@ import logging from builtins import list from builtins import object +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Type +from typing import TypeVar from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.options.value_provider import StaticValueProvider @@ -44,6 +51,8 @@ 'TestOptions', ] +PipelineOptionsT = TypeVar('PipelineOptionsT', bound='PipelineOptions') + def _static_value_provider_of(value_type): """"Helper function to plug a ValueProvider into argparse. @@ -154,7 +163,9 @@ def _add_argparse_args(cls, parser): By default the options classes will use command line arguments to initialize the options. """ - def __init__(self, flags=None, **kwargs): + def __init__(self, + flags=None, # type: Optional[List[str]] + **kwargs): """Initialize an options class. The initializer will traverse all subclasses, add all their argparse @@ -205,6 +216,7 @@ def __init__(self, flags=None, **kwargs): @classmethod def _add_argparse_args(cls, parser): + # type: (_BeamArgumentParser) -> None # Override this in subclasses to provide options. pass @@ -231,7 +243,11 @@ def from_dictionary(cls, options): return cls(flags) - def get_all_options(self, drop_default=False, add_extra_args_fn=None): + def get_all_options(self, + drop_default=False, + add_extra_args_fn=None # type: Optional[Callable[[_BeamArgumentParser], None]] + ): + # type: (...) -> Dict[str, Any] """Returns a dictionary of all defined arguments. Returns a dictionary of all defined arguments (arguments that are defined in @@ -277,6 +293,7 @@ def display_data(self): return self.get_all_options(True) def view_as(self, cls): + # type: (Type[PipelineOptionsT]) -> PipelineOptionsT """Returns a view of current object as provided PipelineOption subclass. Example Usage:: @@ -315,10 +332,12 @@ def view_as(self, cls): return view def _visible_option_list(self): + # type: () -> List[str] return sorted(option for option in dir(self._visible_options) if option[0] != '_') def __dir__(self): + # type: () -> List[str] return sorted(dir(type(self)) + list(self.__dict__) + self._visible_option_list()) @@ -918,7 +937,7 @@ class OptionsContext(object): Can also be used as a decorator. """ - overrides = [] + overrides = [] # type: List[Dict[str, Any]] def __init__(self, **options): self.options = options diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py index ca6536d73a67..380dd7b8efc8 100644 --- a/sdks/python/apache_beam/options/value_provider.py +++ b/sdks/python/apache_beam/options/value_provider.py @@ -23,6 +23,7 @@ from builtins import object from functools import wraps +from typing import Set from apache_beam import error @@ -79,7 +80,7 @@ def __hash__(self): class RuntimeValueProvider(ValueProvider): runtime_options = None - experiments = set() + experiments = set() # type: Set[str] def __init__(self, option_name, value_type, default_value): self.option_name = option_name diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index cad699c87d49..9d65a68b41b7 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -54,6 +54,15 @@ import tempfile from builtins import object from builtins import zip +from typing import TYPE_CHECKING +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Union from future.utils import with_metaclass @@ -77,6 +86,11 @@ from apache_beam.typehints import typehints from apache_beam.utils.annotations import deprecated +if TYPE_CHECKING: + from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.pipeline_context import PipelineContext + from apache_beam.runners.runner import PipelineResult + __all__ = ['Pipeline', 'PTransformOverride'] @@ -95,7 +109,11 @@ class Pipeline(object): (e.g. ``input | "label" >> my_tranform``). """ - def __init__(self, runner=None, options=None, argv=None): + def __init__(self, + runner=None, # type: Optional[Union[str, PipelineRunner]] + options=None, # type: Optional[PipelineOptions] + argv=None # type: Optional[List[str]] + ): """Initialize a pipeline object. Args: @@ -170,7 +188,7 @@ def __init__(self, runner=None, options=None, argv=None): # Set of transform labels (full labels) applied to the pipeline. # If a transform is applied and the full label is already in the set # then the transform will have to be cloned with a new label. - self.applied_labels = set() + self.applied_labels = set() # type: Set[str] @property # type: ignore[misc] # decorated property not supported @deprecated(since='First stable release', @@ -180,14 +198,17 @@ def options(self): return self._options def _current_transform(self): + # type: () -> AppliedPTransform """Returns the transform currently on the top of the stack.""" return self.transforms_stack[-1] def _root_transform(self): + # type: () -> AppliedPTransform """Returns the root transform of the transform stack.""" return self.transforms_stack[0] def _remove_labels_recursively(self, applied_transform): + # type: (AppliedPTransform) -> None for part in applied_transform.parts: if part.full_label in self.applied_labels: self.applied_labels.remove(part.full_label) @@ -206,6 +227,7 @@ class TransformUpdater(PipelineVisitor): # pylint: disable=used-before-assignmen """"A visitor that replaces the matching PTransforms.""" def __init__(self, pipeline): + # type: (Pipeline) -> None self.pipeline = pipeline def _replace_if_needed(self, original_transform_node): @@ -292,9 +314,11 @@ def _replace_if_needed(self, original_transform_node): self.pipeline.transforms_stack.pop() def enter_composite_transform(self, transform_node): + # type: (AppliedPTransform) -> None self._replace_if_needed(transform_node) def visit_transform(self, transform_node): + # type: (AppliedPTransform) -> None self._replace_if_needed(transform_node) self.visit(TransformUpdater(self)) @@ -311,12 +335,15 @@ class InputOutputUpdater(PipelineVisitor): # pylint: disable=used-before-assignm """ def __init__(self, pipeline): + # type: (Pipeline) -> None self.pipeline = pipeline def enter_composite_transform(self, transform_node): + # type: (AppliedPTransform) -> None self.visit_transform(transform_node) def visit_transform(self, transform_node): + # type: (AppliedPTransform) -> None if (None in transform_node.outputs and transform_node.outputs[None] in output_map): output_replacements[transform_node] = ( @@ -372,6 +399,7 @@ def visit_transform(self, transform_node): self.visit(ReplacementValidator()) def replace_all(self, replacements): + # type: (Iterable[PTransformOverride]) -> None """ Dynamically replaces PTransforms in the currently populated hierarchy. Currently this only works for replacements where input and output types @@ -396,6 +424,7 @@ def replace_all(self, replacements): self._check_replacement(override) def run(self, test_runner_api=True): + # type: (...) -> PipelineResult """Runs the pipeline. Returns whatever our runner returns after running.""" # When possible, invoke a round trip through the runner API. @@ -426,6 +455,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.run().wait_until_finish() def visit(self, visitor): + # type: (PipelineVisitor) -> None """Visits depth-first every node of a pipeline's DAG. Runner-internal implementation detail; no backwards-compatibility guarantees @@ -443,7 +473,7 @@ def visit(self, visitor): belong to this pipeline instance. """ - visited = set() + visited = set() # type: Set[pvalue.PValue] self._root_transform().visit(visitor, self, visited) def apply(self, transform, pvalueish=None, label=None): @@ -606,9 +636,13 @@ def visit_value(self, value, _): self.visit(Visitor()) return Visitor.ok - def to_runner_api( - self, return_context=False, context=None, use_fake_coders=False, - default_environment=None): + def to_runner_api(self, + return_context=False, + context=None, # type: Optional[PipelineContext] + use_fake_coders=False, + default_environment=None # type: Optional[beam_runner_api_pb2.Environment] + ): + # type: (...) -> beam_runner_api_pb2.Pipeline """For internal use only; no backwards-compatibility guarantees.""" from apache_beam.runners import pipeline_context from apache_beam.portability.api import beam_runner_api_pb2 @@ -661,13 +695,18 @@ def visit_transform(self, transform_node): proto.components.transforms[root_transform_id].unique_name = ( root_transform_id) if return_context: - return proto, context + return proto, context # type: ignore # too complicated for now else: return proto @staticmethod - def from_runner_api(proto, runner, options, return_context=False, - allow_proto_holders=False): + def from_runner_api(proto, # type: beam_runner_api_pb2.Pipeline + runner, # type: PipelineRunner + options, # type: PipelineOptions + return_context=False, + allow_proto_holders=False + ): + # type: (...) -> Pipeline """For internal use only; no backwards-compatibility guarantees.""" p = Pipeline(runner=runner, options=options) from apache_beam.runners import pipeline_context @@ -695,7 +734,7 @@ def from_runner_api(proto, runner, options, return_context=False, transform.inputs = (pvalue.PBegin(p),) if return_context: - return p, context + return p, context # type: ignore # too complicated for now else: return p @@ -708,6 +747,7 @@ class PipelineVisitor(object): """ def visit_value(self, value, producer_node): + # type: (pvalue.PValue, AppliedPTransform) -> None """Callback for visiting a PValue in the pipeline DAG. Args: @@ -718,14 +758,17 @@ def visit_value(self, value, producer_node): pass def visit_transform(self, transform_node): + # type: (AppliedPTransform) -> None """Callback for visiting a transform leaf node in the pipeline DAG.""" pass def enter_composite_transform(self, transform_node): + # type: (AppliedPTransform) -> None """Callback for entering traversal of a composite transform node.""" pass def leave_composite_transform(self, transform_node): + # type: (AppliedPTransform) -> None """Callback for leaving traversal of a composite transform node.""" pass @@ -737,7 +780,12 @@ class AppliedPTransform(object): (used internally by Pipeline for bookeeping purposes). """ - def __init__(self, parent, transform, full_label, inputs): + def __init__(self, + parent, + transform, # type: ptransform.PTransform + full_label, # type: str + inputs # type: Optional[Sequence[Union[pvalue.PBegin, pvalue.PCollection]]] + ): self.parent = parent self.transform = transform # Note that we want the PipelineVisitor classes to use the full_label, @@ -747,15 +795,19 @@ def __init__(self, parent, transform, full_label, inputs): # any interference. This is particularly useful for composite transforms. self.full_label = full_label self.inputs = inputs or () - self.side_inputs = () if transform is None else tuple(transform.side_inputs) - self.outputs = {} - self.parts = [] + self.side_inputs = () if transform is None else tuple(transform.side_inputs) # type: Tuple[pvalue.AsSideInput, ...] + self.outputs = {} # type: Dict[Union[str, int, None], pvalue.PValue] + self.parts = [] # type: List[AppliedPTransform] def __repr__(self): return "%s(%s, %s)" % (self.__class__.__name__, self.full_label, type(self.transform).__name__) - def replace_output(self, output, tag=None): + def replace_output(self, + output, # type: Union[pvalue.PValue, pvalue.DoOutputsTuple] + tag=None # type: Union[str, int, None] + ): + # type: (...) -> None """Replaces the output defined by the given tag with the given output. Args: @@ -769,7 +821,11 @@ def replace_output(self, output, tag=None): else: raise TypeError("Unexpected output type: %s" % output) - def add_output(self, output, tag=None): + def add_output(self, + output, # type: Union[pvalue.DoOutputsTuple, pvalue.PValue] + tag=None # type: Union[str, int, None] + ): + # type: (...) -> None if isinstance(output, pvalue.DoOutputsTuple): self.add_output(output[output._main_tag]) elif isinstance(output, pvalue.PValue): @@ -782,10 +838,12 @@ def add_output(self, output, tag=None): raise TypeError("Unexpected output type: %s" % output) def add_part(self, part): + # type: (AppliedPTransform) -> None assert isinstance(part, AppliedPTransform) self.parts.append(part) def is_composite(self): + # type: () -> bool """Returns whether this is a composite transform. A composite transform has parts (inner transforms) or isn't the @@ -795,7 +853,12 @@ def is_composite(self): return bool(self.parts) or all( pval.producer is not self for pval in self.outputs.values()) - def visit(self, visitor, pipeline, visited): + def visit(self, + visitor, # type: PipelineVisitor + pipeline, # type: Pipeline + visited # type: Set[pvalue.PValue] + ): + # type: (...) -> None """Visits all nodes reachable from the current node.""" for in_pval in self.inputs: @@ -807,7 +870,8 @@ def visit(self, visitor, pipeline, visited): # Visit side inputs. for side_input in self.side_inputs: - if isinstance(side_input, pvalue.AsSideInput) and side_input.pvalue not in visited: + if isinstance(side_input, pvalue.AsSideInput) \ + and side_input.pvalue not in visited: pval = side_input.pvalue # Unpack marker-object-wrapped pvalue. if pval.producer is not None: pval.producer.visit(visitor, pipeline, visited) @@ -844,6 +908,7 @@ def visit(self, visitor, pipeline, visited): visitor.visit_value(v, self) def named_inputs(self): + # type: () -> Dict[str, pvalue.PCollection] # TODO(BEAM-1833): Push names up into the sdk construction. main_inputs = {str(ix): input for ix, input in enumerate(self.inputs) @@ -853,10 +918,12 @@ def named_inputs(self): return dict(main_inputs, **side_inputs) def named_outputs(self): + # type: () -> Dict[str, pvalue.PCollection] return {str(tag): output for tag, output in self.outputs.items() if isinstance(output, pvalue.PCollection)} def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.PTransform # External tranforms require more splicing than just setting the spec. from apache_beam.transforms import external if isinstance(self.transform, external.ExternalTransform): @@ -864,7 +931,10 @@ def to_runner_api(self, context): from apache_beam.portability.api import beam_runner_api_pb2 - def transform_to_runner_api(transform, context): + def transform_to_runner_api(transform, # type: Optional[ptransform.PTransform] + context # type: PipelineContext + ): + # type: (...) -> Optional[beam_runner_api_pb2.FunctionSpec] if transform is None: return None else: @@ -884,7 +954,10 @@ def transform_to_runner_api(transform, context): display_data=None) @staticmethod - def from_runner_api(proto, context): + def from_runner_api(proto, # type: beam_runner_api_pb2.PTransform + context # type: PipelineContext + ): + # type: (...) -> AppliedPTransform def is_side_input(tag): # As per named_inputs() above. return tag.startswith('side') @@ -938,6 +1011,7 @@ class PTransformOverride(with_metaclass(abc.ABCMeta, object)): @abc.abstractmethod def matches(self, applied_ptransform): + # type: (AppliedPTransform) -> bool """Determines whether the given AppliedPTransform matches. Note that the matching will happen *after* Runner API proto translation. @@ -957,6 +1031,7 @@ def matches(self, applied_ptransform): @abc.abstractmethod def get_replacement_transform(self, ptransform): + # type: (AppliedPTransform) -> AppliedPTransform """Provides a runner specific override for a given PTransform. Args: diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 9916877b0209..23397f4ba944 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -28,9 +28,17 @@ import collections import itertools -import typing from builtins import hex from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import Optional +from typing import Sequence +from typing import TypeVar +from typing import Union from past.builtins import unicode @@ -40,6 +48,14 @@ from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 +if TYPE_CHECKING: + from apache_beam.transforms import sideinputs + from apache_beam.transforms.core import ParDo + from apache_beam.transforms.core import Windowing + from apache_beam.pipeline import AppliedPTransform + from apache_beam.pipeline import Pipeline + from apache_beam.runners.pipeline_context import PipelineContext + __all__ = [ 'PCollection', 'TaggedOutput', @@ -50,6 +66,8 @@ 'EmptySideInput', ] +T = TypeVar('T') + class PValue(object): """Base class for PCollection. @@ -63,8 +81,13 @@ class PValue(object): (3) Has a value which is meaningful if the transform was executed. """ - def __init__(self, pipeline, tag=None, element_type=None, windowing=None, - is_bounded=True): + def __init__(self, + pipeline, # type: Pipeline + tag=None, # type: Optional[str] + element_type=None, # type: Optional[type] + windowing=None, # type: Optional[Windowing] + is_bounded=True, + ): """Initializes a PValue with all arguments hidden behind keyword arguments. Args: @@ -78,7 +101,7 @@ def __init__(self, pipeline, tag=None, element_type=None, windowing=None, # The AppliedPTransform instance for the application of the PTransform # generating this PValue. The field gets initialized when a transform # gets applied. - self.producer = None + self.producer = None # type: Optional[AppliedPTransform] self.is_bounded = is_bounded if windowing: self._windowing = windowing @@ -113,7 +136,7 @@ def __or__(self, ptransform): return self.pipeline.apply(ptransform, self) -class PCollection(PValue, typing.Generic[typing.TypeVar('T')]): +class PCollection(PValue, Generic[T]): """A multiple values (potentially huge) container. Dataflow users should not construct PCollection objects directly in their @@ -133,6 +156,7 @@ def __hash__(self): @property def windowing(self): + # type: () -> Windowing if not hasattr(self, '_windowing'): self._windowing = self.producer.transform.get_windowing( self.producer.inputs) @@ -146,6 +170,7 @@ def __reduce_ex__(self, unused_version): @staticmethod def from_(pcoll): + # type: (PValue) -> PCollection """Create a PCollection, using another PCollection as a starting point. Transfers relevant attributes. @@ -153,6 +178,7 @@ def from_(pcoll): return PCollection(pcoll.pipeline, is_bounded=pcoll.is_bounded) def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.PCollection return beam_runner_api_pb2.PCollection( unique_name=self._unique_name(), coder_id=context.coder_id_from_element_type(self.element_type), @@ -163,6 +189,7 @@ def to_runner_api(self, context): self.windowing)) def _unique_name(self): + # type: () -> str if self.producer: return '%d%s.%s' % ( len(self.producer.full_label), self.producer.full_label, self.tag) @@ -171,6 +198,7 @@ def _unique_name(self): @staticmethod def from_runner_api(proto, context): + # type: (beam_runner_api_pb2.PCollection, PipelineContext) -> PCollection # Producer and tag will be filled in later, the key point is that the # same object is returned for the same pcollection id. return PCollection( @@ -204,7 +232,12 @@ class PDone(PValue): class DoOutputsTuple(object): """An object grouping the multiple outputs of a ParDo or FlatMap transform.""" - def __init__(self, pipeline, transform, tags, main_tag): + def __init__(self, + pipeline, # type: Pipeline + transform, # type: ParDo + tags, # type: Sequence[str] + main_tag # type: Optional[str] + ): self._pipeline = pipeline self._tags = tags self._main_tag = main_tag @@ -212,9 +245,9 @@ def __init__(self, pipeline, transform, tags, main_tag): # The ApplyPTransform instance for the application of the multi FlatMap # generating this value. The field gets initialized when a transform # gets applied. - self.producer = None + self.producer = None # type: Optional[AppliedPTransform] # Dictionary of PCollections already associated with tags. - self._pcolls = {} + self._pcolls = {} # type: Dict[Optional[str], PValue] def __str__(self): return '<%s>' % self._str_internal() @@ -227,6 +260,7 @@ def _str_internal(self): self.__class__.__name__, self._main_tag, self._tags, self._transform) def __iter__(self): + # type: () -> Iterator[PValue] """Iterates over tags returning for each call a (tag, pvalue) pair.""" if self._main_tag is not None: yield self[self._main_tag] @@ -234,6 +268,7 @@ def __iter__(self): yield self[tag] def __getattr__(self, tag): + # type: (str) -> PValue # Special methods which may be accessed before the object is # fully constructed (e.g. in unpickling). if tag[:2] == tag[-2:] == '__': @@ -241,6 +276,7 @@ def __getattr__(self, tag): return self[tag] def __getitem__(self, tag): + # type: (Union[int, str, None]) -> PValue # Accept int tags so that we can look at Partition tags with the # same ints that we used in the partition function. # TODO(gildea): Consider requiring string-based tags everywhere. @@ -258,9 +294,10 @@ def __getitem__(self, tag): if tag in self._pcolls: return self._pcolls[tag] + assert self.producer is not None if tag is not None: self._transform.output_tags.add(tag) - pcoll = PCollection(self._pipeline, tag=tag) + pcoll = PCollection(self._pipeline, tag=tag) # type: PValue # Transfer the producer from the DoOutputsTuple to the resulting # PCollection. pcoll.producer = self.producer.parts[0] @@ -286,6 +323,7 @@ class TaggedOutput(object): """ def __init__(self, tag, value): + # type: (str, Any) -> None if not isinstance(tag, (str, unicode)): raise TypeError( 'Attempting to create a TaggedOutput with non-string tag %s' % (tag,)) @@ -305,6 +343,7 @@ class AsSideInput(object): """ def __init__(self, pcoll): + # type: (PCollection) -> None from apache_beam.transforms import sideinputs self.pvalue = pcoll self._window_mapping_fn = sideinputs.default_window_mapping_fn( @@ -327,6 +366,7 @@ def element_type(self): # TODO(robertwb): Get rid of _from_runtime_iterable and _view_options # in favor of _side_input_data(). def _side_input_data(self): + # type: () -> SideInputData view_options = self._view_options() from_runtime_iterable = type(self)._from_runtime_iterable return SideInputData( @@ -335,19 +375,28 @@ def _side_input_data(self): lambda iterable: from_runtime_iterable(iterable, view_options)) def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.SideInput return self._side_input_data().to_runner_api(context) @staticmethod - def from_runner_api(proto, context): + def from_runner_api(proto, # type: beam_runner_api_pb2.SideInput + context # type: PipelineContext + ): + # type: (...) -> _UnpickledSideInput return _UnpickledSideInput( SideInputData.from_runner_api(proto, context)) + @staticmethod + def _from_runtime_iterable(it, options): + raise NotImplementedError + def requires_keyed_input(self): return False class _UnpickledSideInput(AsSideInput): def __init__(self, side_input_data): + # type: (SideInputData) -> None self._data = side_input_data self._window_mapping_fn = side_input_data.window_mapping_fn @@ -368,12 +417,17 @@ def _side_input_data(self): class SideInputData(object): """All of the data about a side input except for the bound PCollection.""" - def __init__(self, access_pattern, window_mapping_fn, view_fn): + def __init__(self, + access_pattern, # type: str + window_mapping_fn, # type: sideinputs.WindowMappingFn + view_fn + ): self.access_pattern = access_pattern self.window_mapping_fn = window_mapping_fn self.view_fn = view_fn def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.SideInput return beam_runner_api_pb2.SideInput( access_pattern=beam_runner_api_pb2.FunctionSpec( urn=self.access_pattern), @@ -390,6 +444,7 @@ def to_runner_api(self, context): @staticmethod def from_runner_api(proto, unused_context): + # type: (beam_runner_api_pb2.SideInput, PipelineContext) -> SideInputData assert proto.view_fn.spec.urn == python_urns.PICKLED_VIEWFN assert (proto.window_mapping_fn.spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN) @@ -418,6 +473,7 @@ class AsSingleton(AsSideInput): _NO_DEFAULT = object() def __init__(self, pcoll, default_value=_NO_DEFAULT): + # type: (PCollection, Any) -> None super(AsSingleton, self).__init__(pcoll) self.default_value = default_value @@ -469,6 +525,7 @@ def _from_runtime_iterable(it, options): return it def _side_input_data(self): + # type: () -> SideInputData return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, @@ -499,6 +556,7 @@ def _from_runtime_iterable(it, options): return list(it) def _side_input_data(self): + # type: () -> SideInputData return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, @@ -526,6 +584,7 @@ def _from_runtime_iterable(it, options): return dict(it) def _side_input_data(self): + # type: () -> SideInputData return SideInputData( common_urns.side_inputs.ITERABLE.urn, self._window_mapping_fn, @@ -552,6 +611,7 @@ def _from_runtime_iterable(it, options): return result def _side_input_data(self): + # type: () -> SideInputData return SideInputData( common_urns.side_inputs.MULTIMAP.urn, self._window_mapping_fn, diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 541959a5cf91..70486915b2d4 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -28,6 +28,14 @@ from builtins import next from builtins import object from builtins import zip +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import Optional +from typing import Tuple from future.utils import raise_with_traceback from past.builtins import unicode @@ -47,11 +55,17 @@ from apache_beam.utils.timestamp import Timestamp from apache_beam.utils.windowed_value import WindowedValue +if TYPE_CHECKING: + from apache_beam.io import iobase + from apache_beam.transforms import sideinputs + from apache_beam.transforms.core import TimerSpec + class NameContext(object): """Holds the name information for a step.""" def __init__(self, step_name): + # type: (str) -> None """Creates a new step NameContext. Args: @@ -130,6 +144,7 @@ class Receiver(object): """ def receive(self, windowed_value): + # type: (WindowedValue) -> None raise NotImplementedError @@ -160,11 +175,11 @@ def __init__(self, obj_to_invoke, method_name): self.method_value = getattr(obj_to_invoke, method_name) self.has_userstate_arguments = False - self.state_args_to_replace = {} - self.timer_args_to_replace = {} - self.timestamp_arg_name = None - self.window_arg_name = None - self.key_arg_name = None + self.state_args_to_replace = {} # type: Dict[str, core.StateSpec] + self.timer_args_to_replace = {} # type: Dict[str, core.TimerSpec] + self.timestamp_arg_name = None # type: Optional[str] + self.window_arg_name = None # type: Optional[str] + self.key_arg_name = None # type: Optional[str] self.restriction_provider = None self.restriction_provider_arg_name = None @@ -224,6 +239,7 @@ class DoFnSignature(object): """ def __init__(self, do_fn): + # type: (core.DoFn) -> None # We add a property here for all methods defined by Beam DoFn features. assert isinstance(do_fn, core.DoFn) @@ -253,7 +269,7 @@ def __init__(self, do_fn): # Handle stateful DoFns. self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn) - self.timer_methods = {} + self.timer_methods = {} # type: Dict[TimerSpec, MethodWrapper] if self._is_stateful_dofn: # Populate timer firing methods, keyed by TimerSpec. _, all_timer_specs = userstate.get_dofn_specs(do_fn) @@ -262,6 +278,7 @@ def __init__(self, do_fn): self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__) def get_restriction_provider(self): + # type: () -> RestrictionProvider return self.process_method.restriction_provider def _validate(self): @@ -293,12 +310,15 @@ def _validate_stateful_dofn(self): userstate.validate_stateful_dofn(self.do_fn) def is_splittable_dofn(self): + # type: () -> bool return self.get_restriction_provider() is not None def is_stateful_dofn(self): + # type: () -> bool return self._is_stateful_dofn def has_timers(self): + # type: () -> bool _, all_timer_specs = userstate.get_dofn_specs(self.do_fn) return bool(all_timer_specs) @@ -309,7 +329,11 @@ class DoFnInvoker(object): A DoFnInvoker describes a particular way for invoking methods of a DoFn represented by a given DoFnSignature.""" - def __init__(self, output_processor, signature): + def __init__(self, + output_processor, # type: Optional[_OutputProcessor] + signature # type: DoFnSignature + ): + # type: (...) -> None """ Initializes `DoFnInvoker` @@ -319,17 +343,21 @@ def __init__(self, output_processor, signature): """ self.output_processor = output_processor self.signature = signature - self.user_state_context = None - self.bundle_finalizer_param = None + self.user_state_context = None # type: Optional[userstate.UserStateContext] + self.bundle_finalizer_param = None # type: Optional[core._BundleFinalizerParam] @staticmethod def create_invoker( - signature, - output_processor=None, - context=None, side_inputs=None, input_args=None, input_kwargs=None, + signature, # type: DoFnSignature + output_processor=None, # type: Optional[_OutputProcessor] + context=None, # type: Optional[DoFnContext] + side_inputs=None, # type: Optional[List[sideinputs.SideInputMap]] + input_args=None, input_kwargs=None, process_invocation=True, - user_state_context=None, - bundle_finalizer_param=None): + user_state_context=None, # type: Optional[userstate.UserStateContext] + bundle_finalizer_param=None # type: Optional[core._BundleFinalizerParam] + ): + # type: (...) -> DoFnInvoker """ Creates a new DoFnInvoker based on given arguments. Args: @@ -362,14 +390,21 @@ def create_invoker( if use_simple_invoker: return SimpleInvoker(output_processor, signature) else: + if context is None: + raise TypeError("Must provide context when not using SimpleInvoker") return PerWindowInvoker( output_processor, signature, context, side_inputs, input_args, input_kwargs, user_state_context, bundle_finalizer_param) - def invoke_process(self, windowed_value, restriction_tracker=None, - output_processor=None, - additional_args=None, additional_kwargs=None): + def invoke_process(self, + windowed_value, # type: WindowedValue + restriction_tracker=None, # type: Optional[iobase.RestrictionTracker] + output_processor=None, # type: Optional[OutputProcessor] + additional_args=None, + additional_kwargs=None + ): + # type: (...) -> Optional[Tuple[WindowedValue, Timestamp]] """Invokes the DoFn.process() function. Args: @@ -385,23 +420,27 @@ def invoke_process(self, windowed_value, restriction_tracker=None, raise NotImplementedError def invoke_setup(self): + # type: () -> None """Invokes the DoFn.setup() method """ self.signature.setup_lifecycle_method.method_value() def invoke_start_bundle(self): + # type: () -> None """Invokes the DoFn.start_bundle() method. """ self.output_processor.start_bundle_outputs( self.signature.start_bundle_method.method_value()) def invoke_finish_bundle(self): + # type: () -> None """Invokes the DoFn.finish_bundle() method. """ self.output_processor.finish_bundle_outputs( self.signature.finish_bundle_method.method_value()) def invoke_teardown(self): + # type: () -> None """Invokes the DoFn.teardown() method """ self.signature.teardown_lifecycle_method.method_value() @@ -428,13 +467,22 @@ def invoke_create_tracker(self, restriction): class SimpleInvoker(DoFnInvoker): """An invoker that processes elements ignoring windowing information.""" - def __init__(self, output_processor, signature): + def __init__(self, + output_processor, # type: Optional[_OutputProcessor] + signature # type: DoFnSignature + ): + # type: (...) -> None super(SimpleInvoker, self).__init__(output_processor, signature) self.process_method = signature.process_method.method_value - def invoke_process(self, windowed_value, restriction_tracker=None, - output_processor=None, - additional_args=None, additional_kwargs=None): + def invoke_process(self, + windowed_value, # type: WindowedValue + restriction_tracker=None, # type: Optional[iobase.RestrictionTracker] + output_processor=None, # type: Optional[OutputProcessor] + additional_args=None, + additional_kwargs=None + ): + # type: (...) -> None if not output_processor: output_processor = self.output_processor output_processor.process_outputs( @@ -444,9 +492,16 @@ def invoke_process(self, windowed_value, restriction_tracker=None, class PerWindowInvoker(DoFnInvoker): """An invoker that processes elements considering windowing information.""" - def __init__(self, output_processor, signature, context, - side_inputs, input_args, input_kwargs, user_state_context, - bundle_finalizer_param): + def __init__(self, + output_processor, # type: Optional[_OutputProcessor] + signature, # type: DoFnSignature + context, # type: DoFnContext + side_inputs, # type: Iterable[sideinputs.SideInputMap] + input_args, + input_kwargs, + user_state_context, # type: Optional[userstate.UserStateContext] + bundle_finalizer_param # type: Optional[core._BundleFinalizerParam] + ): super(PerWindowInvoker, self).__init__(output_processor, signature) self.side_inputs = side_inputs self.context = context @@ -458,8 +513,8 @@ def __init__(self, output_processor, signature, context, signature.is_stateful_dofn()) self.user_state_context = user_state_context self.is_splittable = signature.is_splittable_dofn() - self.restriction_tracker = None - self.current_windowed_value = None + self.restriction_tracker = None # type: Optional[iobase.RestrictionTracker] + self.current_windowed_value = None # type: Optional[WindowedValue] self.bundle_finalizer_param = bundle_finalizer_param self.is_key_param_required = False @@ -536,9 +591,14 @@ def __init__(self, placeholder): self.args_for_process = args_with_placeholders self.kwargs_for_process = input_kwargs - def invoke_process(self, windowed_value, restriction_tracker=None, - output_processor=None, - additional_args=None, additional_kwargs=None): + def invoke_process(self, + windowed_value, # type: WindowedValue + restriction_tracker=None, + output_processor=None, # type: Optional[OutputProcessor] + additional_args=None, + additional_kwargs=None + ): + # type: (...) -> Optional[Tuple[WindowedValue, Timestamp]] if not additional_args: additional_args = [] if not additional_kwargs: @@ -586,10 +646,15 @@ def invoke_process(self, windowed_value, restriction_tracker=None, else: self._invoke_process_per_window( windowed_value, additional_args, additional_kwargs, output_processor) - - def _invoke_process_per_window( - self, windowed_value, additional_args, - additional_kwargs, output_processor): + return None + + def _invoke_process_per_window(self, + windowed_value, # type: WindowedValue + additional_args, + additional_kwargs, + output_processor # type: OutputProcessor + ): + # type: (...) -> Optional[Tuple[WindowedValue, Timestamp]] if self.has_windowed_inputs: window, = windowed_value.windows side_inputs = [si[window] for si in self.side_inputs] @@ -636,9 +701,11 @@ def _invoke_process_per_window( elif core.DoFn.TimestampParam == p: args_for_process[i] = windowed_value.timestamp elif isinstance(p, core.DoFn.StateParam): + assert self.user_state_context is not None args_for_process[i] = ( self.user_state_context.get_state(p.state_spec, key, window)) elif isinstance(p, core.DoFn.TimerParam): + assert self.user_state_context is not None args_for_process[i] = ( self.user_state_context.get_timer(p.timer_spec, key, window)) elif core.DoFn.BundleFinalizerParam == p: @@ -660,6 +727,7 @@ def _invoke_process_per_window( windowed_value, self.process_method(*args_for_process)) if self.is_splittable: + assert self.restriction_tracker is not None deferred_status = self.restriction_tracker.deferred_status() if deferred_status: deferred_restriction, deferred_watermark = deferred_status @@ -669,6 +737,7 @@ def _invoke_process_per_window( return ( windowed_value.with_value(((element, deferred_restriction), size)), deferred_watermark) + return None def try_split(self, fraction): restriction_tracker = self.restriction_tracker @@ -694,9 +763,12 @@ def try_split(self, fraction): current_watermark)) def current_element_progress(self): + # type: () -> Optional[iobase.RestrictionProgress] restriction_tracker = self.restriction_tracker if restriction_tracker: return restriction_tracker.current_progress() + else: + return None class DoFnRunner(Receiver): @@ -706,18 +778,19 @@ class DoFnRunner(Receiver): """ def __init__(self, - fn, + fn, # type: core.DoFn args, kwargs, - side_inputs, + side_inputs, # type: Iterable[sideinputs.SideInputMap] windowing, - tagged_receivers=None, - step_name=None, + tagged_receivers=None, # type: Optional[Mapping[Optional[str], Receiver]] + step_name=None, # type: Optional[str] logging_context=None, state=None, scoped_metrics_container=None, operation_name=None, - user_state_context=None): + user_state_context=None # type: Optional[userstate.UserStateContext] + ): """Initializes a DoFnRunner. Args: @@ -773,15 +846,19 @@ def __init__(self, bundle_finalizer_param=self.bundle_finalizer_param) def receive(self, windowed_value): + # type: (WindowedValue) -> None self.process(windowed_value) def process(self, windowed_value): + # type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]] try: return self.do_fn_invoker.invoke_process(windowed_value) except BaseException as exn: self._reraise_augmented(exn) + return None def process_with_sized_restriction(self, windowed_value): + # type: (WindowedValue) -> Optional[Tuple[WindowedValue, Timestamp]] (element, restriction), _ = windowed_value.value return self.do_fn_invoker.invoke_process( windowed_value.with_value(element), @@ -792,6 +869,7 @@ def try_split(self, fraction): return self.do_fn_invoker.try_split(fraction) def current_element_progress(self): + # type: () -> Optional[iobase.RestrictionProgress] return self.do_fn_invoker.current_element_progress() def process_user_timer(self, timer_spec, key, window, timestamp): @@ -852,6 +930,7 @@ def _reraise_augmented(self, exn): class OutputProcessor(object): def process_outputs(self, windowed_input_element, results): + # type: (WindowedValue, Iterable[Any]) -> None raise NotImplementedError @@ -860,8 +939,8 @@ class _OutputProcessor(OutputProcessor): def __init__(self, window_fn, - main_receivers, - tagged_receivers, + main_receivers, # type: Receiver + tagged_receivers, # type: Mapping[Optional[str], Receiver] per_element_output_counter): """Initializes ``_OutputProcessor``. @@ -878,6 +957,7 @@ def __init__(self, self.per_element_output_counter = per_element_output_counter def process_outputs(self, windowed_input_element, results): + # type: (WindowedValue, Iterable[Any]) -> None """Dispatch the result of process computation to the appropriate receivers. A value wrapped in a TaggedOutput object will be unwrapped and @@ -1008,6 +1088,7 @@ def __init__(self, label, element=None, state=None): self.set_element(element) def set_element(self, windowed_value): + # type: (Optional[WindowedValue]) -> None self.windowed_value = windowed_value @property diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 4928550143d6..0646da3cb377 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -64,10 +64,10 @@ from apache_beam.utils import proto_utils from apache_beam.utils.plugin import BeamPlugin -try: # Python 3 +if sys.version_info[0] > 2: unquote_to_bytes = urllib.parse.unquote_to_bytes quote = urllib.parse.quote -except AttributeError: # Python 2 +else: # pylint: disable=deprecated-urllib-function unquote_to_bytes = urllib.unquote quote = urllib.quote diff --git a/sdks/python/apache_beam/runners/direct/bundle_factory.py b/sdks/python/apache_beam/runners/direct/bundle_factory.py index 558e925df3e3..13c19669b72d 100644 --- a/sdks/python/apache_beam/runners/direct/bundle_factory.py +++ b/sdks/python/apache_beam/runners/direct/bundle_factory.py @@ -20,6 +20,10 @@ from __future__ import absolute_import from builtins import object +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Union from apache_beam import pvalue from apache_beam.runners import common @@ -38,12 +42,15 @@ class BundleFactory(object): """ def __init__(self, stacked): + # type: (bool) -> None self._stacked = stacked def create_bundle(self, output_pcollection): + # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle return _Bundle(output_pcollection, self._stacked) def create_empty_committed_bundle(self, output_pcollection): + # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle bundle = self.create_bundle(output_pcollection) bundle.commit(None) return bundle @@ -103,6 +110,7 @@ def add_value(self, value): self._appended_values.append(value) def windowed_values(self): + # type: () -> Iterator[WindowedValue] # yield first windowed_value as is, then iterate through # _appended_values to yield WindowedValue on the fly. yield self._initial_windowed_value @@ -111,14 +119,16 @@ def windowed_values(self): self._initial_windowed_value.windows) def __init__(self, pcollection, stacked=True): + # type: (Union[pvalue.PBegin, pvalue.PCollection], bool) -> None assert isinstance(pcollection, (pvalue.PBegin, pvalue.PCollection)) self._pcollection = pcollection - self._elements = [] + self._elements = [] # type: List[Union[WindowedValue, _Bundle._StackedWindowedValues]] self._stacked = stacked self._committed = False self._tag = None # optional tag information for this bundle def get_elements_iterable(self, make_copy=False): + # type: (bool) -> Iterable[WindowedValue] """Returns iterable elements. Args: @@ -189,6 +199,7 @@ def output(self, element): self.add(element) def receive(self, element): + # type: (WindowedValue) -> None self.add(element) def commit(self, synchronized_processing_time): diff --git a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py index d625d3ce5cee..df3c1f87adad 100644 --- a/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py +++ b/sdks/python/apache_beam/runners/direct/consumer_tracking_pipeline_visitor.py @@ -19,9 +19,17 @@ from __future__ import absolute_import +from typing import TYPE_CHECKING +from typing import Dict +from typing import List +from typing import Set + from apache_beam import pvalue from apache_beam.pipeline import PipelineVisitor +if TYPE_CHECKING: + from apache_beam.pipeline import AppliedPTransform + class ConsumerTrackingPipelineVisitor(PipelineVisitor): """For internal use only; no backwards-compatibility guarantees. @@ -34,14 +42,15 @@ class ConsumerTrackingPipelineVisitor(PipelineVisitor): """ def __init__(self): - self.value_to_consumers = {} # Map from PValue to [AppliedPTransform]. - self.root_transforms = set() # set of (root) AppliedPTransforms. - self.views = [] # list of side inputs. - self.step_names = {} # Map from AppliedPTransform to String. + self.value_to_consumers = {} # type: Dict[pvalue.PValue, List[AppliedPTransform]] + self.root_transforms = set() # type: Set[AppliedPTransform] + self.views = [] # type: List[pvalue.AsSideInput] + self.step_names = {} # type: Dict[AppliedPTransform, str] self._num_transforms = 0 def visit_transform(self, applied_ptransform): + # type: (AppliedPTransform) -> None inputs = list(applied_ptransform.inputs) if inputs: for input_value in inputs: diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index d85fc97db098..745debfb5a5a 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -247,6 +247,7 @@ def __init__(self, source): def _infer_output_coder(self, unused_input_type=None, unused_input_coder=None): + # type: (...) -> typing.Optional[coders.Coder] return coders.BytesCoder() def get_windowing(self, inputs): diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py index 54397b89cee2..e3b67c12ab6c 100644 --- a/sdks/python/apache_beam/runners/direct/evaluation_context.py +++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py @@ -22,6 +22,15 @@ import collections import threading from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import DefaultDict +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union from apache_beam.runners.direct.direct_metrics import DirectMetrics from apache_beam.runners.direct.executor import TransformExecutor @@ -30,6 +39,15 @@ from apache_beam.transforms.trigger import InMemoryUnmergedState from apache_beam.utils import counters +if TYPE_CHECKING: + from apache_beam import pvalue + from apache_beam.pipeline import AppliedPTransform + from apache_beam.runners.direct.bundle_factory import BundleFactory, _Bundle + from apache_beam.runners.direct.util import TimerFiring + from apache_beam.runners.direct.util import TransformResult + from apache_beam.runners.direct.watermark_manager import _TransformWatermarks + from apache_beam.utils.timestamp import Timestamp + class _ExecutionContext(object): """Contains the context for the execution of a single PTransform. @@ -37,7 +55,10 @@ class _ExecutionContext(object): It holds the watermarks for that transform, as well as keyed states. """ - def __init__(self, watermarks, keyed_states): + def __init__(self, + watermarks, # type: _TransformWatermarks + keyed_states + ): self.watermarks = watermarks self.keyed_states = keyed_states @@ -75,10 +96,11 @@ class _SideInputsContainer(object): """ def __init__(self, side_inputs): + # type: (Iterable[pvalue.AsSideInput]) -> None self._lock = threading.Lock() - self._views = {} - self._transform_to_side_inputs = collections.defaultdict(list) - self._side_input_to_blocked_tasks = collections.defaultdict(list) + self._views = {} # type: Dict[pvalue.AsSideInput, _SideInputView] + self._transform_to_side_inputs = collections.defaultdict(list) # type: DefaultDict[Optional[AppliedPTransform], List[pvalue.AsSideInput]] + self._side_input_to_blocked_tasks = collections.defaultdict(list) # type: ignore # usused? for side in side_inputs: self._views[side] = _SideInputView(side) @@ -89,7 +111,12 @@ def __repr__(self): if self._views else '[]') return '_SideInputsContainer(_views=%s)' % views_string - def get_value_or_block_until_ready(self, side_input, task, block_until): + def get_value_or_block_until_ready(self, + side_input, + task, # type: TransformExecutor + block_until # type: Timestamp + ): + # type: (...) -> Any """Returns the value of a view whose task is unblocked or blocks its task. It gets the value of a view whose watermark has been updated and @@ -121,6 +148,7 @@ def add_values(self, side_input, values): def update_watermarks_for_transform_and_unblock_tasks(self, ptransform, watermark): + # type: (...) -> List[Tuple[TransformExecutor, Timestamp]] """Updates _SideInputsContainer after a watermark update and unbloks tasks. It traverses the list of side inputs per PTransform and calls @@ -143,6 +171,7 @@ def update_watermarks_for_transform_and_unblock_tasks(self, def _update_watermarks_for_side_input_and_unblock_tasks(self, side_input, watermark): + # type: (...) -> List[Tuple[TransformExecutor, Timestamp]] """Helps update _SideInputsContainer after a watermark update. For each view of the side input, it updates the value of the watermark @@ -210,15 +239,22 @@ class EvaluationContext(object): global watermarks, and executing any callbacks that can be executed. """ - def __init__(self, pipeline_options, bundle_factory, root_transforms, - value_to_consumers, step_names, views, clock): + def __init__(self, + pipeline_options, + bundle_factory, # type: BundleFactory + root_transforms, + value_to_consumers, + step_names, + views, # type: Iterable[pvalue.AsSideInput] + clock + ): self.pipeline_options = pipeline_options self._bundle_factory = bundle_factory self._root_transforms = root_transforms self._value_to_consumers = value_to_consumers self._step_names = step_names self.views = views - self._pcollection_to_views = collections.defaultdict(list) + self._pcollection_to_views = collections.defaultdict(list) # type: DefaultDict[pvalue.PCollection, List[pvalue.AsSideInput]] for view in views: self._pcollection_to_views[view.pvalue].append(view) self._transform_keyed_states = self._initialize_keyed_states( @@ -227,7 +263,7 @@ def __init__(self, pipeline_options, bundle_factory, root_transforms, self._watermark_manager = WatermarkManager( clock, root_transforms, value_to_consumers, self._transform_keyed_states) - self._pending_unblocked_tasks = [] + self._pending_unblocked_tasks = [] # type: List[Tuple[TransformExecutor, Timestamp]] self._counter_factory = counters.CounterFactory() self._metrics = DirectMetrics() @@ -251,10 +287,15 @@ def metrics(self): return self._metrics def is_root_transform(self, applied_ptransform): + # type: (AppliedPTransform) -> bool return applied_ptransform in self._root_transforms - def handle_result( - self, completed_bundle, completed_timers, result): + def handle_result(self, + completed_bundle, # type: _Bundle + completed_timers, + result # type: TransformResult + ): + """Handle the provided result produced after evaluating the input bundle. Handle the provided TransformResult, produced after evaluating @@ -303,7 +344,10 @@ def handle_result( existing_keyed_state[k] = v return committed_bundles - def _update_side_inputs_container(self, committed_bundles, result): + def _update_side_inputs_container(self, + committed_bundles, # type: Iterable[_Bundle] + result # type: TransformResult + ): """Update the side inputs container if we are outputting into a side input. Look at the result, and if it's outputing into a PCollection that we have @@ -330,7 +374,11 @@ def schedule_pending_unblocked_tasks(self, executor_service): executor_service.submit(task) self._pending_unblocked_tasks = [] - def _commit_bundles(self, uncommitted_bundles, unprocessed_bundles): + def _commit_bundles(self, + uncommitted_bundles, # type: Iterable[_Bundle] + unprocessed_bundles # type: Iterable[_Bundle] + ): + # type: (...) -> Tuple[Tuple[_Bundle, ...], Tuple[_Bundle, ...]] """Commits bundles and returns a immutable set of committed bundles.""" for in_progress_bundle in uncommitted_bundles: producing_applied_ptransform = in_progress_bundle.pcollection.producer @@ -343,23 +391,28 @@ def _commit_bundles(self, uncommitted_bundles, unprocessed_bundles): return tuple(uncommitted_bundles), tuple(unprocessed_bundles) def get_execution_context(self, applied_ptransform): + # type: (AppliedPTransform) -> _ExecutionContext return _ExecutionContext( self._watermark_manager.get_watermarks(applied_ptransform), self._transform_keyed_states[applied_ptransform]) def create_bundle(self, output_pcollection): + # type: (Union[pvalue.PBegin, pvalue.PCollection]) -> _Bundle """Create an uncommitted bundle for the specified PCollection.""" return self._bundle_factory.create_bundle(output_pcollection) def create_empty_committed_bundle(self, output_pcollection): + # type: (pvalue.PCollection) -> _Bundle """Create empty bundle useful for triggering evaluation.""" return self._bundle_factory.create_empty_committed_bundle( output_pcollection) def extract_all_timers(self): + # type: () -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool] return self._watermark_manager.extract_all_timers() def is_done(self, transform=None): + # type: (Optional[AppliedPTransform]) -> bool """Checks completion of a step or the pipeline. Args: @@ -378,6 +431,7 @@ def is_done(self, transform=None): return True def _is_transform_done(self, transform): + # type: (AppliedPTransform) -> bool tw = self._watermark_manager.get_watermarks(transform) return tw.output_watermark == WatermarkManager.WATERMARK_POS_INF diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py index 32a6b32ea561..69e873146b88 100644 --- a/sdks/python/apache_beam/runners/direct/executor.py +++ b/sdks/python/apache_beam/runners/direct/executor.py @@ -27,6 +27,12 @@ import traceback from builtins import object from builtins import range +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import FrozenSet +from typing import Optional +from typing import Set from weakref import WeakValueDictionary from future.moves import queue @@ -37,6 +43,12 @@ from apache_beam.transforms import sideinputs from apache_beam.utils import counters +if TYPE_CHECKING: + from apache_beam import pvalue + from apache_beam.runners.direct.bundle_factory import _Bundle + from apache_beam.runners.direct.evaluation_context import EvaluationContext + from apache_beam.runners.direct.transform_evaluator import TransformEvaluatorRegistry + class _ExecutorService(object): """Thread pool for executing tasks in parallel.""" @@ -56,7 +68,10 @@ class _ExecutorServiceWorker(threading.Thread): # Amount to block waiting for getting an item from the queue in seconds. TIMEOUT = 5 - def __init__(self, queue, index): + def __init__(self, + queue, # type: queue.Queue[_ExecutorService.CallableTask] + index + ): super(_ExecutorService._ExecutorServiceWorker, self).__init__() self.queue = queue self._index = index @@ -77,6 +92,7 @@ def _update_name(self, task=None): self._index, name, 'executing' if task else 'idle') def _get_task_or_none(self): + # type: () -> Optional[_ExecutorService.CallableTask] try: # Do not block indefinitely, otherwise we may not act for a requested # shutdown. @@ -103,12 +119,13 @@ def shutdown(self): self.shutdown_requested = True def __init__(self, num_workers): - self.queue = queue.Queue() + self.queue = queue.Queue() # type: queue.Queue[_ExecutorService.CallableTask] self.workers = [_ExecutorService._ExecutorServiceWorker( self.queue, i) for i in range(num_workers)] self.shutdown_requested = False def submit(self, task): + # type: (_ExecutorService.CallableTask) -> None assert isinstance(task, _ExecutorService.CallableTask) if not self.shutdown_requested: self.queue.put(task) @@ -136,7 +153,10 @@ def shutdown(self): class _TransformEvaluationState(object): - def __init__(self, executor_service, scheduled): + def __init__(self, + executor_service, + scheduled # type: Set[TransformExecutor] + ): self.executor_service = executor_service self.scheduled = scheduled @@ -204,16 +224,19 @@ class _TransformExecutorServices(object): """ def __init__(self, executor_service): + # type: (_ExecutorService) -> None self._executor_service = executor_service - self._scheduled = set() + self._scheduled = set() # type: Set[TransformExecutor] self._parallel = _ParallelEvaluationState( self._executor_service, self._scheduled) - self._serial_cache = WeakValueDictionary() + self._serial_cache = WeakValueDictionary() # type: WeakValueDictionary[Any, _SerialEvaluationState] def parallel(self): + # type: () -> _ParallelEvaluationState return self._parallel def serial(self, step): + # type: (Any) -> _SerialEvaluationState cached = self._serial_cache.get(step) if not cached: cached = _SerialEvaluationState(self._executor_service, self._scheduled) @@ -222,6 +245,7 @@ def serial(self, step): @property def executors(self): + # type: () -> FrozenSet[TransformExecutor] return frozenset(self._scheduled) @@ -233,7 +257,11 @@ class _CompletionCallback(object): or for a source transform. """ - def __init__(self, evaluation_context, all_updates, timer_firings=None): + def __init__(self, + evaluation_context, # type: EvaluationContext + all_updates, + timer_firings=None + ): self._evaluation_context = evaluation_context self._all_updates = all_updates self._timer_firings = timer_firings or [] @@ -271,9 +299,15 @@ class TransformExecutor(_ExecutorService.CallableTask): _MAX_RETRY_PER_BUNDLE = 4 - def __init__(self, transform_evaluator_registry, evaluation_context, - input_bundle, fired_timers, applied_ptransform, - completion_callback, transform_evaluation_state): + def __init__(self, + transform_evaluator_registry, # type: TransformEvaluatorRegistry + evaluation_context, # type: EvaluationContext + input_bundle, # type: _Bundle + fired_timers, + applied_ptransform, + completion_callback, + transform_evaluation_state # type: _TransformEvaluationState + ): self._transform_evaluator_registry = transform_evaluator_registry self._evaluation_context = evaluation_context self._input_bundle = input_bundle @@ -289,7 +323,7 @@ def __init__(self, transform_evaluator_registry, evaluation_context, self._applied_ptransform = applied_ptransform self._completion_callback = completion_callback self._transform_evaluation_state = transform_evaluation_state - self._side_input_values = {} + self._side_input_values = {} # type: Dict[pvalue.AsSideInput, Any] self.blocked = False self._call_count = 0 self._retry_count = 0 @@ -408,8 +442,11 @@ class _ExecutorServiceParallelExecutor(object): NUM_WORKERS = 1 - def __init__(self, value_to_consumers, transform_evaluator_registry, - evaluation_context): + def __init__(self, + value_to_consumers, + transform_evaluator_registry, + evaluation_context # type: EvaluationContext + ): self.executor_service = _ExecutorService( _ExecutorServiceParallelExecutor.NUM_WORKERS) self.transform_executor_services = _TransformExecutorServices( @@ -452,6 +489,7 @@ def request_shutdown(self): self.executor_service.shutdown() def schedule_consumers(self, committed_bundle): + # type: (_Bundle) -> None if committed_bundle.pcollection in self.value_to_consumers: consumers = self.value_to_consumers[committed_bundle.pcollection] for applied_ptransform in consumers: @@ -462,8 +500,12 @@ def schedule_unprocessed_bundle(self, applied_ptransform, unprocessed_bundle): self.node_to_pending_bundles[applied_ptransform].append(unprocessed_bundle) - def schedule_consumption(self, consumer_applied_ptransform, committed_bundle, - fired_timers, on_complete): + def schedule_consumption(self, + consumer_applied_ptransform, + committed_bundle, # type: _Bundle + fired_timers, + on_complete + ): """Schedules evaluation of the given bundle with the transform.""" assert consumer_applied_ptransform assert committed_bundle @@ -471,7 +513,7 @@ def schedule_consumption(self, consumer_applied_ptransform, committed_bundle, if self.transform_evaluator_registry.should_execute_serially( consumer_applied_ptransform): transform_executor_service = self.transform_executor_services.serial( - consumer_applied_ptransform) + consumer_applied_ptransform) # type: _TransformEvaluationState else: transform_executor_service = self.transform_executor_services.parallel() @@ -548,6 +590,7 @@ class _MonitorTask(_ExecutorService.CallableTask): """MonitorTask continuously runs to ensure that pipeline makes progress.""" def __init__(self, executor): + # type: (_ExecutorServiceParallelExecutor) -> None self._executor = executor @property @@ -585,6 +628,7 @@ def call(self, state_sampler): self._executor.executor_service.submit(self) def _should_shutdown(self): + # type: () -> bool """Checks whether the pipeline is completed and should be shut down. If there is anything in the queue of tasks to do or @@ -646,6 +690,7 @@ def _fire_timers(self): return bool(transform_fired_timers) def _is_executing(self): + # type: () -> bool """Checks whether the job is still executing. Returns: diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py index 307679018a5b..f75372e5936f 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py @@ -23,6 +23,8 @@ from builtins import object from threading import Lock from threading import Timer +from typing import Any +from typing import Iterable import apache_beam as beam from apache_beam import TimeDomain @@ -360,4 +362,5 @@ def __init__(self): self.output_iter = None def process_outputs(self, windowed_input_element, output_iter): + # type: (WindowedValue, Iterable[Any]) -> None self.output_iter = output_iter diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index e1fc3cd1640d..9263f2d8f3df 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -24,8 +24,14 @@ import logging import random import time -import typing from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple +from typing import Type +from typing import Union from future.utils import iteritems @@ -65,6 +71,12 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import Timestamp +if TYPE_CHECKING: + from apache_beam.io.gcp.pubsub import _PubSubSource + from apache_beam.io.gcp.pubsub import PubsubMessage + from apache_beam.pipeline import AppliedPTransform + from apache_beam.runners.direct.evaluation_context import EvaluationContext + class TransformEvaluatorRegistry(object): """For internal use only; no backwards-compatibility guarantees. @@ -72,9 +84,10 @@ class TransformEvaluatorRegistry(object): Creates instances of TransformEvaluator for the application of a transform. """ - _test_evaluators_overrides = {} + _test_evaluators_overrides = {} # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] def __init__(self, evaluation_context): + # type: (EvaluationContext) -> None assert evaluation_context self._evaluation_context = evaluation_context self._evaluators = { @@ -88,7 +101,7 @@ def __init__(self, evaluation_context): _NativeWrite: _NativeWriteEvaluator, TestStream: _TestStreamEvaluator, ProcessElements: _ProcessElementsEvaluator - } + } # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] self._evaluators.update(self._test_evaluators_overrides) self._root_bundle_providers = { core.PTransform: DefaultRootBundleProvider, @@ -200,8 +213,12 @@ def get_root_bundles(self): class _TransformEvaluator(object): """An evaluator of a specific application of a transform.""" - def __init__(self, evaluation_context, applied_ptransform, - input_committed_bundle, side_inputs): + def __init__(self, + evaluation_context, # type: EvaluationContext + applied_ptransform, # type: AppliedPTransform + input_committed_bundle, + side_inputs + ): self._evaluation_context = evaluation_context self._applied_ptransform = applied_ptransform self._input_committed_bundle = input_committed_bundle @@ -280,6 +297,7 @@ def process_element(self, element): raise NotImplementedError('%s do not process elements.' % type(self)) def finish_bundle(self): + # type: () -> TransformResult """Finishes the bundle and produces output.""" pass @@ -382,7 +400,7 @@ class _PubSubReadEvaluator(_TransformEvaluator): # A mapping of transform to _PubSubSubscriptionWrapper. # TODO(BEAM-7750): Prevents garbage collection of pipeline instances. - _subscription_cache = {} + _subscription_cache = {} # type: Dict[AppliedPTransform, str] def __init__(self, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs): @@ -391,7 +409,7 @@ def __init__(self, evaluation_context, applied_ptransform, evaluation_context, applied_ptransform, input_committed_bundle, side_inputs) - self.source = self._applied_ptransform.transform._source + self.source = self._applied_ptransform.transform._source # type: _PubSubSource if self.source.id_label: raise NotImplementedError( 'DirectRunner: id_label is not supported for PubSub reads') @@ -426,6 +444,7 @@ def process_element(self, element): pass def _read_from_pubsub(self, timestamp_attribute): + # type: (...) -> List[Tuple[Timestamp, PubsubMessage]] from apache_beam.io.gcp.pubsub import PubsubMessage from google.cloud import pubsub @@ -465,6 +484,7 @@ def _get_element(message): return results def finish_bundle(self): + # type: () -> TransformResult data = self._read_from_pubsub(self.source.timestamp_attribute) if data: output_pcollection = list(self._outputs)[0] @@ -481,8 +501,9 @@ def finish_bundle(self): else: bundles = [] if self._applied_ptransform.inputs: - input_pvalue = self._applied_ptransform.inputs[0] + input_pvalue = self._applied_ptransform.inputs[0] # type: Union[pvalue.PBegin, pvalue.PCollection] else: + assert self._applied_ptransform.transform.pipeline is not None input_pvalue = pvalue.PBegin(self._applied_ptransform.transform.pipeline) unprocessed_bundle = self._evaluation_context.create_bundle( input_pvalue) @@ -527,6 +548,7 @@ class NullReceiver(common.Receiver): """Ignores undeclared outputs, default execution mode.""" def receive(self, element): + # type: (WindowedValue) -> None pass class _InMemoryReceiver(common.Receiver): @@ -537,6 +559,7 @@ def __init__(self, target, tag): self._tag = tag def receive(self, element): + # type: (WindowedValue) -> None self._target[self._tag].append(element) def __missing__(self, key): @@ -548,9 +571,13 @@ def __missing__(self, key): class _ParDoEvaluator(_TransformEvaluator): """TransformEvaluator for ParDo transform.""" - def __init__(self, evaluation_context, applied_ptransform, - input_committed_bundle, side_inputs, - perform_dofn_pickle_test=True): + def __init__(self, + evaluation_context, # type: EvaluationContext + applied_ptransform, # type: AppliedPTransform + input_committed_bundle, + side_inputs, + perform_dofn_pickle_test=True + ): super(_ParDoEvaluator, self).__init__( evaluation_context, applied_ptransform, input_committed_bundle, side_inputs) @@ -582,11 +609,11 @@ def start_bundle(self): self.user_timer_map = {} if is_stateful_dofn(dofn): kv_type_hint = self._applied_ptransform.inputs[0].element_type - if kv_type_hint and kv_type_hint != typing.Any: + if kv_type_hint and kv_type_hint != Any: coder = coders.registry.get_coder(kv_type_hint) self.key_coder = coder.key_coder() else: - self.key_coder = coders.registry.get_coder(typing.Any) + self.key_coder = coders.registry.get_coder(Any) self.user_state_context = DirectUserStateContext( self._step_context, dofn, self.key_coder) @@ -742,7 +769,7 @@ def start_bundle(self): # The input type of a GroupByKey will be Tuple[Any, Any] or more specific. kv_type_hint = self._applied_ptransform.inputs[0].element_type key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint - else typing.Any) + else Any) self.key_coder = coders.registry.get_coder(key_type_hint) def process_element(self, element): @@ -798,7 +825,7 @@ def start_bundle(self): # GroupAlsoByWindow will be Tuple[Any, Iter[Any]] or more specific. kv_type_hint = self._applied_ptransform.outputs[None].element_type key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint - else typing.Any) + else Any) self.key_coder = coders.registry.get_coder(key_type_hint) def process_element(self, element): diff --git a/sdks/python/apache_beam/runners/direct/watermark_manager.py b/sdks/python/apache_beam/runners/direct/watermark_manager.py index 23431f16ddbc..6288a870c284 100644 --- a/sdks/python/apache_beam/runners/direct/watermark_manager.py +++ b/sdks/python/apache_beam/runners/direct/watermark_manager.py @@ -21,6 +21,12 @@ import threading from builtins import object +from typing import TYPE_CHECKING +from typing import Dict +from typing import Iterable +from typing import List +from typing import Set +from typing import Tuple from apache_beam import pipeline from apache_beam import pvalue @@ -29,6 +35,11 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import TIME_GRANULARITY +if TYPE_CHECKING: + from apache_beam.pipeline import AppliedPTransform + from apache_beam.runners.direct.bundle_factory import _Bundle + from apache_beam.utils.timestamp import Timestamp + class WatermarkManager(object): """For internal use only; no backwards-compatibility guarantees. @@ -45,7 +56,7 @@ def __init__(self, clock, root_transforms, value_to_consumers, self._value_to_consumers = value_to_consumers self._transform_keyed_states = transform_keyed_states # AppliedPTransform -> TransformWatermarks - self._transform_to_watermarks = {} + self._transform_to_watermarks = {} # type: Dict[AppliedPTransform, _TransformWatermarks] for root_transform in root_transforms: self._transform_to_watermarks[root_transform] = _TransformWatermarks( @@ -61,6 +72,7 @@ def __init__(self, clock, root_transforms, value_to_consumers, self._update_input_transform_watermarks(consumer) def _update_input_transform_watermarks(self, applied_ptransform): + # type: (AppliedPTransform) -> None assert isinstance(applied_ptransform, pipeline.AppliedPTransform) input_transform_watermarks = [] for input_pvalue in applied_ptransform.inputs: @@ -73,6 +85,7 @@ def _update_input_transform_watermarks(self, applied_ptransform): input_transform_watermarks) def get_watermarks(self, applied_ptransform): + # type: (AppliedPTransform) -> _TransformWatermarks """Gets the input and output watermarks for an AppliedPTransform. If the applied_ptransform has not processed any elements, return a @@ -93,9 +106,15 @@ def get_watermarks(self, applied_ptransform): return self._transform_to_watermarks[applied_ptransform] - def update_watermarks(self, completed_committed_bundle, applied_ptransform, - completed_timers, outputs, unprocessed_bundles, - keyed_earliest_holds, side_inputs_container): + def update_watermarks(self, + completed_committed_bundle, # type: _Bundle + applied_ptransform, # type: AppliedPTransform + completed_timers, + outputs, + unprocessed_bundles, + keyed_earliest_holds, + side_inputs_container + ): assert isinstance(applied_ptransform, pipeline.AppliedPTransform) self._update_pending( completed_committed_bundle, applied_ptransform, completed_timers, @@ -104,9 +123,13 @@ def update_watermarks(self, completed_committed_bundle, applied_ptransform, tw.hold(keyed_earliest_holds) return self._refresh_watermarks(applied_ptransform, side_inputs_container) - def _update_pending(self, input_committed_bundle, applied_ptransform, - completed_timers, output_committed_bundles, - unprocessed_bundles): + def _update_pending(self, + input_committed_bundle, + applied_ptransform, # type: AppliedPTransform + completed_timers, + output_committed_bundles, # type: Iterable[_Bundle] + unprocessed_bundles # type: Iterable[_Bundle] + ): """Updated list of pending bundles for the given AppliedPTransform.""" # Update pending elements. Filter out empty bundles. They do not impact @@ -153,9 +176,10 @@ def _refresh_watermarks(self, applied_ptransform, side_inputs_container): return unblocked_tasks def extract_all_timers(self): + # type: () -> Tuple[List[Tuple[AppliedPTransform, List[TimerFiring]]], bool] """Extracts fired timers for all transforms and reports if there are any timers set.""" - all_timers = [] + all_timers = [] # type: List[Tuple[AppliedPTransform, List[TimerFiring]]] has_realtime_timer = False for applied_ptransform, tw in self._transform_to_watermarks.items(): fired_timers, had_realtime_timer = tw.extract_transform_timers() @@ -175,17 +199,19 @@ class _TransformWatermarks(object): def __init__(self, clock, keyed_states, transform): self._clock = clock self._keyed_states = keyed_states - self._input_transform_watermarks = [] + self._input_transform_watermarks = [] # type: List[_TransformWatermarks] self._input_watermark = WatermarkManager.WATERMARK_NEG_INF self._output_watermark = WatermarkManager.WATERMARK_NEG_INF self._keyed_earliest_holds = {} - self._pending = set() # Scheduled bundles targeted for this transform. + # Scheduled bundles targeted for this transform. + self._pending = set() # type: Set[_Bundle] self._fired_timers = set() self._lock = threading.Lock() self._label = str(transform) def update_input_transform_watermarks(self, input_transform_watermarks): + # type: (List[_TransformWatermarks]) -> None with self._lock: self._input_transform_watermarks = input_transform_watermarks @@ -196,11 +222,13 @@ def update_timers(self, completed_timers): @property def input_watermark(self): + # type: () -> Timestamp with self._lock: return self._input_watermark @property def output_watermark(self): + # type: () -> Timestamp with self._lock: return self._output_watermark @@ -213,10 +241,12 @@ def hold(self, keyed_earliest_holds): del self._keyed_earliest_holds[key] def add_pending(self, pending): + # type: (_Bundle) -> None with self._lock: self._pending.add(pending) def remove_pending(self, completed): + # type: (_Bundle) -> None with self._lock: # Ignore repeated removes. This will happen if a transform has a repeated # input. @@ -224,6 +254,7 @@ def remove_pending(self, completed): self._pending.remove(completed) def refresh(self): + # type: () -> bool """Refresh the watermark for a given transform. This method looks at the watermark coming from all input PTransforms, and @@ -272,6 +303,7 @@ def synchronized_processing_output_time(self): return self._clock.time() def extract_transform_timers(self): + # type: () -> Tuple[List[TimerFiring], bool] """Extracts fired timers and reports of any timers set per transform.""" with self._lock: fired_timers = [] diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py index 20d84e3be6fc..72491ac56496 100644 --- a/sdks/python/apache_beam/runners/interactive/cache_manager.py +++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py @@ -22,6 +22,7 @@ import collections import datetime import os +import sys import tempfile import urllib @@ -32,10 +33,10 @@ from apache_beam.io import tfrecordio from apache_beam.transforms import combiners -try: # Python 3 +if sys.version_info[0] > 2: unquote_to_bytes = urllib.parse.unquote_to_bytes quote = urllib.parse.quote -except AttributeError: # Python 2 +else: # pylint: disable=deprecated-urllib-function unquote_to_bytes = urllib.unquote quote = urllib.quote diff --git a/sdks/python/apache_beam/runners/interactive/display/display_manager.py b/sdks/python/apache_beam/runners/interactive/display/display_manager.py index 35f608a503bc..6ddeb36cd6b8 100644 --- a/sdks/python/apache_beam/runners/interactive/display/display_manager.py +++ b/sdks/python/apache_beam/runners/interactive/display/display_manager.py @@ -27,7 +27,7 @@ import collections import threading import time -import typing +from typing import TYPE_CHECKING from apache_beam.runners.interactive.display import interactive_pipeline_graph @@ -36,7 +36,7 @@ # _display_progress defines how outputs are printed on the frontend. _display_progress = IPython.display.display - if not typing.TYPE_CHECKING: + if not TYPE_CHECKING: def _formatter(string, pp, cycle): # pylint: disable=unused-argument pp.text(string) diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py index 23605b535c33..3f8582e4b678 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py @@ -26,6 +26,12 @@ import collections import threading +from typing import DefaultDict +from typing import Dict +from typing import Iterator +from typing import List +from typing import Tuple +from typing import Union import pydot @@ -37,9 +43,10 @@ class PipelineGraph(object): """Creates a DOT representation of the pipeline. Thread-safe.""" def __init__(self, - pipeline, + pipeline, # type: Union[beam_runner_api_pb2.Pipeline, beam.Pipeline] default_vertex_attrs=None, - default_edge_attrs=None): + default_edge_attrs=None + ): """Constructor of PipelineGraph. Examples: @@ -57,7 +64,7 @@ def __init__(self, default_edge_attrs: (Dict[str, str]) a dict of default edge attributes """ self._lock = threading.Lock() - self._graph = None + self._graph = None # type: pydot.Dot if isinstance(pipeline, beam_runner_api_pb2.Pipeline): self._pipeline_proto = pipeline @@ -69,9 +76,9 @@ def __init__(self, type(pipeline))) # A dict from PCollection ID to a list of its consuming Transform IDs - self._consumers = collections.defaultdict(list) + self._consumers = collections.defaultdict(list) # type: DefaultDict[str, List[str]] # A dict from PCollection ID to its producing Transform ID - self._producers = {} + self._producers = {} # type: Dict[str, str] for transform_id, transform_proto in self._top_level_transforms(): for pcoll_id in transform_proto.inputs.values(): @@ -93,9 +100,11 @@ def __init__(self, default_edge_attrs) def get_dot(self): + # type: () -> str return self._get_graph().to_string() def _top_level_transforms(self): + # type: () -> Iterator[Tuple[str, beam_runner_api_pb2.PTransform]] """Yields all top level PTransforms (subtransforms of the root PTransform). Yields: (str, PTransform proto) ID, proto pair of top level PTransforms. diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py index 2df5c612cd39..47337e02c287 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py @@ -27,11 +27,17 @@ import abc import os import subprocess +from typing import TYPE_CHECKING +from typing import Optional +from typing import Type from future.utils import with_metaclass from apache_beam.utils.plugin import BeamPlugin +if TYPE_CHECKING: + from apache_beam.runners.interactive.display.pipeline_graph import PipelineGraph + class PipelineGraphRenderer(with_metaclass(abc.ABCMeta, BeamPlugin)): """Abstract class for renderers, who decide how pipeline graphs are rendered. @@ -40,12 +46,14 @@ class PipelineGraphRenderer(with_metaclass(abc.ABCMeta, BeamPlugin)): @classmethod @abc.abstractmethod def option(cls): + # type: () -> str """The corresponding rendering option for the renderer. """ raise NotImplementedError @abc.abstractmethod def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str """Renders the pipeline graph in HTML-compatible format. Args: @@ -63,9 +71,11 @@ class MuteRenderer(PipelineGraphRenderer): @classmethod def option(cls): + # type: () -> str return 'mute' def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str return '' @@ -75,9 +85,11 @@ class TextRenderer(PipelineGraphRenderer): @classmethod def option(cls): + # type: () -> str return 'text' def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str return pipeline_graph.get_dot() @@ -91,13 +103,16 @@ class PydotRenderer(PipelineGraphRenderer): @classmethod def option(cls): + # type: () -> str return 'graph' def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str return pipeline_graph._get_graph().create_svg().decode("utf-8") # pylint: disable=protected-access def get_renderer(option=None): + # type: (Optional[str]) -> Type[PipelineGraphRenderer] """Get an instance of PipelineGraphRenderer given rendering option. Args: diff --git a/sdks/python/apache_beam/runners/job/utils.py b/sdks/python/apache_beam/runners/job/utils.py index 3e347517972a..4c7c965b7c17 100644 --- a/sdks/python/apache_beam/runners/job/utils.py +++ b/sdks/python/apache_beam/runners/job/utils.py @@ -27,8 +27,10 @@ def dict_to_struct(dict_obj): + # type: (dict) -> struct_pb2.Struct return json_format.ParseDict(dict_obj, struct_pb2.Struct()) def struct_to_dict(struct_obj): + # type: (struct_pb2.Struct) -> dict return json.loads(json_format.MessageToJson(struct_obj)) diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 913dac69d2b5..774446043bc0 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -23,6 +23,12 @@ from __future__ import absolute_import from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Union from apache_beam import coders from apache_beam import pipeline @@ -33,6 +39,11 @@ from apache_beam.transforms import core from apache_beam.typehints import native_type_compatibility +if TYPE_CHECKING: + from google.protobuf import message # pylint: disable=ungrouped-imports + from apache_beam.coders.coder_impl import IterableStateReader + from apache_beam.coders.coder_impl import IterableStateWriter + class Environment(object): """A wrapper around the environment proto. @@ -40,13 +51,16 @@ class Environment(object): Provides consistency with how the other componentes are accessed. """ def __init__(self, proto): + # type: (beam_runner_api_pb2.Environment) -> None self.proto = proto def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.Environment return self.proto @staticmethod def from_runner_api(proto, context): + # type: (beam_runner_api_pb2.Environment, PipelineContext) -> Environment return Environment(proto) @@ -56,16 +70,22 @@ class _PipelineContextMap(object): Under the hood it encodes and decodes these objects into runner API representations. """ - def __init__(self, context, obj_type, namespace, proto_map=None): + def __init__(self, + context, + obj_type, + namespace, # type: str + proto_map=None # type: Optional[Mapping[str, message.Message]] + ): self._pipeline_context = context self._obj_type = obj_type self._namespace = namespace - self._obj_to_id = {} - self._id_to_obj = {} + self._obj_to_id = {} # type: Dict[Any, str] + self._id_to_obj = {} # type: Dict[str, Any] self._id_to_proto = dict(proto_map) if proto_map else {} self._counter = 0 def _unique_ref(self, obj=None, label=None): + # type: (Optional[Any], Optional[str]) -> str self._counter += 1 return "%s_%s_%s_%d" % ( self._namespace, @@ -74,10 +94,12 @@ def _unique_ref(self, obj=None, label=None): self._counter) def populate_map(self, proto_map): + # type: (Mapping[str, message.Message]) -> None for id, proto in self._id_to_proto.items(): proto_map[id].CopyFrom(proto) def get_id(self, obj, label=None): + # type: (Any, Optional[str]) -> str if obj not in self._obj_to_id: id = self._unique_ref(obj, label) self._id_to_obj[id] = obj @@ -86,15 +108,18 @@ def get_id(self, obj, label=None): return self._obj_to_id[obj] def get_proto(self, obj, label=None): + # type: (Any, Optional[str]) -> message.Message return self._id_to_proto[self.get_id(obj, label)] def get_by_id(self, id): + # type: (str) -> Any if id not in self._id_to_obj: self._id_to_obj[id] = self._obj_type.from_runner_api( self._id_to_proto[id], self._pipeline_context) return self._id_to_obj[id] def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): + # type: (message.Message, Optional[str], bool) -> str if deduplicate: for id, proto in self._id_to_proto.items(): if proto == maybe_new_proto: @@ -102,18 +127,22 @@ def get_by_proto(self, maybe_new_proto, label=None, deduplicate=False): return self.put_proto(self._unique_ref(label), maybe_new_proto) def get_id_to_proto_map(self): + # type: () -> Dict[str, message.Message] return self._id_to_proto def put_proto(self, id, proto): + # type: (str, message.Message) -> str if id in self._id_to_proto: raise ValueError("Id '%s' is already taken." % id) self._id_to_proto[id] = proto return id def __getitem__(self, id): + # type: (str) -> Any return self.get_by_id(id) def __contains__(self, id): + # type: (str) -> bool return id in self._id_to_proto @@ -146,7 +175,8 @@ def __init__( self, cls, namespace, getattr(proto, name, None))) if default_environment: self._default_environment_id = self.environments.get_id( - Environment(default_environment), label='default_environment') + Environment(default_environment), + label='default_environment') # type: Optional[str] else: self._default_environment_id = None self.use_fake_coders = use_fake_coders @@ -159,12 +189,14 @@ def __init__( # as well as performing a round-trip through protos. # TODO(BEAM-2717): Remove once this is no longer needed. def coder_id_from_element_type(self, element_type): + # type: (Any) -> str if self.use_fake_coders: return pickler.dumps(element_type) else: return self.coders.get_id(coders.registry.get_coder(element_type)) def element_type_from_coder_id(self, coder_id): + # type: (str) -> Any if self.use_fake_coders or coder_id not in self.coders: return pickler.loads(coder_id) else: @@ -173,13 +205,16 @@ def element_type_from_coder_id(self, coder_id): @staticmethod def from_runner_api(proto): + # type: (beam_runner_api_pb2.Components) -> PipelineContext return PipelineContext(proto) def to_runner_api(self): + # type: () -> beam_runner_api_pb2.Components context_proto = beam_runner_api_pb2.Components() for name in self._COMPONENT_TYPES: getattr(self, name).populate_map(getattr(context_proto, name)) return context_proto def default_environment_id(self): + # type: () -> Optional[str] return self._default_environment_id diff --git a/sdks/python/apache_beam/runners/portability/abstract_job_service.py b/sdks/python/apache_beam/runners/portability/abstract_job_service.py index 982fad1682bc..3361f554a99d 100644 --- a/sdks/python/apache_beam/runners/portability/abstract_job_service.py +++ b/sdks/python/apache_beam/runners/portability/abstract_job_service.py @@ -19,10 +19,20 @@ import logging import uuid from builtins import object +from typing import TYPE_CHECKING +from typing import Dict +from typing import Iterator +from typing import Optional +from typing import Union from apache_beam.portability.api import beam_job_api_pb2 from apache_beam.portability.api import beam_job_api_pb2_grpc +if TYPE_CHECKING: + from google.protobuf import struct_pb2 # pylint: disable=ungrouped-imports + from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.portability.api import endpoints_pb2 + TERMINAL_STATES = [ beam_job_api_pb2.JobState.DONE, beam_job_api_pb2.JobState.STOPPED, @@ -37,13 +47,24 @@ class AbstractJobServiceServicer(beam_job_api_pb2_grpc.JobServiceServicer): Servicer for the Beam Job API. """ def __init__(self): - self._jobs = {} - - def create_beam_job(self, preparation_id, job_name, pipeline, options): + self._jobs = {} # type: Dict[str, AbstractBeamJob] + + def create_beam_job(self, + preparation_id, # stype: str + job_name, # type: str + pipeline, # type: beam_runner_api_pb2.Pipeline + options # type: struct_pb2.Struct + ): + # type: (...) -> AbstractBeamJob """Returns an instance of AbstractBeamJob specific to this servicer.""" raise NotImplementedError(type(self)) - def Prepare(self, request, context=None, timeout=None): + def Prepare(self, + request, # type: beam_job_api_pb2.PrepareJobRequest + context=None, + timeout=None + ): + # type: (...) -> beam_job_api_pb2.PrepareJobResponse logging.debug('Got Prepare request.') preparation_id = '%s-%s' % (request.job_name, uuid.uuid4()) self._jobs[preparation_id] = self.create_beam_job( @@ -59,31 +80,56 @@ def Prepare(self, request, context=None, timeout=None): preparation_id].artifact_staging_endpoint(), staging_session_token=preparation_id) - def Run(self, request, context=None, timeout=None): + def Run(self, + request, # type: beam_job_api_pb2.RunJobRequest + context=None, + timeout=None + ): + # type: (...) -> beam_job_api_pb2.RunJobResponse # For now, just use the preparation id as the job id. job_id = request.preparation_id logging.info("Running job '%s'", job_id) self._jobs[job_id].run() return beam_job_api_pb2.RunJobResponse(job_id=job_id) - def GetJobs(self, request, context=None, timeout=None): + def GetJobs(self, + request, # type: beam_job_api_pb2.GetJobsRequest + context=None, + timeout=None + ): + # type: (...) -> beam_job_api_pb2.GetJobsResponse return beam_job_api_pb2.GetJobsResponse( [job.to_runner_api() for job in self._jobs.values()]) - def GetState(self, request, context=None): + def GetState(self, + request, # type: beam_job_api_pb2.GetJobStateRequest + context=None + ): + # type: (...) -> beam_job_api_pb2.GetJobStateResponse return beam_job_api_pb2.GetJobStateResponse( state=self._jobs[request.job_id].get_state()) - def GetPipeline(self, request, context=None, timeout=None): + def GetPipeline(self, + request, # type: beam_job_api_pb2.GetJobPipelineRequest + context=None, + timeout=None + ): + # type: (...) -> beam_job_api_pb2.GetJobPipelineResponse return beam_job_api_pb2.GetJobPipelineResponse( pipeline=self._jobs[request.job_id].get_pipeline()) - def Cancel(self, request, context=None, timeout=None): + def Cancel(self, + request, # type: beam_job_api_pb2.CancelJobRequest + context=None, + timeout=None + ): + # type: (...) -> beam_job_api_pb2.CancelJobResponse self._jobs[request.job_id].cancel() return beam_job_api_pb2.CancelJobRequest( state=self._jobs[request.job_id].get_state()) def GetStateStream(self, request, context=None, timeout=None): + # type: (...) -> Iterator[beam_job_api_pb2.GetJobStateResponse] """Yields state transitions since the stream started. """ if request.job_id not in self._jobs: @@ -94,6 +140,7 @@ def GetStateStream(self, request, context=None, timeout=None): yield beam_job_api_pb2.GetJobStateResponse(state=state) def GetMessageStream(self, request, context=None, timeout=None): + # type: (...) -> Iterator[beam_job_api_pb2.JobMessagesResponse] """Yields messages since the stream started. """ if request.job_id not in self._jobs: @@ -109,29 +156,59 @@ def GetMessageStream(self, request, context=None, timeout=None): yield resp def DescribePipelineOptions(self, request, context=None, timeout=None): + # type: (...) -> beam_job_api_pb2.DescribePipelineOptionsResponse return beam_job_api_pb2.DescribePipelineOptionsResponse() class AbstractBeamJob(object): """Abstract baseclass for managing a single Beam job.""" - def __init__(self, job_id, job_name, pipeline, options): + def __init__(self, + job_id, # type: str + job_name, # type: str + pipeline, # type: beam_runner_api_pb2.Pipeline + options # type: struct_pb2.Struct + ): self._job_id = job_id self._job_name = job_name self._pipeline_proto = pipeline self._pipeline_options = options - def _to_implement(self): + def prepare(self): + # type: () -> None + """Called immediately after this class is instantiated""" + raise NotImplementedError(self) + + def run(self): + # type: () -> None raise NotImplementedError(self) - prepare = run = cancel = _to_implement - artifact_staging_endpoint = _to_implement - get_state = get_state_stream = get_message_stream = _to_implement + def cancel(self): + # type: () -> Optional[beam_job_api_pb2.JobState.Enum] + raise NotImplementedError(self) + + def artifact_staging_endpoint(self): + # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] + raise NotImplementedError(self) + + def get_state(self): + # type: () -> Optional[beam_job_api_pb2.JobState.Enum] + raise NotImplementedError(self) + + def get_state_stream(self): + # type: () -> Iterator[Optional[beam_job_api_pb2.JobState.Enum]] + raise NotImplementedError(self) + + def get_message_stream(self): + # type: () -> Iterator[Union[int, Optional[beam_job_api_pb2.JobMessage]]] + raise NotImplementedError(self) def get_pipeline(self): + # type: () -> beam_runner_api_pb2.Pipeline return self._pipeline_proto def to_runner_api(self): + # type: () -> beam_job_api_pb2.JobInfo return beam_job_api_pb2.JobInfo( job_id=self._job_id, job_name=self._job_name, diff --git a/sdks/python/apache_beam/runners/portability/artifact_service.py b/sdks/python/apache_beam/runners/portability/artifact_service.py index 100eca5788e6..98df1404c8b5 100644 --- a/sdks/python/apache_beam/runners/portability/artifact_service.py +++ b/sdks/python/apache_beam/runners/portability/artifact_service.py @@ -26,6 +26,7 @@ import sys import threading import zipfile +from typing import Iterator from google.protobuf import json_format @@ -48,39 +49,49 @@ def _sha256(self, string): return hashlib.sha256(string.encode('utf-8')).hexdigest() def _join(self, *args): + # type: (*str) -> str raise NotImplementedError(type(self)) def _dirname(self, path): + # type: (str) -> str raise NotImplementedError(type(self)) def _temp_path(self, path): + # type: (str) -> str return path + '.tmp' def _open(self, path, mode): raise NotImplementedError(type(self)) def _rename(self, src, dest): + # type: (str, str) -> None raise NotImplementedError(type(self)) def _delete(self, path): + # type: (str) -> None raise NotImplementedError(type(self)) def _artifact_path(self, retrieval_token, name): + # type: (str, str) -> str return self._join(self._dirname(retrieval_token), self._sha256(name)) def _manifest_path(self, retrieval_token): + # type: (str) -> str return retrieval_token def _get_manifest_proxy(self, retrieval_token): + # type: (str) -> beam_artifact_api_pb2.ProxyManifest with self._open(self._manifest_path(retrieval_token), 'r') as fin: return json_format.Parse( fin.read().decode('utf-8'), beam_artifact_api_pb2.ProxyManifest()) def retrieval_token(self, staging_session_token): + # type: (str) -> str return self._join( self._root, self._sha256(staging_session_token), 'MANIFEST') def PutArtifact(self, request_iterator, context=None): + # type: (...) -> beam_artifact_api_pb2.PutArtifactResponse first = True for request in request_iterator: if first: @@ -104,7 +115,10 @@ def PutArtifact(self, request_iterator, context=None): self._rename(temp_path, artifact_path) return beam_artifact_api_pb2.PutArtifactResponse() - def CommitManifest(self, request, context=None): + def CommitManifest(self, + request, # type: beam_artifact_api_pb2.CommitManifestRequest + context=None): + # type: (...) -> beam_artifact_api_pb2.CommitManifestResponse retrieval_token = self.retrieval_token(request.staging_session_token) proxy_manifest = beam_artifact_api_pb2.ProxyManifest( manifest=request.manifest, @@ -118,11 +132,17 @@ def CommitManifest(self, request, context=None): return beam_artifact_api_pb2.CommitManifestResponse( retrieval_token=retrieval_token) - def GetManifest(self, request, context=None): + def GetManifest(self, + request, # type: beam_artifact_api_pb2.GetManifestRequest + context=None): + # type: (...) -> beam_artifact_api_pb2.GetManifestResponse return beam_artifact_api_pb2.GetManifestResponse( manifest=self._get_manifest_proxy(request.retrieval_token).manifest) - def GetArtifact(self, request, context=None): + def GetArtifact(self, + request, # type: beam_artifact_api_pb2.GetArtifactRequest + context=None): + # type: (...) -> Iterator[beam_artifact_api_pb2.ArtifactChunk] for artifact in self._get_manifest_proxy(request.retrieval_token).location: if artifact.name == request.name: with self._open(artifact.uri, 'r') as fin: @@ -156,18 +176,23 @@ def __init__(self, path, chunk_size=None): self._lock = threading.Lock() def _join(self, *args): + # type: (*str) -> str return '/'.join(args) def _dirname(self, path): + # type: (str) -> str return path.rsplit('/', 1)[0] def _temp_path(self, path): + # type: (str) -> str return path # ZipFile offers no move operation. def _rename(self, src, dest): + # type: (str, str) -> None assert src == dest def _delete(self, path): + # type: (str) -> None # ZipFile offers no delete operation: https://bugs.python.org/issue6818 pass @@ -205,15 +230,19 @@ def close(self): class BeamFilesystemArtifactService(AbstractArtifactService): def _join(self, *args): + # type: (*str) -> str return filesystems.FileSystems.join(*args) def _dirname(self, path): + # type: (str) -> str return filesystems.FileSystems.split(path)[0] def _rename(self, src, dest): + # type: (str, str) -> None filesystems.FileSystems.rename([src], [dest]) def _delete(self, path): + # type: (str) -> None filesystems.FileSystems.delete([path]) def _open(self, path, mode='r'): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index 05690a3b9f69..ef059a788649 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -34,6 +34,21 @@ import uuid from builtins import object from concurrent import futures +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import DefaultDict +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union import grpc @@ -77,6 +92,26 @@ from apache_beam.utils import profiler from apache_beam.utils import proto_utils +if TYPE_CHECKING: + from google.protobuf import message # pylint: disable=ungrouped-imports + from apache_beam.pipeline import Pipeline + from apache_beam.coders.coder_impl import CoderImpl + from apache_beam.coders.coder_impl import WindowedValueCoderImpl + from apache_beam.portability.api import metrics_pb2 + from apache_beam.transforms.window import BoundedWindow + +T = TypeVar('T') +ConstructorFn = Callable[ + [Union['message.Message', bytes], + 'FnApiRunner.StateServicer', + Optional['ExtendedProvisionInfo'], + 'GrpcServer'], + 'WorkerHandler'] +DataSideInput = Dict[Tuple[str, str], + Tuple[bytes, beam_runner_api_pb2.FunctionSpec]] +DataOutput = Dict[str, bytes] +BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse, List[beam_fn_api_pb2.ProcessBundleSplitResponse]] + # This module is experimental. No backwards-compatibility guarantees. ENCODED_IMPULSE_VALUE = beam.coders.WindowedValueCoder( @@ -96,9 +131,9 @@ class ControlConnection(object): _lock = threading.Lock() def __init__(self): - self._push_queue = queue.Queue() - self._input = None - self._futures_by_id = dict() + self._push_queue = queue.Queue() # type: queue.Queue[beam_fn_api_pb2.InstructionRequest] + self._input = None # type: Optional[Iterable[beam_fn_api_pb2.InstructionResponse]] + self._futures_by_id = dict() # type: Dict[str, ControlFuture] self._read_thread = threading.Thread( name='beam_control_read', target=self._read) self._state = BeamFnControlServicer.UNSTARTED_STATE @@ -108,6 +143,7 @@ def _read(self): self._futures_by_id.pop(data.instruction_id).set(data) def push(self, req): + # type: (...) -> Optional[ControlFuture] if req == BeamFnControlServicer._DONE_MARKER: self._push_queue.put(req) return None @@ -121,9 +157,11 @@ def push(self, req): return future def get_req(self): + # type: () -> beam_fn_api_pb2.InstructionRequest return self._push_queue.get() def set_input(self, input): + # type: (Iterable[beam_fn_api_pb2.InstructionResponse]) -> None with ControlConnection._lock: if self._input: raise RuntimeError('input is already set.') @@ -132,6 +170,7 @@ def set_input(self, input): self._state = BeamFnControlServicer.STARTED_STATE def close(self): + # type: () -> None with ControlConnection._lock: if self._state == BeamFnControlServicer.STARTED_STATE: self.push(BeamFnControlServicer._DONE_MARKER) @@ -157,13 +196,18 @@ def __init__(self): self._req_sent = collections.defaultdict(int) self._req_worker_mapping = {} self._log_req = logging.getLogger().getEffectiveLevel() <= logging.DEBUG - self._connections_by_worker_id = collections.defaultdict(ControlConnection) + self._connections_by_worker_id = collections.defaultdict(ControlConnection) # type: DefaultDict[str, ControlConnection] def get_conn_by_worker_id(self, worker_id): + # type: (str) -> ControlConnection with self._lock: return self._connections_by_worker_id[worker_id] - def Control(self, iterator, context): + def Control(self, + iterator, # type: Iterable[beam_fn_api_pb2.InstructionResponse] + context + ): + # type: (...) -> Iterator[beam_fn_api_pb2.InstructionRequest] with self._lock: if self._state == self.DONE_STATE: return @@ -198,20 +242,27 @@ def done(self): class _ListBuffer(list): """Used to support parititioning of a list.""" def partition(self, n): + # type: (int) -> List[List[bytes]] return [self[k::n] for k in range(n)] class _GroupingBuffer(object): """Used to accumulate groupded (shuffled) results.""" - def __init__(self, pre_grouped_coder, post_grouped_coder, windowing): + def __init__(self, + pre_grouped_coder, # type: coders.Coder + post_grouped_coder, # type: coders.Coder + windowing + ): + # type: (...) -> None self._key_coder = pre_grouped_coder.key_coder() self._pre_grouped_coder = pre_grouped_coder self._post_grouped_coder = post_grouped_coder - self._table = collections.defaultdict(list) + self._table = collections.defaultdict(list) # type: Optional[DefaultDict[bytes, List[Any]]] self._windowing = windowing - self._grouped_output = None + self._grouped_output = None # type: Optional[List[List[bytes]]] def append(self, elements_data): + # type: (bytes) -> None if self._grouped_output: raise RuntimeError('Grouping table append after read.') input_stream = create_InputStream(elements_data) @@ -228,6 +279,7 @@ def append(self, elements_data): else windowed_key_value.with_value(value)) def partition(self, n): + # type: (int) -> List[List[bytes]] """ It is used to partition _GroupingBuffer to N parts. Once it is partitioned, it would not be re-partitioned with diff N. Re-partition is not supported now. @@ -260,6 +312,7 @@ def partition(self, n): return self._grouped_output def __iter__(self): + # type: () -> Iterator[bytes] """ Since partition() returns a list of lists, add this __iter__ to return a list to simplify code when we need to iterate through ALL elements of _GroupingBuffer. @@ -269,12 +322,16 @@ def __iter__(self): class _WindowGroupingBuffer(object): """Used to partition windowed side inputs.""" - def __init__(self, access_pattern, coder): + def __init__(self, + access_pattern, + coder # type: coders.WindowedValueCoder + ): + # type: (...) -> None # Here's where we would use a different type of partitioning # (e.g. also by key) for a different access pattern. if access_pattern.urn == common_urns.side_inputs.ITERABLE.urn: self._kv_extrator = lambda value: ('', value) - self._key_coder = coders.SingletonCoder('') + self._key_coder = coders.SingletonCoder('') # type: coders.Coder self._value_coder = coder.wrapped_value_coder elif access_pattern.urn == common_urns.side_inputs.MULTIMAP.urn: self._kv_extrator = lambda value: value @@ -286,18 +343,21 @@ def __init__(self, access_pattern, coder): "Unknown access pattern: '%s'" % access_pattern.urn) self._windowed_value_coder = coder self._window_coder = coder.window_coder - self._values_by_window = collections.defaultdict(list) + self._values_by_window = collections.defaultdict(list) # type: DefaultDict[Tuple[str, BoundedWindow], List[Any]] def append(self, elements_data): + # type: (bytes) -> None input_stream = create_InputStream(elements_data) while input_stream.size() > 0: - windowed_value = self._windowed_value_coder.get_impl( - ).decode_from_stream(input_stream, True) + windowed_val_coder_impl = self._windowed_value_coder.get_impl() # type: WindowedValueCoderImpl + windowed_value = windowed_val_coder_impl.decode_from_stream( + input_stream, True) key, value = self._kv_extrator(windowed_value.value) for window in windowed_value.windows: self._values_by_window[key, window].append(value) def encoded_items(self): + # type: () -> Iterator[Tuple[bytes, bytes, bytes]] value_coder_impl = self._value_coder.get_impl() key_coder_impl = self._key_coder.get_impl() for (key, window), values in self._values_by_window.items(): @@ -313,11 +373,12 @@ class FnApiRunner(runner.PipelineRunner): def __init__( self, - default_environment=None, + default_environment=None, # type: Optional[beam_runner_api_pb2.Environment] bundle_repeat=0, use_state_iterables=False, - provision_info=None, + provision_info=None, # type: Optional[ExtendedProvisionInfo] progress_request_frequency=None): + # type: (...) -> None """Creates a new Fn API Runner. Args: @@ -338,7 +399,7 @@ def __init__( self._bundle_repeat = bundle_repeat self._num_workers = 1 self._progress_frequency = progress_request_frequency - self._profiler_factory = None + self._profiler_factory = None # type: Optional[Callable[..., profiler.Profile]] self._use_state_iterables = use_state_iterables self._provision_info = provision_info or ExtendedProvisionInfo( beam_provision_api_pb2.ProvisionInfo( @@ -350,7 +411,11 @@ def _next_uid(self): self._last_uid += 1 return str(self._last_uid) - def run_pipeline(self, pipeline, options): + def run_pipeline(self, + pipeline, # type: Pipeline + options # type: pipeline_options.PipelineOptions + ): + # type: (...) -> RunnerResult MetricsEnvironment.set_metrics_supported(False) RuntimeValueProvider.set_runtime_options({}) @@ -382,6 +447,7 @@ def run_pipeline(self, pipeline, options): return self._latest_run_result def run_via_runner_api(self, pipeline_proto): + # type: (beam_runner_api_pb2.Pipeline) -> RunnerResult stage_context, stages = self.create_stages(pipeline_proto) # TODO(pabloem, BEAM-7514): Create a watermark manager (that has access to # the teststream (if any), and all the stages). @@ -426,7 +492,10 @@ def maybe_profile(self): # Empty context. yield - def create_stages(self, pipeline_proto): + def create_stages(self, + pipeline_proto # type: beam_runner_api_pb2.Pipeline + ): + # type: (...) -> Tuple[fn_api_runner_transforms.TransformContext, List[fn_api_runner_transforms.Stage]] return fn_api_runner_transforms.create_and_optimize_stages( copy.deepcopy(pipeline_proto), phases=[fn_api_runner_transforms.annotate_downstream_side_inputs, @@ -446,7 +515,11 @@ def create_stages(self, pipeline_proto): common_urns.primitives.GROUP_BY_KEY.urn]), use_state_iterables=self._use_state_iterables) - def run_stages(self, stage_context, stages): + def run_stages(self, + stage_context, # type: fn_api_runner_transforms.TransformContext + stages # type: List[fn_api_runner_transforms.Stage] + ): + # type: (...) -> RunnerResult """Run a list of topologically-sorted stages in batch mode. Args: @@ -460,7 +533,7 @@ def run_stages(self, stage_context, stages): try: with self.maybe_profile(): - pcoll_buffers = collections.defaultdict(_ListBuffer) + pcoll_buffers = collections.defaultdict(_ListBuffer) # type: DefaultDict[bytes, _ListBuffer] for stage in stages: stage_results = self._run_stage( worker_handler_manager.get_worker_handlers, @@ -477,11 +550,11 @@ def run_stages(self, stage_context, stages): runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage) def _store_side_inputs_in_state(self, - worker_handler, - context, - pipeline_components, - data_side_input, - pcoll_buffers, + worker_handler, # type: WorkerHandler + context, # type: pipeline_context.PipelineContext + pipeline_components, # type: beam_runner_api_pb2.Components + data_side_input, # type: DataSideInput + pcoll_buffers, # type: Mapping[bytes, _ListBuffer] safe_coders): for (transform_id, tag), (buffer_id, si) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) @@ -500,9 +573,18 @@ def _store_side_inputs_in_state(self, worker_handler.state.append_raw(state_key, elements_data) def _run_bundle_multiple_times_for_testing( - self, worker_handler_list, process_bundle_descriptor, data_input, - data_output, get_input_coder_callable, cache_token_generator): - + self, + worker_handler_list, # type: Sequence[WorkerHandler] + process_bundle_descriptor, + data_input, + data_output, # type: DataOutput + get_input_coder_callable, + cache_token_generator + ): + # type: (...) -> None + """ + If bundle_repeat > 0, replay every bundle for profiling and debugging. + """ # all workers share state, so use any worker_handler. worker_handler = worker_handler_list[0] for k in range(self._bundle_repeat): @@ -519,12 +601,14 @@ def _run_bundle_multiple_times_for_testing( finally: worker_handler.state.restore() - def _collect_written_timers_and_add_to_deferred_inputs(self, - context, - pipeline_components, - stage, - get_buffer_callable, - deferred_inputs): + def _collect_written_timers_and_add_to_deferred_inputs( + self, + context, # type: pipeline_context.PipelineContext + pipeline_components, # type: beam_runner_api_pb2.Components + stage, # type: fn_api_runner_transforms.Stage + get_buffer_callable, + deferred_inputs # type: DefaultDict[str, _ListBuffer] + ): for transform_id, timer_writes in stage.timer_pcollections: @@ -554,9 +638,15 @@ def _collect_written_timers_and_add_to_deferred_inputs(self, written_timers[:] = [] def _add_residuals_and_channel_splits_to_deferred_inputs( - self, splits, get_input_coder_callable, - input_for_callable, last_sent, deferred_inputs): - prev_stops = {} + self, + splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] + get_input_coder_callable, + input_for_callable, + last_sent, + deferred_inputs # type: DefaultDict[str, _ListBuffer] + ): + + prev_stops = {} # type: Dict[str, int] for split in splits: for delayed_application in split.residual_roots: deferred_inputs[ @@ -592,12 +682,16 @@ def _add_residuals_and_channel_splits_to_deferred_inputs( @staticmethod def _extract_stage_data_endpoints( - stage, pipeline_components, data_api_service_descriptor, pcoll_buffers): + stage, # type: fn_api_runner_transforms.Stage + pipeline_components, # type: beam_runner_api_pb2.Components + data_api_service_descriptor, + pcoll_buffers # type: DefaultDict[bytes, _ListBuffer] + ): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data ApiServiceDescriptor. data_input = {} - data_side_input = {} - data_output = {} + data_side_input = {} # type: DataSideInput + data_output = {} # type: Dict[Tuple[str, str], bytes] for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): @@ -631,16 +725,18 @@ def _extract_stage_data_endpoints( return data_input, data_side_input, data_output def _run_stage(self, - worker_handler_factory, - pipeline_components, - stage, - pcoll_buffers, - safe_coders): + worker_handler_factory, # type: Callable[[Optional[str], int], List[WorkerHandler]] + pipeline_components, # type: beam_runner_api_pb2.Components + stage, # type: fn_api_runner_transforms.Stage + pcoll_buffers, # type: DefaultDict[bytes, _ListBuffer] + safe_coders + ): + # type: (...) -> beam_fn_api_pb2.InstructionResponse """Run an individual stage. Args: - worker_handler_factory: A ``callable`` that takes in an environment, and - returns a ``WorkerHandler`` class. + worker_handler_factory: A ``callable`` that takes in an environment id + and a number of workers, and returns a list of ``WorkerHandler``s. pipeline_components (beam_runner_api_pb2.Components): TODO stage (fn_api_runner_transforms.Stage) pcoll_buffers (collections.defaultdict of str: list): Mapping of @@ -649,6 +745,7 @@ def _run_stage(self, safe_coders (dict): TODO """ def iterable_state_write(values, element_coder_impl): + # type: (...) -> bytes token = unique_name(None, 'iter').encode('ascii') out = create_OutputStream() for element in values: @@ -683,9 +780,10 @@ def iterable_state_write(values, element_coder_impl): pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) - if worker_handler.state_api_service_descriptor(): + state_api_service_descriptor = worker_handler.state_api_service_descriptor() + if state_api_service_descriptor: process_bundle_descriptor.state_api_service_descriptor.url = ( - worker_handler.state_api_service_descriptor().url) + state_api_service_descriptor.url) # Store the required side inputs into state so it is accessible for the # worker when it runs this bundle. @@ -753,6 +851,7 @@ def get_input_coder_impl(transform_id): result, splits = bundle_manager.process_bundle(data_input, data_output) def input_for(transform_id, input_id): + # type: (str, str) -> str input_pcoll = process_bundle_descriptor.transforms[ transform_id].inputs[input_id] for read_id, proto in process_bundle_descriptor.transforms.items(): @@ -766,7 +865,7 @@ def input_for(transform_id, input_id): last_sent = data_input while True: - deferred_inputs = collections.defaultdict(_ListBuffer) + deferred_inputs = collections.defaultdict(_ListBuffer) # type: DefaultDict[str, _ListBuffer] self._collect_written_timers_and_add_to_deferred_inputs( context, pipeline_components, stage, get_buffer, deferred_inputs) @@ -810,10 +909,12 @@ def input_for(transform_id, input_id): return result @staticmethod - def _extract_endpoints(stage, - pipeline_components, - data_api_service_descriptor, - pcoll_buffers): + def _extract_endpoints(stage, # type: fn_api_runner_transforms.Stage + pipeline_components, # type: beam_runner_api_pb2.Components + data_api_service_descriptor, # type: Optional[endpoints_pb2.ApiServiceDescriptor] + pcoll_buffers # type: DefaultDict[bytes, _ListBuffer] + ): + # type: (...) -> Tuple[Dict[str, _ListBuffer], Dict[Tuple[str, str], Tuple[bytes, beam_runner_api_pb2.FunctionSpec]], Dict[str, bytes]] """Returns maps of transform names to PCollection identifiers. Also mutates IO stages to point to the data ApiServiceDescriptor. @@ -832,9 +933,9 @@ def _extract_endpoints(stage, PCollection buffer; `data_output` is a dictionary mapping (transform_name, output_name) to a PCollection ID. """ - data_input = {} - data_side_input = {} - data_output = {} + data_input = {} # type: Dict[str, _ListBuffer] + data_side_input = {} # type: DataSideInput + data_output = {} # type: DataOutput for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): @@ -883,6 +984,7 @@ def __getitem__(self, key): self._underlying, self._overlay, key) def __delitem__(self, key): + # type: (bytes) -> None self._overlay[key] = [] def commit(self): @@ -896,19 +998,21 @@ def __init__(self, underlying, overlay, key): self._key = key def __iter__(self): + # type: () -> Iterator[bytes] if self._key in self._overlay: return iter(self._overlay[self._key]) else: return iter(self._underlying[self._key]) def append(self, item): + # type: (bytes) -> None if self._key not in self._overlay: self._overlay[self._key] = list(self._underlying[self._key]) self._overlay[self._key].append(item) def __init__(self): self._lock = threading.Lock() - self._state = collections.defaultdict(list) + self._state = collections.defaultdict(list) # type: DefaultDict[bytes, List[bytes]] self._checkpoint = None self._use_continuation_tokens = False self._continuations = {} @@ -931,7 +1035,11 @@ def restore(self): def process_instruction_id(self, unused_instruction_id): yield - def get_raw(self, state_key, continuation_token=None): + def get_raw(self, + state_key, # type: beam_fn_api_pb2.StateKey + continuation_token=None # type: Optional[bytes] + ): + # type: (...) -> Tuple[bytes, Optional[bytes]] with self._lock: full_state = self._state[self._to_key(state_key)] if self._use_continuation_tokens: @@ -952,12 +1060,17 @@ def get_raw(self, state_key, continuation_token=None): assert not continuation_token return b''.join(full_state), None - def append_raw(self, state_key, data): + def append_raw(self, + state_key, # type: beam_fn_api_pb2.StateKey + data # type: bytes + ): + # type: (...) -> _Future with self._lock: self._state[self._to_key(state_key)].append(data) return _Future.done() def clear(self, state_key): + # type: (beam_fn_api_pb2.StateKey) -> _Future with self._lock: try: del self._state[self._to_key(state_key)] @@ -971,13 +1084,19 @@ def clear(self, state_key): @staticmethod def _to_key(state_key): + # type: (beam_fn_api_pb2.StateKey) -> bytes return state_key.SerializeToString() class GrpcStateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer): def __init__(self, state): + # type: (FnApiRunner.StateServicer) -> None self._state = state - def State(self, request_stream, context=None): + def State(self, + request_stream, # type: Iterable[beam_fn_api_pb2.StateRequest] + context=None + ): + # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse] # Note that this eagerly mutates state, assuming any failures are fatal. # Thus it is safe to ignore instruction_id. for request in request_stream: @@ -1006,13 +1125,16 @@ class SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory): """A singleton cache for a StateServicer.""" def __init__(self, state_handler): + # type: (sdk_worker.StateHandler) -> None self._state_handler = state_handler def create_state_handler(self, api_service_descriptor): + # type: (endpoints_pb2.ApiServiceDescriptor) -> sdk_worker.StateHandler """Returns the singleton state handler.""" return self._state_handler def close(self): + # type: (...) -> None """Does nothing.""" pass @@ -1065,12 +1187,21 @@ class WorkerHandler(object): it. """ - _registered_environments = {} + _registered_environments = {} # type: Dict[str, Tuple[ConstructorFn, type]] _worker_id_counter = -1 _lock = threading.Lock() - def __init__( - self, control_handler, data_plane_handler, state, provision_info): + # FIXME: add a Protocol for these + control_conn = None # type: ControlConnection + data_conn = None # type: data_plane._GrpcDataChannel + + def __init__(self, + control_handler, + data_plane_handler, + state, # type: FnApiRunner.StateServicer + provision_info # type: Optional[ExtendedProvisionInfo] + ): + # type: (...) -> None """Initialize a WorkerHandler. Args: @@ -1089,32 +1220,48 @@ def __init__( self.worker_id = 'worker_%s' % WorkerHandler._worker_id_counter def close(self): + # type: () -> None self.stop_worker() def start_worker(self): + # type: () -> None raise NotImplementedError def stop_worker(self): + # type: () -> None raise NotImplementedError def data_api_service_descriptor(self): + # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] raise NotImplementedError def state_api_service_descriptor(self): + # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] raise NotImplementedError def logging_api_service_descriptor(self): + # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor] raise NotImplementedError @classmethod - def register_environment(cls, urn, payload_type): + def register_environment(cls, + urn, # type: str + payload_type # type: Optional[Type[T]] + ): + # type: (...) -> Callable[[Callable[[T, FnApiRunner.StateServicer, Optional[ExtendedProvisionInfo], GrpcServer], WorkerHandler]], Callable[[T, FnApiRunner.StateServicer, Optional[ExtendedProvisionInfo], GrpcServer], WorkerHandler]] def wrapper(constructor): cls._registered_environments[urn] = constructor, payload_type return constructor return wrapper @classmethod - def create(cls, environment, state, provision_info, grpc_server): + def create(cls, + environment, # type: beam_runner_api_pb2.Environment + state, # type: FnApiRunner.StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + grpc_server # type: GrpcServer + ): + # type: (...) -> WorkerHandler constructor, payload_type = cls._registered_environments[environment.urn] return constructor( proto_utils.parse_Bytes(environment.payload, payload_type), @@ -1127,11 +1274,16 @@ def create(cls, environment, state, provision_info, grpc_server): class EmbeddedWorkerHandler(WorkerHandler): """An in-memory worker_handler for fn API control, state and data planes.""" - def __init__(self, unused_payload, state, provision_info, - unused_grpc_server=None): + def __init__(self, + unused_payload, # type: None + state, + provision_info, # type: Optional[ExtendedProvisionInfo] + unused_grpc_server=None + ): + # type: (...) -> None super(EmbeddedWorkerHandler, self).__init__( self, data_plane.InMemoryDataChannel(), state, provision_info) - self.control_conn = self + self.control_conn = self # type: ignore # need Protocol to describe this self.data_conn = self.data_plane_handler self.worker = sdk_worker.SdkWorker( sdk_worker.BundleProcessorCache( @@ -1151,21 +1303,27 @@ def push(self, request): return ControlFuture(request.instruction_id, response) def start_worker(self): + # type: () -> None pass def stop_worker(self): + # type: () -> None self.worker.stop() def done(self): + # type: () -> None pass def data_api_service_descriptor(self): + # type: () -> None return None def state_api_service_descriptor(self): + # type: () -> None return None def logging_api_service_descriptor(self): + # type: () -> None return None @@ -1193,9 +1351,11 @@ class BasicProvisionService( beam_provision_api_pb2_grpc.ProvisionServiceServicer): def __init__(self, info): + # type: (Optional[beam_provision_api_pb2.ProvisionInfo]) -> None self._info = info def GetProvisionInfo(self, request, context=None): + # type: (...) -> beam_provision_api_pb2.GetProvisionInfoResponse return beam_provision_api_pb2.GetProvisionInfoResponse( info=self._info) @@ -1215,7 +1375,12 @@ class GrpcServer(object): _DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5 - def __init__(self, state, provision_info, max_workers): + def __init__(self, + state, # type: FnApiRunner.StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + max_workers # type: int + ): + # type: (...) -> None self.state = state self.provision_info = provision_info self.max_workers = max_workers @@ -1256,7 +1421,8 @@ def __init__(self, state, provision_info, max_workers): if self.provision_info.artifact_staging_dir: service = artifact_service.BeamFilesystemArtifactService( - self.provision_info.artifact_staging_dir) + self.provision_info.artifact_staging_dir + ) # type: beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer else: service = EmptyArtifactRetrievalService() beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server( @@ -1302,7 +1468,12 @@ def close(self): class GrpcWorkerHandler(WorkerHandler): """An grpc based worker_handler for fn API control, state and data planes.""" - def __init__(self, state, provision_info, grpc_server): + def __init__(self, + state, # type: FnApiRunner.StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + grpc_server # type: GrpcServer + ): + # type: (...) -> None self._grpc_server = grpc_server super(GrpcWorkerHandler, self).__init__( self._grpc_server.control_handler, self._grpc_server.data_plane_handler, @@ -1317,14 +1488,17 @@ def __init__(self, state, provision_info, grpc_server): self.worker_id) def data_api_service_descriptor(self): + # type: () -> endpoints_pb2.ApiServiceDescriptor return endpoints_pb2.ApiServiceDescriptor( url=self.port_from_worker(self._grpc_server.data_port)) def state_api_service_descriptor(self): + # type: () -> endpoints_pb2.ApiServiceDescriptor return endpoints_pb2.ApiServiceDescriptor( url=self.port_from_worker(self._grpc_server.state_port)) def logging_api_service_descriptor(self): + # type: () -> endpoints_pb2.ApiServiceDescriptor return endpoints_pb2.ApiServiceDescriptor( url=self.port_from_worker(self._grpc_server.logging_port)) @@ -1343,12 +1517,19 @@ def host_from_worker(self): @WorkerHandler.register_environment( common_urns.environments.EXTERNAL.urn, beam_runner_api_pb2.ExternalPayload) class ExternalWorkerHandler(GrpcWorkerHandler): - def __init__(self, external_payload, state, provision_info, grpc_server): + def __init__(self, + external_payload, # type: beam_runner_api_pb2.ExternalPayload + state, # type: FnApiRunner.StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + grpc_server # type: GrpcServer + ): + # type: (...) -> None super(ExternalWorkerHandler, self).__init__(state, provision_info, grpc_server) self._external_payload = external_payload def start_worker(self): + # type: () -> None stub = beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolStub( GRPCChannelFactory.insecure_channel( self._external_payload.endpoint.url)) @@ -1363,6 +1544,7 @@ def start_worker(self): raise RuntimeError("Error starting worker: %s" % response.error) def stop_worker(self): + # type: () -> None pass def host_from_worker(self): @@ -1372,7 +1554,13 @@ def host_from_worker(self): @WorkerHandler.register_environment(python_urns.EMBEDDED_PYTHON_GRPC, bytes) class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler): - def __init__(self, payload, state, provision_info, grpc_server): + def __init__(self, + payload, # type: bytes + state, # type: FnApiRunner.StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + grpc_server # type: GrpcServer + ): + # type: (...) -> None super(EmbeddedGrpcWorkerHandler, self).__init__(state, provision_info, grpc_server) if payload: @@ -1384,6 +1572,7 @@ def __init__(self, payload, state, provision_info, grpc_server): self._state_cache_size = STATE_CACHE_SIZE def start_worker(self): + # type: () -> None self.worker = sdk_worker.SdkHarness( self.control_address, worker_count=self._num_threads, state_cache_size=self._state_cache_size, worker_id=self.worker_id) @@ -1393,6 +1582,7 @@ def start_worker(self): self.worker_thread.start() def stop_worker(self): + # type: () -> None self.worker_thread.join() @@ -1403,12 +1593,19 @@ def stop_worker(self): @WorkerHandler.register_environment(python_urns.SUBPROCESS_SDK, bytes) class SubprocessSdkWorkerHandler(GrpcWorkerHandler): - def __init__(self, worker_command_line, state, provision_info, grpc_server): + def __init__(self, + worker_command_line, # type: bytes + state, # type: FnApiRunner.StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + grpc_server # type: GrpcServer + ): + # type: (...) -> None super(SubprocessSdkWorkerHandler, self).__init__(state, provision_info, grpc_server) self._worker_command_line = worker_command_line def start_worker(self): + # type: () -> None from apache_beam.runners.portability import local_job_service self.worker = local_job_service.SubprocessSdkWorker( self._worker_command_line, self.control_address, self.worker_id) @@ -1417,17 +1614,24 @@ def start_worker(self): self.worker_thread.start() def stop_worker(self): + # type: () -> None self.worker_thread.join() @WorkerHandler.register_environment(common_urns.environments.DOCKER.urn, beam_runner_api_pb2.DockerPayload) class DockerSdkWorkerHandler(GrpcWorkerHandler): - def __init__(self, payload, state, provision_info, grpc_server): + def __init__(self, + payload, # type: beam_runner_api_pb2.DockerPayload + state, # type: FnApiRunner.StateServicer + provision_info, # type: Optional[ExtendedProvisionInfo] + grpc_server # type: GrpcServer + ): + # type: (...) -> None super(DockerSdkWorkerHandler, self).__init__(state, provision_info, grpc_server) self._container_image = payload.container_image - self._container_id = None + self._container_id = None # type: Optional[bytes] def host_from_worker(self): if sys.platform == "darwin": @@ -1437,6 +1641,7 @@ def host_from_worker(self): return super(DockerSdkWorkerHandler, self).host_from_worker() def start_worker(self): + # type: () -> None with SUBPROCESS_LOCK: try: subprocess.check_call(['docker', 'pull', self._container_image]) @@ -1478,6 +1683,7 @@ def start_worker(self): time.sleep(1) def stop_worker(self): + # type: () -> None if self._container_id: with SUBPROCESS_LOCK: subprocess.call([ @@ -1487,14 +1693,27 @@ def stop_worker(self): class WorkerHandlerManager(object): - def __init__(self, environments, job_provision_info): + """ + Manages creation of ``WorkerHandler``s. + + Caches ``WorkerHandler``s based on environment id. + """ + def __init__(self, + environments, # type: Mapping[str, beam_runner_api_pb2.Environment] + job_provision_info # type: Optional[ExtendedProvisionInfo] + ): + # type: (...) -> None self._environments = environments self._job_provision_info = job_provision_info - self._cached_handlers = collections.defaultdict(list) + self._cached_handlers = collections.defaultdict(list) # type: DefaultDict[str, List[WorkerHandler]] self._state = FnApiRunner.StateServicer() # rename? - self._grpc_server = None + self._grpc_server = None # type: Optional[GrpcServer] - def get_worker_handlers(self, environment_id, num_workers): + def get_worker_handlers(self, + environment_id, # type: Optional[str] + num_workers # type: int + ): + # type: (...) -> List[WorkerHandler] if environment_id is None: # Any environment will do, pick one arbitrarily. environment_id = next(iter(self._environments.keys())) @@ -1545,7 +1764,10 @@ def close_all(self): class ExtendedProvisionInfo(object): - def __init__(self, provision_info=None, artifact_staging_dir=None): + def __init__(self, + provision_info=None, # type: Optional[beam_provision_api_pb2.ProvisionInfo] + artifact_staging_dir=None + ): self.provision_info = ( provision_info or beam_provision_api_pb2.ProvisionInfo()) self.artifact_staging_dir = artifact_staging_dir @@ -1587,10 +1809,15 @@ class BundleManager(object): _uid_counter = 0 _lock = threading.Lock() - def __init__( - self, worker_handler_list, get_buffer, get_input_coder_impl, - bundle_descriptor, progress_frequency=None, skip_registration=False, - cache_token_generator=FnApiRunner.get_cache_token_generator()): + def __init__(self, + worker_handler_list, # type: Sequence[WorkerHandler] + get_buffer, # type: Callable[[bytes], list] + get_input_coder_impl, # type: Callable[[str], CoderImpl] + bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor + progress_frequency=None, + skip_registration=False, + cache_token_generator=FnApiRunner.get_cache_token_generator() + ): """Set up a bundle manager. Args: @@ -1607,13 +1834,14 @@ def __init__( self._bundle_descriptor = bundle_descriptor self._registered = skip_registration self._progress_frequency = progress_frequency - self._worker_handler = None + self._worker_handler = None # type: Optional[WorkerHandler] self._cache_token_generator = cache_token_generator def _send_input_to_worker(self, - process_bundle_id, - read_transform_id, - byte_streams): + process_bundle_id, # type: str + read_transform_id, # type: str + byte_streams + ): data_out = self._worker_handler.data_conn.output_stream( process_bundle_id, read_transform_id) for byte_stream in byte_streams: @@ -1621,6 +1849,7 @@ def _send_input_to_worker(self, data_out.close() def _register_bundle_descriptor(self): + # type: () -> Optional[ControlFuture] if self._registered: registration_future = None else: @@ -1649,9 +1878,10 @@ def _select_split_manager(self): def _generate_splits_for_testing(self, split_manager, - inputs, + inputs, # type: Mapping[str, _ListBuffer] process_bundle_id): - split_results = [] + # type: (...) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse] + split_results = [] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] read_transform_id, buffer_data = only_element(inputs.items()) byte_stream = b''.join(buffer_data) @@ -1670,6 +1900,8 @@ def _generate_splits_for_testing(self, self._send_input_to_worker( process_bundle_id, read_transform_id, [byte_stream]) + assert self._worker_handler is not None + # Execute the requested splits. while not done: if split_fraction is None: @@ -1686,7 +1918,7 @@ def _generate_splits_for_testing(self, estimated_input_elements=num_elements) })) split_response = self._worker_handler.control_conn.push( - split_request).get() + split_request).get() # type: beam_fn_api_pb2.InstructionResponse for t in (0.05, 0.1, 0.2): waiting = ('Instruction not running', 'not yet scheduled') if any(msg in split_response.error for msg in waiting): @@ -1707,7 +1939,11 @@ def _generate_splits_for_testing(self, break return split_results - def process_bundle(self, inputs, expected_outputs): + def process_bundle(self, + inputs, # type: Mapping[str, _ListBuffer] + expected_outputs # type: DataOutput + ): + # type: (...) -> BundleProcessResult # Unique id for the instruction processing this bundle. with BundleManager._lock: BundleManager._uid_counter += 1 @@ -1736,7 +1972,7 @@ def process_bundle(self, inputs, expected_outputs): cache_tokens=[next(self._cache_token_generator)])) result_future = self._worker_handler.control_conn.push(process_bundle_req) - split_results = [] + split_results = [] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] with ProgressRequester( self._worker_handler, process_bundle_id, self._progress_frequency): @@ -1756,7 +1992,7 @@ def process_bundle(self, inputs, expected_outputs): expected_outputs[output.transform_id]).append(output.data) logging.debug('Wait for the bundle %s to finish.' % process_bundle_id) - result = result_future.get() + result = result_future.get() # type: beam_fn_api_pb2.InstructionResponse if result.error: raise RuntimeError(result.error) @@ -1775,23 +2011,34 @@ def process_bundle(self, inputs, expected_outputs): class ParallelBundleManager(BundleManager): def __init__( - self, worker_handler_list, get_buffer, get_input_coder_impl, - bundle_descriptor, progress_frequency=None, skip_registration=False, - cache_token_generator=None, **kwargs): + self, + worker_handler_list, # type: Sequence[WorkerHandler] + get_buffer, # type: Callable[[bytes], list] + get_input_coder_impl, # type: Callable[[str], CoderImpl] + bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor + progress_frequency=None, + skip_registration=False, + cache_token_generator=None, + **kwargs): + # type: (...) -> None super(ParallelBundleManager, self).__init__( worker_handler_list, get_buffer, get_input_coder_impl, bundle_descriptor, progress_frequency, skip_registration, cache_token_generator=cache_token_generator) self._num_workers = kwargs.pop('num_workers', 1) - def process_bundle(self, inputs, expected_outputs): - part_inputs = [{} for _ in range(self._num_workers)] + def process_bundle(self, + inputs, # type: Mapping[str, _ListBuffer] + expected_outputs # type: DataOutput + ): + # type: (...) -> BundleProcessResult + part_inputs = [{} for _ in range(self._num_workers)] # type: List[Dict[str, List[bytes]]] for name, input in inputs.items(): for ix, part in enumerate(input.partition(self._num_workers)): part_inputs[ix][name] = part - merged_result = None - split_result_list = [] + merged_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] + split_result_list = [] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] with futures.ThreadPoolExecutor(max_workers=self._num_workers) as executor: for result, split_result in executor.map(lambda part: BundleManager( self._worker_handler_list, self._get_buffer, @@ -1811,6 +2058,7 @@ def process_bundle(self, inputs, expected_outputs): result.process_bundle.monitoring_infos, merged_result.process_bundle.monitoring_infos))), error=result.error or merged_result.error) + assert merged_result is not None return merged_result, split_result_list @@ -1821,7 +2069,13 @@ class ProgressRequester(threading.Thread): A callback can be passed to call with progress updates. """ - def __init__(self, worker_handler, instruction_id, frequency, callback=None): + def __init__(self, + worker_handler, # type: WorkerHandler + instruction_id, + frequency, + callback=None + ): + # type: (...) -> None super(ProgressRequester, self).__init__() self._worker_handler = worker_handler self._instruction_id = instruction_id @@ -1938,6 +2192,7 @@ def query(self, filter=None): self.GAUGES: gauges} def monitoring_infos(self): + # type: () -> List[metrics_pb2.MonitoringInfo] return [item for sublist in self._monitoring_infos.values() for item in sublist] diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 25cb618aa60c..04660bf6a5dd 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -31,6 +31,7 @@ import unittest import uuid from builtins import range +from typing import Dict # patches unittest.TestCase to be python3 compatible import future.tests.base # pylint: disable=unused-import @@ -1522,7 +1523,7 @@ def __reduce__(self): return _unpickle_element_counter, (name,) -_pickled_element_counters = {} +_pickled_element_counters = {} # type: Dict[str, ElementCounter] def _unpickle_element_counter(name): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py index 31ad86557b33..f647b0f89e96 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py @@ -24,6 +24,17 @@ import functools import logging from builtins import object +from typing import Container +from typing import DefaultDict +from typing import Dict +from typing import FrozenSet +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TypeVar from past.builtins import unicode @@ -34,6 +45,8 @@ from apache_beam.runners.worker import bundle_processor from apache_beam.utils import proto_utils +T = TypeVar('T') + # This module is experimental. No backwards-compatibility guarantees. @@ -61,14 +74,20 @@ class Stage(object): """A set of Transforms that can be sent to the worker for processing.""" - def __init__(self, name, transforms, - downstream_side_inputs=None, must_follow=frozenset(), - parent=None, environment=None, forced_root=False): + def __init__(self, + name, # type: str + transforms, # type: List[beam_runner_api_pb2.PTransform] + downstream_side_inputs=None, # type: Optional[FrozenSet[str]] + must_follow=frozenset(), # type: FrozenSet[Stage] + parent=None, # type: Optional[Stage] + environment=None, # type: Optional[str] + forced_root=False + ): self.name = name self.transforms = transforms self.downstream_side_inputs = downstream_side_inputs self.must_follow = must_follow - self.timer_pcollections = [] + self.timer_pcollections = [] # type: List[Tuple[str, str]] self.parent = parent if environment is None: environment = functools.reduce( @@ -93,6 +112,7 @@ def __repr__(self): @staticmethod def _extract_environment(transform): + # type: (beam_runner_api_pb2.PTransform) -> Optional[str] if transform.spec.urn in PAR_DO_URNS: pardo_payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) @@ -106,6 +126,7 @@ def _extract_environment(transform): @staticmethod def _merge_environments(env1, env2): + # type: (Optional[str], Optional[str]) -> Optional[str] if env1 is None: return env2 elif env2 is None: @@ -118,6 +139,7 @@ def _merge_environments(env1, env2): return env1 def can_fuse(self, consumer, context): + # type: (Stage, TransformContext) -> bool try: self._merge_environments(self.environment, consumer.environment) except ValueError: @@ -134,6 +156,7 @@ def no_overlap(a, b): and no_overlap(self.downstream_side_inputs, consumer.side_inputs())) def fuse(self, other): + # type: (Stage) -> Stage return Stage( "(%s)+(%s)" % (self.name, other.name), self.transforms + other.transforms, @@ -145,10 +168,12 @@ def fuse(self, other): forced_root=self.forced_root or other.forced_root) def is_runner_urn(self, context): + # type: (TransformContext) -> bool return any(transform.spec.urn in context.known_runner_urns for transform in self.transforms) def side_inputs(self): + # type: () -> Iterator[str] for transform in self.transforms: if transform.spec.urn in PAR_DO_URNS: payload = proto_utils.parse_Bytes( @@ -169,7 +194,8 @@ def has_as_main_input(self, pcoll): return True def deduplicate_read(self): - seen_pcolls = set() + # type: () -> None + seen_pcolls = set() # type: Set[str] new_transforms = [] for transform in self.transforms: if transform.spec.urn == bundle_processor.DATA_INPUT_URN: @@ -180,8 +206,12 @@ def deduplicate_read(self): new_transforms.append(transform) self.transforms = new_transforms - def executable_stage_transform( - self, known_runner_urns, all_consumers, components): + def executable_stage_transform(self, + known_runner_urns, # type: FrozenSet[str] + all_consumers, + components # type: beam_runner_api_pb2.Components + ): + # type: (...) -> beam_runner_api_pb2.PTransform if (len(self.transforms) == 1 and self.transforms[0].spec.urn in known_runner_urns): return self.transforms[0] @@ -208,7 +238,7 @@ def executable_stage_transform( # Only keep the transforms in this stage. # Also gather up payload data as we iterate over the transforms. stage_components.transforms.clear() - main_inputs = set() + main_inputs = set() # type: Set[str] side_inputs = [] user_states = [] timers = [] @@ -288,7 +318,11 @@ class TransformContext(object): _KNOWN_CODER_URNS = set( value.urn for value in common_urns.coders.__dict__.values()) - def __init__(self, components, known_runner_urns, use_state_iterables=False): + def __init__(self, + components, # type: beam_runner_api_pb2.Components + known_runner_urns, # type: FrozenSet[str] + use_state_iterables=False + ): self.components = components self.known_runner_urns = known_runner_urns self.use_state_iterables = use_state_iterables @@ -297,7 +331,11 @@ def __init__(self, components, known_runner_urns, use_state_iterables=False): self.bytes_coder_id = self.add_or_get_coder_id(coder_proto, 'bytes_coder') self.safe_coders = {self.bytes_coder_id: self.bytes_coder_id} - def add_or_get_coder_id(self, coder_proto, coder_prefix='coder'): + def add_or_get_coder_id(self, + coder_proto, # type: beam_runner_api_pb2.Coder + coder_prefix='coder' + ): + # type: (...) -> str for coder_id, coder in self.components.coders.items(): if coder == coder_proto: return coder_id @@ -307,6 +345,7 @@ def add_or_get_coder_id(self, coder_proto, coder_prefix='coder'): @memoize_on_instance def with_state_iterables(self, coder_id): + # type: (str) -> str coder = self.components.coders[coder_id] if coder.spec.urn == common_urns.coders.ITERABLE.urn: new_coder_id = unique_name( @@ -334,6 +373,7 @@ def with_state_iterables(self, coder_id): @memoize_on_instance def length_prefixed_coder(self, coder_id): + # type: (str) -> str if coder_id in self.safe_coders: return coder_id length_prefixed_id, safe_id = self.length_prefixed_and_safe_coder(coder_id) @@ -342,6 +382,7 @@ def length_prefixed_coder(self, coder_id): @memoize_on_instance def length_prefixed_and_safe_coder(self, coder_id): + # type: (str) -> Tuple[str, str] coder = self.components.coders[coder_id] if coder.spec.urn == common_urns.coders.LENGTH_PREFIX.urn: return coder_id, self.bytes_coder_id @@ -379,6 +420,7 @@ def length_prefixed_and_safe_coder(self, coder_id): return new_coder_id, self.bytes_coder_id def length_prefix_pcoll_coders(self, pcoll_id): + # type: (str) -> None self.components.pcollections[pcoll_id].coder_id = ( self.length_prefixed_coder( self.components.pcollections[pcoll_id].coder_id)) @@ -386,6 +428,7 @@ def length_prefix_pcoll_coders(self, pcoll_id): def leaf_transform_stages( root_ids, components, parent=None, known_composites=KNOWN_COMPOSITES): + # type: (...) -> Iterator[Stage] for root_id in root_ids: root = components.transforms[root_id] if root.spec.urn in known_composites: @@ -400,8 +443,12 @@ def leaf_transform_stages( yield stage -def pipeline_from_stages( - pipeline_proto, stages, known_runner_urns, partial): +def pipeline_from_stages(pipeline_proto, # type: beam_runner_api_pb2.Pipeline + stages, # type: Iterable[Stage] + known_runner_urns, # type: FrozenSet[str] + partial # type: bool + ): + # type: (...) -> beam_runner_api_pb2.Pipeline # In case it was a generator that mutates components as it # produces outputs (as is the case with most transformations). @@ -433,7 +480,7 @@ def add_parent(child, parent): add_parent(parent, parents.get(parent)) components.transforms[parent].subtransforms.append(child) - all_consumers = collections.defaultdict(set) + all_consumers = collections.defaultdict(set) # type: DefaultDict[str, Set[int]] for stage in stages: for transform in stage.transforms: for pcoll in transform.inputs.values(): @@ -455,10 +502,12 @@ def add_parent(child, parent): return new_proto -def create_and_optimize_stages(pipeline_proto, +def create_and_optimize_stages(pipeline_proto, # type: beam_runner_api_pb2.Pipeline phases, - known_runner_urns, - use_state_iterables=False): + known_runner_urns, # type: FrozenSet[str] + use_state_iterables=False + ): + # type: (...) -> Tuple[TransformContext, List[Stage]] """Create a set of stages given a pipeline proto, and set of optimizations. Args: @@ -496,9 +545,9 @@ def create_and_optimize_stages(pipeline_proto, def optimize_pipeline( - pipeline_proto, + pipeline_proto, # type: beam_runner_api_pb2.Pipeline phases, - known_runner_urns, + known_runner_urns, # type: FrozenSet[str] partial=False, **kwargs): unused_context, stages = create_and_optimize_stages( @@ -514,6 +563,7 @@ def optimize_pipeline( def annotate_downstream_side_inputs(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] """Annotate each stage with fusion-prohibiting information. Each stage is annotated with the (transitive) set of pcollections that @@ -528,9 +578,11 @@ def annotate_downstream_side_inputs(stages, pipeline_context): This representation is also amenable to simple recomputation on fusion. """ - consumers = collections.defaultdict(list) + consumers = collections.defaultdict(list) # type: DefaultDict[str, List[Stage]] + def get_all_side_inputs(): - all_side_inputs = set() + # type: () -> Set[str] + all_side_inputs = set() # type: Set[str] for stage in stages: for transform in stage.transforms: for input in transform.inputs.values(): @@ -541,11 +593,12 @@ def get_all_side_inputs(): all_side_inputs = frozenset(get_all_side_inputs()) - downstream_side_inputs_by_stage = {} + downstream_side_inputs_by_stage = {} # type: Dict[Stage, FrozenSet[str]] def compute_downstream_side_inputs(stage): + # type: (Stage) -> FrozenSet[str] if stage not in downstream_side_inputs_by_stage: - downstream_side_inputs = frozenset() + downstream_side_inputs = frozenset() # type: FrozenSet[str] for transform in stage.transforms: for output in transform.outputs.values(): if output in all_side_inputs: @@ -564,6 +617,7 @@ def compute_downstream_side_inputs(stage): def annotate_stateful_dofns_as_roots(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] for stage in stages: for transform in stage.transforms: if transform.spec.urn == common_urns.primitives.PAR_DO.urn: @@ -575,6 +629,7 @@ def annotate_stateful_dofns_as_roots(stages, pipeline_context): def fix_side_input_pcoll_coders(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] """Length prefix side input PCollection coders. """ for stage in stages: @@ -584,6 +639,7 @@ def fix_side_input_pcoll_coders(stages, pipeline_context): def lift_combiners(stages, context): + # type: (List[Stage], TransformContext) -> Iterator[Stage] """Expands CombinePerKey into pre- and post-grouping stages. ... -> CombinePerKey -> ... @@ -658,6 +714,7 @@ def lift_combiners(stages, context): is_bounded=output_pcoll.is_bounded)) def make_stage(base_stage, transform): + # type: (Stage, beam_runner_api_pb2.PTransform) -> Stage return Stage( transform.unique_name, [transform], @@ -713,6 +770,7 @@ def make_stage(base_stage, transform): def expand_sdf(stages, context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Transforms splitable DoFns into pair+split+read.""" for stage in stages: assert len(stage.transforms) == 1 @@ -753,6 +811,7 @@ def copy_like(protos, original, suffix='_copy', **kwargs): return new_id def make_stage(base_stage, transform_id, extra_must_follow=()): + # type: (Stage, str, Iterable[Stage]) -> Stage transform = context.components.transforms[transform_id] return Stage( transform.unique_name, @@ -782,7 +841,9 @@ def make_stage(base_stage, transform_id, extra_must_follow=()): component_coder_ids=[ paired_coder_id, context.add_or_get_coder_id( - coders.FloatCoder().to_runner_api(None), + # context can be None here only because FloatCoder does + # not have components + coders.FloatCoder().to_runner_api(None), # type: ignore 'doubles_coder') ])) @@ -854,6 +915,7 @@ def make_stage(base_stage, transform_id, extra_must_follow=()): def expand_gbk(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Transforms each GBK into a write followed by a read. """ for stage in stages: @@ -903,6 +965,7 @@ def expand_gbk(stages, pipeline_context): def fix_flatten_coders(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Ensures that the inputs of Flatten have the same coders as the output. """ pcollections = pipeline_context.components.pcollections @@ -943,6 +1006,7 @@ def fix_flatten_coders(stages, pipeline_context): def sink_flattens(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Sink flattens and remove them from the graph. A flatten that cannot be sunk/fused away becomes multiple writes (to the @@ -955,7 +1019,7 @@ def sink_flattens(stages, pipeline_context): if transform.spec.urn == common_urns.primitives.FLATTEN.urn: # This is used later to correlate the read and writes. buffer_id = create_buffer_id(transform.unique_name) - flatten_writes = [] + flatten_writes = [] # type: List[Stage] for local_in, pcoll_in in transform.inputs.items(): flatten_write = Stage( transform.unique_name + '/Write/' + local_in, @@ -1074,6 +1138,7 @@ def fuse(producer, consumer): def read_to_impulse(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Translates Read operations into Impulse operations.""" for stage in stages: # First map Reads, if any, to Impulse + triggered read op. @@ -1111,6 +1176,7 @@ def read_to_impulse(stages, pipeline_context): def impulse_to_input(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Translates Impulse operations into GRPC reads.""" for stage in stages: for transform in list(stage.transforms): @@ -1127,6 +1193,7 @@ def impulse_to_input(stages, pipeline_context): def extract_impulse_stages(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Splits fused Impulse operations into their own stage.""" for stage in stages: for transform in list(stage.transforms): @@ -1144,6 +1211,7 @@ def extract_impulse_stages(stages, pipeline_context): def remove_data_plane_ops(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] for stage in stages: for transform in list(stage.transforms): if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, @@ -1155,6 +1223,7 @@ def remove_data_plane_ops(stages, pipeline_context): def inject_timer_pcollections(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] """Create PCollections for fired timers and to-be-set timers. At execution time, fired timers and timers-to-set are represented as @@ -1228,10 +1297,11 @@ def inject_timer_pcollections(stages, pipeline_context): def sort_stages(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> List[Stage] """Order stages suitable for sequential execution. """ all_stages = set(stages) - seen = set() + seen = set() # type: Set[Stage] ordered = [] def process(stage): @@ -1248,6 +1318,7 @@ def process(stage): def window_pcollection_coders(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] """Wrap all PCollection coders as windowed value coders. This is required as some SDK workers require windowed coders for their @@ -1291,6 +1362,7 @@ def union(a, b): def unique_name(existing, prefix): + # type: (Optional[Container[str]], str) -> str if existing is None: global _global_counter _global_counter += 1 @@ -1307,14 +1379,18 @@ def unique_name(existing, prefix): def only_element(iterable): + # type: (Iterable[T]) -> T element, = iterable return element def create_buffer_id(name, kind='materialize'): + # type: (str, str) -> bytes return ('%s:%s' % (kind, name)).encode('utf-8') def split_buffer_id(buffer_id): + # type: (bytes) -> Tuple[str, str] """A buffer id is "kind:pcollection_id". Split into (kind, pcoll_id). """ - return buffer_id.decode('utf-8').split(':', 1) + kind, pcoll_id = buffer_id.decode('utf-8').split(':', 1) + return kind, pcoll_id diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index 7cf8d4321937..0d42a9150ad8 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -51,6 +51,7 @@ def __init__(self, endpoint, timeout=None): self._timeout = timeout def start(self): + # type: () -> beam_job_api_pb2_grpc.JobServiceStub channel = grpc.insecure_channel(self._endpoint) grpc.channel_ready_future(channel).result(timeout=self._timeout) return beam_job_api_pb2_grpc.JobServiceStub(channel) @@ -61,6 +62,7 @@ def stop(self): class EmbeddedJobServer(JobServer): def start(self): + # type: () -> local_job_service.LocalJobServicer return local_job_service.LocalJobServicer() def stop(self): diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 1567bfac6245..d064cbed7047 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -27,6 +27,9 @@ import traceback from builtins import object from concurrent import futures +from typing import TYPE_CHECKING +from typing import List +from typing import Optional import grpc from google.protobuf import text_format # type: ignore # not in typeshed @@ -42,6 +45,10 @@ from apache_beam.runners.portability import artifact_service from apache_beam.runners.portability import fn_api_runner +if TYPE_CHECKING: + from google.protobuf import struct_pb2 # pylint: disable=ungrouped-imports + from apache_beam.portability.api import beam_runner_api_pb2 + class LocalJobServicer(abstract_job_service.AbstractJobServiceServicer): """Manages one or more pipelines, possibly concurrently. @@ -63,9 +70,15 @@ def __init__(self, staging_dir=None): self._staging_dir = staging_dir or tempfile.mkdtemp() self._artifact_service = artifact_service.BeamFilesystemArtifactService( self._staging_dir) - self._artifact_staging_endpoint = None - - def create_beam_job(self, preparation_id, job_name, pipeline, options): + self._artifact_staging_endpoint = None # type: Optional[endpoints_pb2.ApiServiceDescriptor] + + def create_beam_job(self, + preparation_id, # stype: str + job_name, # type: str + pipeline, # type: beam_runner_api_pb2.Pipeline + options # type: struct_pb2.Struct + ): + # type: (...) -> BeamJob # TODO(angoenka): Pass an appropriate staging_session_token. The token can # be obtained in PutArtifactResponse from JobService if not self._artifact_staging_endpoint: @@ -112,7 +125,11 @@ class SubprocessSdkWorker(object): """Manages a SDK worker implemented as a subprocess communicating over grpc. """ - def __init__(self, worker_command_line, control_address, worker_id=None): + def __init__(self, + worker_command_line, # type: bytes + control_address, + worker_id=None + ): self._worker_command_line = worker_command_line self._control_address = control_address self._worker_id = worker_id @@ -162,18 +179,19 @@ class BeamJob(abstract_job_service.AbstractBeamJob): """ def __init__(self, - job_id, + job_id, # type: str pipeline, options, - provision_info, - artifact_staging_endpoint): + provision_info, # type: fn_api_runner.ExtendedProvisionInfo + artifact_staging_endpoint # type: Optional[endpoints_pb2.ApiServiceDescriptor] + ): super(BeamJob, self).__init__( job_id, provision_info.provision_info.job_name, pipeline, options) self._provision_info = provision_info self._artifact_staging_endpoint = artifact_staging_endpoint self._state = None - self._state_queues = [] - self._log_queues = [] + self._state_queues = [] # type: List[queue.Queue] + self._log_queues = [] # type: List[queue.Queue] self.state = beam_job_api_pb2.JobState.STARTING self.daemon = True diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index 16c6eba6c731..d41f9238933c 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -24,6 +24,8 @@ import sys import threading import time +from typing import TYPE_CHECKING +from typing import Optional import grpc @@ -45,6 +47,10 @@ from apache_beam.runners.worker import sdk_worker_main from apache_beam.runners.worker import worker_pool_main +if TYPE_CHECKING: + from apache_beam.options.pipeline_options import PipelineOptions + from apache_beam.pipeline import Pipeline + __all__ = ['PortableRunner'] MESSAGE_LOG_LEVELS = { @@ -74,7 +80,7 @@ class PortableRunner(runner.PipelineRunner): running and managing the job lies with the job service used. """ def __init__(self): - self._dockerized_job_server = None + self._dockerized_job_server = None # type: Optional[job_server.JobServer] @staticmethod def default_docker_image(): @@ -94,6 +100,7 @@ def default_docker_image(): @staticmethod def _create_environment(options): + # type: (PipelineOptions) -> beam_runner_api_pb2.Environment portable_options = options.view_as(PortableOptions) environment_urn = common_urns.environments.DOCKER.urn if portable_options.environment_type == 'DOCKER': @@ -156,6 +163,7 @@ def looks_like_json(environment_config): if portable_options.environment_config else None)) def default_job_server(self, portable_options): + # type: (...) -> job_server.JobServer # TODO Provide a way to specify a container Docker URL # https://issues.apache.org/jira/browse/BEAM-6328 if not self._dockerized_job_server: @@ -176,6 +184,7 @@ def create_job_service(self, options): return server.start() def run_pipeline(self, pipeline, options): + # type: (Pipeline, PipelineOptions) -> PipelineResult portable_options = options.view_as(PortableOptions) # TODO: https://issues.apache.org/jira/browse/BEAM-5525 @@ -261,6 +270,7 @@ def run_pipeline(self, pipeline, options): # fetch runner options from job service # retries in case the channel is not ready def send_options_request(max_retries=5): + # type: (int) -> beam_job_api_pb2.DescribePipelineOptionsResponse num_retries = 0 while True: try: diff --git a/sdks/python/apache_beam/runners/portability/portable_stager.py b/sdks/python/apache_beam/runners/portability/portable_stager.py index 09ff18fd4565..328025e0ae09 100644 --- a/sdks/python/apache_beam/runners/portability/portable_stager.py +++ b/sdks/python/apache_beam/runners/portability/portable_stager.py @@ -22,6 +22,8 @@ import hashlib import os +from typing import Iterator +from typing import List from apache_beam.portability.api import beam_artifact_api_pb2 from apache_beam.portability.api import beam_artifact_api_pb2_grpc @@ -54,9 +56,10 @@ def __init__(self, artifact_service_channel, staging_session_token): self._artifact_staging_stub = beam_artifact_api_pb2_grpc.\ ArtifactStagingServiceStub(channel=artifact_service_channel) self._staging_session_token = staging_session_token - self._artifacts = [] + self._artifacts = [] # type: List[beam_artifact_api_pb2.ArtifactMetadata] def stage_artifact(self, local_path_to_artifact, artifact_name): + # type: (str, str) -> None """Stage a file to ArtifactStagingService. Args: @@ -69,6 +72,7 @@ def stage_artifact(self, local_path_to_artifact, artifact_name): .format(local_path_to_artifact)) def artifact_request_generator(): + # type: () -> Iterator[beam_artifact_api_pb2.PutArtifactRequest] artifact_metadata = beam_artifact_api_pb2.ArtifactMetadata( name=artifact_name, sha256=_get_file_hash(local_path_to_artifact), diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index eb4c92a5a217..774042b34a84 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -52,6 +52,8 @@ import shutil import sys import tempfile +from typing import List +from typing import Optional import pkg_resources @@ -59,6 +61,7 @@ from apache_beam.internal.http_client import get_new_http from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import DebugOptions +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import WorkerOptions # TODO(angoenka): Remove reference to dataflow internal names @@ -91,6 +94,7 @@ class Stager(object): """ def stage_artifact(self, local_path_to_artifact, artifact_name): + # type: (str, str) -> None """ Stages the artifact to Stager._staging_location and adds artifact_name to the manifest of artifacts that have been staged.""" raise NotImplementedError @@ -106,11 +110,12 @@ def get_sdk_package_name(): return names.BEAM_PACKAGE_NAME def stage_job_resources(self, - options, - build_setup_args=None, - temp_dir=None, - populate_requirements_cache=None, - staging_location=None): + options, # type: PipelineOptions + build_setup_args=None, # type: Optional[List[str]] + temp_dir=None, # type: Optional[str] + populate_requirements_cache=None, # type: Optional[str] + staging_location=None # type: Optional[str] + ): """For internal use only; no backwards-compatibility guarantees. Creates (if needed) and stages job resources to staging_location. @@ -138,7 +143,7 @@ def stage_job_resources(self, while trying to create the resources (e.g., build a setup package). """ temp_dir = temp_dir or tempfile.mkdtemp() - resources = [] + resources = [] # type: List[str] setup_options = options.view_as(SetupOptions) # Make sure that all required options are specified. @@ -319,6 +324,7 @@ def _is_remote_path(path): return path.find('://') != -1 def _stage_jar_packages(self, jar_packages, staging_location, temp_dir): + # type: (...) -> List[str] """Stages a list of local jar packages for Java SDK Harness. :param jar_packages: Ordered list of local paths to jar packages to be @@ -331,9 +337,9 @@ def _stage_jar_packages(self, jar_packages, staging_location, temp_dir): RuntimeError: If files specified are not found or do not have expected name patterns. """ - resources = [] + resources = [] # type: List[str] staging_temp_dir = tempfile.mkdtemp(dir=temp_dir) - local_packages = [] + local_packages = [] # type: List[str] for package in jar_packages: if not os.path.basename(package).endswith('.jar'): raise RuntimeError( @@ -369,6 +375,7 @@ def _stage_jar_packages(self, jar_packages, staging_location, temp_dir): return resources def _stage_extra_packages(self, extra_packages, staging_location, temp_dir): + # type: (...) -> List[str] """Stages a list of local extra packages. Args: @@ -387,9 +394,9 @@ def _stage_extra_packages(self, extra_packages, staging_location, temp_dir): RuntimeError: If files specified are not found or do not have expected name patterns. """ - resources = [] + resources = [] # type: List[str] staging_temp_dir = tempfile.mkdtemp(dir=temp_dir) - local_packages = [] + local_packages = [] # type: List[str] for package in extra_packages: if not (os.path.basename(package).endswith('.tar') or os.path.basename(package).endswith('.tar.gz') or @@ -487,7 +494,11 @@ def _populate_requirements_cache(requirements_file, cache_dir): processes.check_output(cmd_args, stderr=processes.STDOUT) @staticmethod - def _build_setup_package(setup_file, temp_dir, build_setup_args=None): + def _build_setup_package(setup_file, # type: str + temp_dir, # type: str + build_setup_args=None # type: Optional[List[str]] + ): + # type: (...) -> str saved_current_directory = os.getcwd() try: os.chdir(os.path.dirname(setup_file)) @@ -508,6 +519,7 @@ def _build_setup_package(setup_file, temp_dir, build_setup_args=None): @staticmethod def _desired_sdk_filename_in_staging_location(sdk_location): + # type: (...) -> str """Returns the name that SDK file should have in the staging location. Args: sdk_location: Full path to SDK file. @@ -522,6 +534,7 @@ def _desired_sdk_filename_in_staging_location(sdk_location): return DATAFLOW_SDK_TARBALL_FILE def _stage_beam_sdk(self, sdk_remote_location, staging_location, temp_dir): + # type: (...) -> List[str] """Stages a Beam SDK file with the appropriate version. Args: diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py b/sdks/python/apache_beam/runners/portability/stager_test.py index e7fc5f1a8590..2d9193bdeea5 100644 --- a/sdks/python/apache_beam/runners/portability/stager_test.py +++ b/sdks/python/apache_beam/runners/portability/stager_test.py @@ -24,6 +24,7 @@ import sys import tempfile import unittest +from typing import List import mock @@ -66,7 +67,7 @@ def create_temp_file(self, path, contents): def is_remote_path(self, path): return path.startswith('/tmp/remote/') - remote_copied_files = [] + remote_copied_files = [] # type: List[str] def file_copy(self, from_path, to_path): if self.is_remote_path(from_path): diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index b7a2d304d6b8..7a188c98c1c7 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -26,6 +26,16 @@ import shutil import tempfile from builtins import object +from typing import TYPE_CHECKING +from typing import Optional + +if TYPE_CHECKING: + from apache_beam import pvalue + from apache_beam import PTransform + from apache_beam.options.pipeline_options import PipelineOptions + from apache_beam.pipeline import AppliedPTransform + from apache_beam.pipeline import Pipeline + from apache_beam.pipeline import PipelineVisitor __all__ = ['PipelineRunner', 'PipelineState', 'PipelineResult'] @@ -53,6 +63,7 @@ def create_runner(runner_name): + # type: (str) -> PipelineRunner """For internal use only; no backwards-compatibility guarantees. Creates a runner instance from a runner class name. @@ -105,7 +116,11 @@ class PipelineRunner(object): materialized values in order to reduce footprint. """ - def run(self, transform, options=None): + def run(self, + transform, # type: PTransform + options=None # type: Optional[PipelineOptions] + ): + # type: (...) -> PipelineResult """Run the given transform or callable with this runner. Blocks until the pipeline is complete. See also `PipelineRunner.run_async`. @@ -114,7 +129,11 @@ def run(self, transform, options=None): result.wait_until_finish() return result - def run_async(self, transform, options=None): + def run_async(self, + transform, # type: PTransform + options=None # type: Optional[PipelineOptions] + ): + # type: (...) -> PipelineResult """Run the given transform or callable with this runner. May return immediately, executing the pipeline in the background. @@ -133,7 +152,10 @@ def run_async(self, transform, options=None): transform(PBegin(p)) return p.run() - def run_pipeline(self, pipeline, options): + def run_pipeline(self, + pipeline, # type: Pipeline + options # type: PipelineOptions + ): """Execute the entire pipeline or the sub-DAG reachable from a node. Runners should override this method. @@ -146,6 +168,7 @@ def run_pipeline(self, pipeline, options): class RunVisitor(PipelineVisitor): def __init__(self, runner): + # type: (PipelineRunner) -> None self.runner = runner def visit_transform(self, transform_node): @@ -157,7 +180,11 @@ def visit_transform(self, transform_node): pipeline.visit(RunVisitor(self)) - def apply(self, transform, input, options): + def apply(self, + transform, # type: PTransform + input, # type: pvalue.PCollection + options # type: PipelineOptions + ): """Runner callback for a pipeline.apply call. Args: @@ -180,7 +207,10 @@ def apply_PTransform(self, transform, input, options): # The base case of apply is to call the transform's expand. return transform.expand(input) - def run_transform(self, transform_node, options): + def run_transform(self, + transform_node, # type: AppliedPTransform + options # type: PipelineOptions + ): """Runner callback for a pipeline.run call. Args: @@ -298,6 +328,7 @@ def key(self, pobj): return self.to_cache_key(pobj.real_producer, pobj.tag) +# FIXME: replace with PipelineState(str, enum.Enum) class PipelineState(object): """State of the Pipeline, as returned by :attr:`PipelineResult.state`. diff --git a/sdks/python/apache_beam/runners/sdf_common.py b/sdks/python/apache_beam/runners/sdf_common.py index 072d3dc74492..ff95177d27a4 100644 --- a/sdks/python/apache_beam/runners/sdf_common.py +++ b/sdks/python/apache_beam/runners/sdf_common.py @@ -21,6 +21,8 @@ import uuid from builtins import object +from typing import Any +from typing import Iterator import apache_beam as beam from apache_beam import pvalue @@ -61,6 +63,7 @@ class SplittableParDo(PTransform): """A transform that processes a PCollection using a Splittable DoFn.""" def __init__(self, ptransform): + # type: (ParDo) -> None assert isinstance(ptransform, ParDo) self._ptransform = ptransform @@ -96,6 +99,7 @@ class PairWithRestrictionFn(beam.DoFn): """A transform that pairs each element with a restriction.""" def __init__(self, do_fn): + # type: (beam.DoFn) -> None self._do_fn = do_fn def start_bundle(self): @@ -104,6 +108,7 @@ def start_bundle(self): signature, process_invocation=False) def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs): + # type: (...) -> Iterator[ElementAndRestriction] initial_restriction = self._invoker.invoke_initial_restriction(element) yield ElementAndRestriction(element, initial_restriction) @@ -112,6 +117,7 @@ class SplitRestrictionFn(beam.DoFn): """A transform that perform initial splitting of Splittable DoFn inputs.""" def __init__(self, do_fn): + # type: (beam.DoFn) -> None self._do_fn = do_fn def start_bundle(self): @@ -120,6 +126,7 @@ def start_bundle(self): signature, process_invocation=False) def process(self, element_and_restriction, *args, **kwargs): + # type: (ElementAndRestriction, Any, Any) -> Iterator[ElementAndRestriction] element = element_and_restriction.element restriction = element_and_restriction.restriction restriction_parts = self._invoker.invoke_split( diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index fe0ab37332ac..b92d16c551ff 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -30,6 +30,24 @@ import threading from builtins import next from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Container +from typing import DefaultDict +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import cast from future.utils import itervalues from google.protobuf import timestamp_pb2 @@ -57,8 +75,25 @@ from apache_beam.utils import timestamp from apache_beam.utils import windowed_value -# This module is experimental. No backwards-compatibility guarantees. +if TYPE_CHECKING: + from google.protobuf import message # pylint: disable=ungrouped-imports + from apache_beam import pvalue + from apache_beam.portability.api import metrics_pb2 + from apache_beam.runners.worker import data_plane + from apache_beam.runners.worker import sdk_worker + from apache_beam.transforms import window + from apache_beam.utils.timestamp import Timestamp +# This module is experimental. No backwards-compatibility guarantees. +T = TypeVar('T') +ConstructorFn = Callable[ + ['BeamTransformFactory', + Any, + beam_runner_api_pb2.PTransform, + Union['message.Message', bytes], + Dict[str, List[operations.Operation]]], + operations.Operation] +OperationT = TypeVar('OperationT', bound=operations.Operation) DATA_INPUT_URN = 'beam:source:runner:0.1' DATA_OUTPUT_URN = 'beam:sink:runner:0.1' @@ -73,8 +108,17 @@ class RunnerIOOperation(operations.Operation): """Common baseclass for runner harness IO operations.""" - def __init__(self, name_context, step_name, consumers, counter_factory, - state_sampler, windowed_coder, transform_id, data_channel): + def __init__(self, + name_context, # type: Union[str, common.NameContext] + step_name, + consumers, # type: Mapping[Any, Iterable[operations.Operation]] + counter_factory, + state_sampler, + windowed_coder, # type: coders.WindowedValueCoder + transform_id, # type: str + data_channel # type: data_plane.GrpcClientDataChannel + ): + # type: (...) -> None super(RunnerIOOperation, self).__init__( name_context, None, counter_factory, state_sampler) self.windowed_coder = windowed_coder @@ -93,14 +137,17 @@ class DataOutputOperation(RunnerIOOperation): """ def set_output_stream(self, output_stream): + # type: (data_plane.ClosableOutputStream) -> None self.output_stream = output_stream def process(self, windowed_value): + # type: (windowed_value.WindowedValue) -> None self.windowed_coder_impl.encode_to_stream( windowed_value, self.output_stream, True) self.output_stream.maybe_flush() def finish(self): + # type: () -> None self.output_stream.close() super(DataOutputOperation, self).finish() @@ -108,8 +155,17 @@ def finish(self): class DataInputOperation(RunnerIOOperation): """A source-like operation that gathers input from the runner.""" - def __init__(self, operation_name, step_name, consumers, counter_factory, - state_sampler, windowed_coder, transform_id, data_channel): + def __init__(self, + operation_name, # type: str + step_name, + consumers, # type: Mapping[Any, Iterable[operations.Operation]] + counter_factory, + state_sampler, + windowed_coder, # type: coders.WindowedValueCoder + transform_id, + data_channel # type: data_plane.GrpcClientDataChannel + ): + # type: (...) -> None super(DataInputOperation, self).__init__( operation_name, step_name, consumers, counter_factory, state_sampler, windowed_coder, transform_id=transform_id, data_channel=data_channel) @@ -122,6 +178,7 @@ def __init__(self, operation_name, step_name, consumers, counter_factory, self.started = False def start(self): + # type: () -> None super(DataInputOperation, self).start() with self.splitting_lock: self.index = -1 @@ -129,9 +186,11 @@ def start(self): self.started = True def process(self, windowed_value): + # type: (windowed_value.WindowedValue) -> None self.output(windowed_value) def process_encoded(self, encoded_windowed_values): + # type: (bytes) -> None input_stream = coder_impl.create_InputStream(encoded_windowed_values) while input_stream.size() > 0: with self.splitting_lock: @@ -185,6 +244,7 @@ def try_split(self, fraction_of_remainder, total_buffer_size): return self.stop - 1, None, None, self.stop def progress_metrics(self): + # type: () -> beam_fn_api_pb2.Metrics.PTransform with self.splitting_lock: metrics = super(DataInputOperation, self).progress_metrics() current_element_progress = self.receivers[0].current_element_progress() @@ -194,13 +254,19 @@ def progress_metrics(self): return metrics def finish(self): + # type: () -> None with self.splitting_lock: self.started = False class _StateBackedIterable(object): - def __init__(self, state_handler, state_key, coder_or_impl, - is_cached=False): + def __init__(self, + state_handler, + state_key, # type: beam_fn_api_pb2.StateKey + coder_or_impl, # type: Union[coders.Coder, coder_impl.CoderImpl] + is_cached=False + ): + # type: (...) -> None self._state_handler = state_handler self._state_key = state_key if isinstance(coder_or_impl, coders.Coder): @@ -210,6 +276,7 @@ def __init__(self, state_handler, state_key, coder_or_impl, self._is_cached = is_cached def __iter__(self): + # type: () -> Iterator[Any] return self._state_handler.blocking_get( self._state_key, self._coder_impl, is_cached=self._is_cached) @@ -222,7 +289,14 @@ def __reduce__(self): class StateBackedSideInputMap(object): - def __init__(self, state_handler, transform_id, tag, side_input_data, coder): + def __init__(self, + state_handler, + transform_id, # type: str + tag, # type: Optional[str] + side_input_data, # type: pvalue.SideInputData + coder # type: WindowedValueCoder + ): + # type: (...) -> None self._state_handler = state_handler self._transform_id = transform_id self._tag = tag @@ -230,7 +304,7 @@ def __init__(self, state_handler, transform_id, tag, side_input_data, coder): self._element_coder = coder.wrapped_value_coder self._target_window_coder = coder.window_coder # TODO(robertwb): Limit the cache size. - self._cache = {} + self._cache = {} # type: Dict[window.BoundedWindow, Any] def __getitem__(self, window): target_window = self._side_input_data.window_mapping_fn(window) @@ -278,10 +352,12 @@ def __reduce__(self): return self._cache[target_window] def is_globally_windowed(self): + # type: () -> bool return (self._side_input_data.window_mapping_fn == sideinputs._global_window_mapping_fn) def reset(self): + # type: () -> None # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. self._cache = {} @@ -328,10 +404,12 @@ class _ConcatIterable(object): Unlike itertools.chain, this allows reiteration. """ def __init__(self, first, second): + # type: (Iterable[Any], Iterable[Any]) -> None self.first = first self.second = second def __iter__(self): + # type: () -> Iterator[Any] for elem in self.first: yield elem for elem in self.second: @@ -343,18 +421,25 @@ def __iter__(self): class SynchronousBagRuntimeState(userstate.BagRuntimeState): - def __init__(self, state_handler, state_key, value_coder): + def __init__(self, + state_handler, + state_key, # type: beam_fn_api_pb2.StateKey + value_coder # type: coders.Coder + ): + # type: (...) -> None self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = [] + self._added_elements = [] # type: List[Any] def read(self): + # type: () -> Iterable[Any] return _ConcatIterable( - [] if self._cleared else _StateBackedIterable( + [] if self._cleared + else cast('Iterable[Any]', _StateBackedIterable( self._state_handler, self._state_key, self._value_coder, - is_cached=True), + is_cached=True)), self._added_elements) def add(self, value): @@ -381,12 +466,17 @@ def _commit(self): class SynchronousSetRuntimeState(userstate.SetRuntimeState): - def __init__(self, state_handler, state_key, value_coder): + def __init__(self, + state_handler, + state_key, # type: beam_fn_api_pb2.StateKey + value_coder # type: coders.Coder + ): + # type: (...) -> None self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = set() + self._added_elements = set() # type: Set[Any] def _compact_data(self, rewrite=True): accumulator = set(_ConcatIterable( @@ -410,6 +500,7 @@ def _compact_data(self, rewrite=True): return accumulator def read(self): + # type: () -> Set[Any] return self._compact_data(rewrite=False) def add(self, value): @@ -423,10 +514,12 @@ def add(self, value): self._compact_data() def clear(self): + # type: () -> None self._cleared = True self._added_elements = set() def _commit(self): + # type: () -> None if self._cleared: self._state_handler.clear(self._state_key, is_cached=True).get() if self._added_elements: @@ -438,7 +531,11 @@ def _commit(self): class OutputTimer(object): - def __init__(self, key, window, receiver): + def __init__(self, + key, + window, # type: windowed_value.BoundedWindow + receiver # type: operations.ConsumerSet + ): self._key = key self._window = window self._receiver = receiver @@ -450,6 +547,7 @@ def set(self, ts): (self._key, dict(timestamp=ts)), ts, (self._window,))) def clear(self): + # type: () -> None dummy_millis = int(common_urns.constants.MAX_TIMESTAMP_MILLIS.constant) + 1 clear_ts = timestamp.Timestamp(micros=dummy_millis * 1000) self._receiver.receive( @@ -460,8 +558,14 @@ def clear(self): class FnApiUserStateContext(userstate.UserStateContext): """Interface for state and timers from SDK to Fn API servicer of state..""" - def __init__( - self, state_handler, transform_id, key_coder, window_coder, timer_specs): + def __init__(self, + state_handler, + transform_id, # type: str + key_coder, # type: coders.Coder + window_coder, # type: coders.Coder + timer_specs # type: MutableMapping[str, beam_runner_api_pb2.TimerSpec] + ): + # type: (...) -> None """Initialize a ``FnApiUserStateContext``. Args: @@ -477,16 +581,22 @@ def __init__( self._key_coder = key_coder self._window_coder = window_coder self._timer_specs = timer_specs - self._timer_receivers = None - self._all_states = {} + self._timer_receivers = None # type: Optional[Dict[str, operations.ConsumerSet]] + self._all_states = {} # type: Dict[tuple, Union[SynchronousBagRuntimeState, SynchronousSetRuntimeState, CombiningValueRuntimeState]] def update_timer_receivers(self, receivers): + # type: (operations._TaggedReceivers) -> None """TODO""" self._timer_receivers = {} for tag in self._timer_specs: self._timer_receivers[tag] = receivers.pop(tag) - def get_timer(self, timer_spec, key, window): + def get_timer(self, + timer_spec, + key, + window # type: windowed_value.BoundedWindow + ): + # type: (...) -> OutputTimer return OutputTimer( key, window, self._timer_receivers[timer_spec.name]) @@ -496,7 +606,11 @@ def get_state(self, *args): state_handle = self._all_states[args] = self._create_state(*args) return state_handle - def _create_state(self, state_spec, key, window): + def _create_state(self, + state_spec, # type: userstate.StateSpec + key, + window # type: windowed_value.BoundedWindow + ): if isinstance(state_spec, (userstate.BagStateSpec, userstate.CombiningValueStateSpec)): bag_state = SynchronousBagRuntimeState( @@ -528,10 +642,12 @@ def _create_state(self, state_spec, key, window): raise NotImplementedError(state_spec) def commit(self): + # type: () -> None for state in self._all_states.values(): state._commit() def reset(self): + # type: () -> None # TODO(BEAM-5428): Implement cross-bundle state caching. self._all_states = {} @@ -549,6 +665,7 @@ def wrapper(*args): def only_element(iterable): + # type: (Iterable[T]) -> T element, = iterable return element @@ -556,8 +673,12 @@ def only_element(iterable): class BundleProcessor(object): """ A class for processing bundles of elements. """ - def __init__( - self, process_bundle_descriptor, state_handler, data_channel_factory): + def __init__(self, + process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor + state_handler, # type: Union[FnApiRunner.StateServicer, GrpcStateHandler] + data_channel_factory # type: data_plane.DataChannelFactory + ): + # type: (...) -> None """Initialize a bundle processor. Args: @@ -580,7 +701,10 @@ def __init__( op.setup() self.splitting_lock = threading.Lock() - def create_execution_tree(self, descriptor): + def create_execution_tree(self, + descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor + ): + # type: (...) -> collections.OrderedDict[str, operations.Operation] transform_factory = BeamTransformFactory( descriptor, self.data_channel_factory, self.counter_factory, self.state_sampler, self.state_handler) @@ -591,7 +715,7 @@ def is_side_input(transform_proto, tag): transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload).side_inputs - pcoll_consumers = collections.defaultdict(list) + pcoll_consumers = collections.defaultdict(list) # type: DefaultDict[str, List[str]] for transform_id, transform_proto in descriptor.transforms.items(): for tag, pcoll_id in transform_proto.inputs.items(): if not is_side_input(transform_proto, tag): @@ -599,6 +723,7 @@ def is_side_input(transform_proto, tag): @memoize def get_operation(transform_id): + # type: (str) -> operations.Operation transform_consumers = { tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] for tag, pcoll_id @@ -610,6 +735,7 @@ def get_operation(transform_id): # Operations must be started (hence returned) in order. @memoize def topological_height(transform_id): + # type: (str) -> int return 1 + max( [0] + [topological_height(consumer) @@ -622,6 +748,7 @@ def topological_height(transform_id): descriptor.transforms, key=topological_height, reverse=True)]) def reset(self): + # type: () -> None self.counter_factory.reset() self.state_sampler.reset() # Side input caches. @@ -629,6 +756,7 @@ def reset(self): op.reset() def process_bundle(self, instruction_id): + # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] expected_inputs = [] for op in self.ops.values(): if isinstance(op, DataOutputOperation): @@ -650,7 +778,7 @@ def process_bundle(self, instruction_id): op.start() # Inject inputs from data plane. - data_channels = collections.defaultdict(list) + data_channels = collections.defaultdict(list) # type: DefaultDict[data_plane.GrpcClientDataChannel, List[str]] input_op_by_transform_id = {} for input_op in expected_inputs: data_channels[input_op.data_channel].append(input_op.transform_id) @@ -678,14 +806,17 @@ def process_bundle(self, instruction_id): self.state_sampler.stop_if_still_running() def finalize_bundle(self): + # type: () -> beam_fn_api_pb2.FinalizeBundleResponse for op in self.ops.values(): op.finalize_bundle() return beam_fn_api_pb2.FinalizeBundleResponse() def requires_finalization(self): + # type: () -> bool return any(op.needs_finalization() for op in self.ops.values()) def try_split(self, bundle_split_request): + # type: (...) -> beam_fn_api_pb2.ProcessBundleSplitResponse split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() with self.splitting_lock: for op in self.ops.values(): @@ -713,14 +844,20 @@ def try_split(self, bundle_split_request): return split_response - def delayed_bundle_application(self, op, deferred_remainder): + def delayed_bundle_application(self, + op, # type: operations.DoOperation + deferred_remainder # type: Tuple[windowed_value.WindowedValue, Timestamp] + ): + # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication transform_id, main_input_tag, main_input_coder, outputs = op.input_info # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. element_and_restriction, watermark = deferred_remainder if watermark: proto_watermark = timestamp_pb2.Timestamp() proto_watermark.FromMicroseconds(watermark.micros) - output_watermarks = {output: proto_watermark for output in outputs} + output_watermarks = { + output: proto_watermark for output in outputs + } # type: Optional[Dict[str, timestamp_pb2.Timestamp]] else: output_watermarks = None return beam_fn_api_pb2.DelayedBundleApplication( @@ -732,6 +869,7 @@ def delayed_bundle_application(self, op, deferred_remainder): element_and_restriction))) def metrics(self): + # type: () -> beam_fn_api_pb2.Metrics # DEPRECATED return beam_fn_api_pb2.Metrics( # TODO(robertwb): Rename to progress? @@ -763,6 +901,7 @@ def fix_only_output_tag(actual_output_tag, mapping): return metrics def monitoring_infos(self): + # type: () -> List[metrics_pb2.MonitoringInfo] """Returns the list of MonitoringInfos collected processing this bundle.""" # Construct a new dict first to remove duplciates. all_monitoring_infos_dict = {} @@ -811,6 +950,7 @@ def inject_pcollection(monitoring_info): return infos_list def _fix_output_tags_monitoring_info(self, transform_id, monitoring_info): + # type: (str, metrics_pb2.MonitoringInfo) -> metrics_pb2.MonitoringInfo actual_output_tags = list( self.process_bundle_descriptor.transforms[transform_id].outputs.keys()) if ('TAG' in monitoring_info.labels and @@ -820,19 +960,25 @@ def _fix_output_tags_monitoring_info(self, transform_id, monitoring_info): return monitoring_info def shutdown(self): + # type: () -> None for op in self.ops.values(): op.teardown() class ExecutionContext(object): def __init__(self): - self.delayed_applications = [] + self.delayed_applications = [] # type: List[Tuple[operations.DoOperation, Tuple[windowed_value.WindowedValue, Timestamp]]] class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" - def __init__(self, descriptor, data_channel_factory, counter_factory, - state_sampler, state_handler): + def __init__(self, + descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor + data_channel_factory, # type: data_plane.DataChannelFactory + counter_factory, + state_sampler, # type: statesampler.StateSampler + state_handler + ): self.descriptor = descriptor self.data_channel_factory = data_channel_factory self.counter_factory = counter_factory @@ -847,16 +993,24 @@ def __init__(self, descriptor, data_channel_factory, counter_factory, runner=beam_fn_api_pb2.StateKey.Runner(key=token)), element_coder_impl)) - _known_urns = {} + _known_urns = {} # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]] @classmethod - def register_urn(cls, urn, parameter_type): + def register_urn(cls, + urn, # type: str + parameter_type # type: Optional[Type[T]] + ): + # type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]] def wrapper(func): cls._known_urns[urn] = func, parameter_type return func return wrapper - def create_operation(self, transform_id, consumers): + def create_operation(self, + transform_id, # type: str + consumers # type: Dict[str, List[operations.Operation]] + ): + # type: (...) -> operations.Operation transform_proto = self.descriptor.transforms[transform_id] if not transform_proto.unique_name: logging.debug("No unique name set for transform %s" % transform_id) @@ -867,6 +1021,7 @@ def create_operation(self, transform_id, consumers): return creator(self, transform_id, transform_proto, payload, consumers) def get_coder(self, coder_id): + # type: (str) -> coders.Coder if coder_id not in self.descriptor.coders: raise KeyError("No such coder: %s" % coder_id) coder_proto = self.descriptor.coders[coder_id] @@ -878,6 +1033,7 @@ def get_coder(self, coder_id): json.loads(coder_proto.spec.payload.decode('utf-8'))) def get_windowed_coder(self, pcoll_id): + # type: (str) -> WindowedValueCoder coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) # TODO(robertwb): Remove this condition once all runners are consistent. if not isinstance(coder, WindowedValueCoder): @@ -889,26 +1045,32 @@ def get_windowed_coder(self, pcoll_id): return coder def get_output_coders(self, transform_proto): + # type: (beam_runner_api_pb2.PTransform) -> Dict[str, WindowedValueCoder] return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.outputs.items() } def get_only_output_coder(self, transform_proto): + # type: (beam_runner_api_pb2.PTransform) -> WindowedValueCoder return only_element(self.get_output_coders(transform_proto).values()) def get_input_coders(self, transform_proto): + # type: (beam_runner_api_pb2.PTransform) -> Dict[str, WindowedValueCoder] return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.inputs.items() } def get_only_input_coder(self, transform_proto): + # type: (beam_runner_api_pb2.PTransform) -> WindowedValueCoder return only_element(list(self.get_input_coders(transform_proto).values())) # TODO(robertwb): Update all operations to take these in the constructor. @staticmethod - def augment_oldstyle_op(op, step_name, consumers, tag_list=None): + def augment_oldstyle_op(op, # type: OperationT + step_name, consumers, tag_list=None): + # type: (...) -> OperationT op.step_name = step_name for tag, op_consumers in consumers.items(): for consumer in op_consumers: @@ -922,12 +1084,20 @@ def __init__(self, timer_tag, do_op): self._do_op = do_op def process(self, windowed_value): + # type: (windowed_value.WindowedValue) -> None self._do_op.process_timer(self._timer_tag, windowed_value) @BeamTransformFactory.register_urn( DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) -def create_source_runner(factory, transform_id, transform_proto, grpc_port, consumers): +def create_source_runner( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> DataInputOperation # Timers are the one special case where we don't want to call the # (unlabeled) operation.process() method, which we detect here. # TODO(robertwb): Consider generalizing if there are any more cases. @@ -942,7 +1112,8 @@ def create_source_runner(factory, transform_id, transform_proto, grpc_port, cons break if grpc_port.coder_id: - output_coder = factory.get_coder(grpc_port.coder_id) + output_coder = cast(coders.WindowedValueCoder, + factory.get_coder(grpc_port.coder_id)) else: logging.info( 'Missing required coder_id on grpc_port for %s; ' @@ -962,9 +1133,17 @@ def create_source_runner(factory, transform_id, transform_proto, grpc_port, cons @BeamTransformFactory.register_urn( DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) -def create_sink_runner(factory, transform_id, transform_proto, grpc_port, consumers): +def create_sink_runner( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> DataOutputOperation if grpc_port.coder_id: - output_coder = factory.get_coder(grpc_port.coder_id) + output_coder = cast(coders.WindowedValueCoder, + factory.get_coder(grpc_port.coder_id)) else: logging.info( 'Missing required coder_id on grpc_port for %s; ' @@ -983,7 +1162,14 @@ def create_sink_runner(factory, transform_id, transform_proto, grpc_port, consum @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None) -def create_source_java(factory, transform_id, transform_proto, parameter, consumers): +def create_source_java( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + parameter, + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> operations.ReadOperation # The Dataflow runner harness strips the base64 encoding. source = pickler.loads(base64.b64encode(parameter)) spec = operation_specs.WorkerRead( @@ -1001,7 +1187,14 @@ def create_source_java(factory, transform_id, transform_proto, parameter, consum @BeamTransformFactory.register_urn( common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload) -def create(factory, transform_id, transform_proto, parameter, consumers): +def create_deprecated_read( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + parameter, # type: beam_runner_api_pb2.ReadPayload + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> operations.ReadOperation source = iobase.SourceBase.from_runner_api(parameter.source, factory.context) spec = operation_specs.WorkerRead( iobase.SourceBundle(1.0, source, None, None), @@ -1018,7 +1211,14 @@ def create(factory, transform_id, transform_proto, parameter, consumers): @BeamTransformFactory.register_urn( python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) -def create_read_from_impulse_python(factory, transform_id, transform_proto, parameter, consumers): +def create_read_from_impulse_python( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + parameter, # type: beam_runner_api_pb2.ReadPayload + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> operations.ImpulseReadOperation return operations.ImpulseReadOperation( transform_proto.unique_name, factory.counter_factory, @@ -1030,7 +1230,13 @@ def create_read_from_impulse_python(factory, transform_id, transform_proto, para @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None) -def create_dofn_javasdk(factory, transform_id, transform_proto, serialized_fn, consumers): +def create_dofn_javasdk( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + serialized_fn, + consumers # type: Dict[str, List[operations.Operation]] +): return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -1078,7 +1284,12 @@ def process(self, element_restriction, *args, **kwargs): common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn, beam_runner_api_pb2.ParDoPayload) def create_process_sized_elements_and_restrictions( - factory, transform_id, transform_proto, parameter, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + parameter, # type: beam_runner_api_pb2.ParDoPayload + consumers # type: Dict[str, List[operations.Operation]] +): assert parameter.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO serialized_fn = parameter.do_fn.spec.payload return _create_pardo_operation( @@ -1103,7 +1314,14 @@ def _create_sdf_operation( @BeamTransformFactory.register_urn( common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) -def create_par_do(factory, transform_id, transform_proto, parameter, consumers): +def create_par_do( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + parameter, # type: beam_runner_api_pb2.ParDoPayload + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> operations.DoOperation assert parameter.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO serialized_fn = parameter.do_fn.spec.payload return _create_pardo_operation( @@ -1112,8 +1330,14 @@ def create_par_do(factory, transform_id, transform_proto, parameter, consumers): def _create_pardo_operation( - factory, transform_id, transform_proto, consumers, - serialized_fn, pardo_proto=None, operation_cls=operations.DoOperation): + factory, + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + consumers, + serialized_fn, + pardo_proto=None, # type: Optional[beam_runner_api_pb2.ParDoPayload] + operation_cls=operations.DoOperation +): if pardo_proto and pardo_proto.side_inputs: input_tags_to_coders = factory.get_input_coders(transform_proto) @@ -1152,7 +1376,8 @@ def mutate_tag(tag): # Windowing not set. if pardo_proto: other_input_tags = set.union( - set(pardo_proto.side_inputs), set(pardo_proto.timer_specs)) + set(pardo_proto.side_inputs), + set(pardo_proto.timer_specs)) # type: Container[str] else: other_input_tags = () pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() @@ -1161,9 +1386,10 @@ def mutate_tag(tag): factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,)) + timer_inputs = None # type: Optional[Dict[str, str]] if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs or pardo_proto.splittable): - main_input_coder = None + main_input_coder = None # type: Optional[WindowedValueCoder] timer_inputs = {} for tag, pcoll_id in transform_proto.inputs.items(): if tag in pardo_proto.timer_specs: @@ -1183,12 +1409,11 @@ def mutate_tag(tag): transform_id, main_input_coder.key_coder(), main_input_coder.window_coder, - timer_specs=pardo_proto.timer_specs) + timer_specs=pardo_proto.timer_specs) # type: Optional[FnApiUserStateContext] else: user_state_context = None else: user_state_context = None - timer_inputs = None output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( @@ -1217,8 +1442,12 @@ def mutate_tag(tag): return result -def _create_simple_pardo_operation( - factory, transform_id, transform_proto, consumers, dofn): +def _create_simple_pardo_operation(factory, # type: BeamTransformFactory + transform_id, + transform_proto, + consumers, + dofn, # type: beam.DoFn + ): serialized_fn = pickler.dumps((dofn, (), {}, [], None)) return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -1228,7 +1457,12 @@ def _create_simple_pardo_operation( common_urns.primitives.ASSIGN_WINDOWS.urn, beam_runner_api_pb2.WindowingStrategy) def create_assign_windows( - factory, transform_id, transform_proto, parameter, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + parameter, # type: beam_runner_api_pb2.WindowingStrategy + consumers # type: Dict[str, List[operations.Operation]] +): class WindowIntoDoFn(beam.DoFn): def __init__(self, windowing): self.windowing = windowing @@ -1248,7 +1482,13 @@ def process(self, element, timestamp=beam.DoFn.TimestampParam, @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) def create_identity_dofn( - factory, transform_id, transform_proto, unused_parameter, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + parameter, + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> operations.FlattenOperation return factory.augment_oldstyle_op( operations.FlattenOperation( transform_proto.unique_name, @@ -1264,7 +1504,13 @@ def create_identity_dofn( common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_precombine( - factory, transform_id, transform_proto, payload, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + payload, # type: beam_runner_api_pb2.CombinePayload + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> operations.PGBKCVOperation serialized_combine_fn = pickler.dumps( (beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -1285,7 +1531,12 @@ def create_combine_per_key_precombine( common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combbine_per_key_merge_accumulators( - factory, transform_id, transform_proto, payload, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + payload, # type: beam_runner_api_pb2.CombinePayload + consumers # type: Dict[str, List[operations.Operation]] +): return _create_combine_phase_operation( factory, transform_proto, payload, consumers, 'merge') @@ -1294,7 +1545,12 @@ def create_combbine_per_key_merge_accumulators( common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_extract_outputs( - factory, transform_id, transform_proto, payload, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + payload, # type: beam_runner_api_pb2.CombinePayload + consumers # type: Dict[str, List[operations.Operation]] +): return _create_combine_phase_operation( factory, transform_proto, payload, consumers, 'extract') @@ -1303,13 +1559,19 @@ def create_combine_per_key_extract_outputs( common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, beam_runner_api_pb2.CombinePayload) def create_combine_grouped_values( - factory, transform_id, transform_proto, payload, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + payload, # type: beam_runner_api_pb2.CombinePayload + consumers # type: Dict[str, List[operations.Operation]] +): return _create_combine_phase_operation( factory, transform_proto, payload, consumers, 'all') def _create_combine_phase_operation( factory, transform_proto, payload, consumers, phase): + # type: (...) -> operations.CombineOperation serialized_combine_fn = pickler.dumps( (beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -1329,7 +1591,13 @@ def _create_combine_phase_operation( @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None) def create_flatten( - factory, transform_id, transform_proto, unused_parameter, consumers): + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + payload, + consumers # type: Dict[str, List[operations.Operation]] +): + # type: (...) -> operations.FlattenOperation return factory.augment_oldstyle_op( operations.FlattenOperation( transform_proto.unique_name, @@ -1345,7 +1613,13 @@ def create_flatten( @BeamTransformFactory.register_urn( common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.SdkFunctionSpec) -def create_map_windows(factory, transform_id, transform_proto, mapping_fn_spec, consumers): +def create_map_windows( + factory, # type: BeamTransformFactory + transform_id, # type: str + transform_proto, # type: beam_runner_api_pb2.PTransform + mapping_fn_spec, # type: beam_runner_api_pb2.SdkFunctionSpec + consumers # type: Dict[str, List[operations.Operation]] +): assert mapping_fn_spec.spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN window_mapping_fn = pickler.loads(mapping_fn_spec.spec.payload) diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 8324e6bbf30e..bf76b39e7a36 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -29,6 +29,14 @@ import threading from builtins import object from builtins import range +from typing import TYPE_CHECKING +from typing import Callable +from typing import DefaultDict +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional import grpc from future.utils import raise_ @@ -40,18 +48,29 @@ from apache_beam.runners.worker.channel_factory import GRPCChannelFactory from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor +if TYPE_CHECKING: + # TODO: remove from TYPE_CHECKING scope when we drop support for python < 3.6 + from typing import Collection + # This module is experimental. No backwards-compatibility guarantees. _DEFAULT_FLUSH_THRESHOLD = 10 << 20 # 10MB -class ClosableOutputStream(type(coder_impl.create_OutputStream())): +if TYPE_CHECKING: + import apache_beam.coders.slow_stream + OutputStream = apache_beam.coders.slow_stream.OutputStream +else: + OutputStream = type(coder_impl.create_OutputStream()) + + +class ClosableOutputStream(OutputStream): """A Outputstream for use with CoderImpls that has a close() method.""" def __init__(self, - close_callback=None, - flush_callback=None, + close_callback=None, # type: Optional[Callable[[bytes], None]] + flush_callback=None, # type: Optional[Callable[[bytes], None]] flush_threshold=_DEFAULT_FLUSH_THRESHOLD): super(ClosableOutputStream, self).__init__() self._close_callback = close_callback @@ -90,8 +109,12 @@ class DataChannel(with_metaclass(abc.ABCMeta, object)): """ @abc.abstractmethod - def input_elements( - self, instruction_id, expected_transforms, abort_callback=None): + def input_elements(self, + instruction_id, # type: str + expected_transforms, # type: Collection[str] + abort_callback=None # type: Optional[Callable[[], bool]] + ): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] """Returns an iterable of all Element.Data bundles for instruction_id. This iterable terminates only once the full set of data has been recieved @@ -106,7 +129,11 @@ def input_elements( raise NotImplementedError(type(self)) @abc.abstractmethod - def output_stream(self, instruction_id, transform_id): + def output_stream(self, + instruction_id, # type: str + transform_id # type: str + ): + # type: (...) -> ClosableOutputStream """Returns an output stream writing elements to transform_id. Args: @@ -117,6 +144,7 @@ def output_stream(self, instruction_id, transform_id): @abc.abstractmethod def close(self): + # type: () -> None """Closes this channel, indicating that all data has been written. Data can continue to be read. @@ -135,14 +163,20 @@ class InMemoryDataChannel(DataChannel): """ def __init__(self, inverse=None): - self._inputs = [] + # type: (Optional[InMemoryDataChannel]) -> None + self._inputs = [] # type: List[beam_fn_api_pb2.Elements.Data] self._inverse = inverse or InMemoryDataChannel(self) def inverse(self): + # type: () -> InMemoryDataChannel return self._inverse - def input_elements(self, instruction_id, unused_expected_transforms=None, - abort_callback=None): + def input_elements(self, + instruction_id, # type: str + unused_expected_transforms=None, # type: Optional[Collection[str]] + abort_callback=None # type: Optional[Callable[[], bool]] + ): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] other_inputs = [] for data in self._inputs: if data.instruction_id == instruction_id: @@ -153,6 +187,7 @@ def input_elements(self, instruction_id, unused_expected_transforms=None, self._inputs = other_inputs def output_stream(self, instruction_id, transform_id): + # type: (str, str) -> ClosableOutputStream def add_to_inverse_output(data): self._inverse._inputs.append( # pylint: disable=protected-access beam_fn_api_pb2.Elements.Data( @@ -172,8 +207,8 @@ class _GrpcDataChannel(DataChannel): _WRITES_FINISHED = object() def __init__(self): - self._to_send = queue.Queue() - self._received = collections.defaultdict(queue.Queue) + self._to_send = queue.Queue() # type: queue.Queue[beam_fn_api_pb2.Elements.Data] + self._received = collections.defaultdict(queue.Queue) # type: DefaultDict[str, queue.Queue[beam_fn_api_pb2.Elements.Data]] self._receive_lock = threading.Lock() self._reads_finished = threading.Event() self._closed = False @@ -187,15 +222,21 @@ def wait(self, timeout=None): self._reads_finished.wait(timeout) def _receiving_queue(self, instruction_id): + # type: (str) -> queue.Queue[beam_fn_api_pb2.Elements.Data] with self._receive_lock: return self._received[instruction_id] def _clean_receiving_queue(self, instruction_id): + # type: (str) -> None with self._receive_lock: self._received.pop(instruction_id) - def input_elements(self, instruction_id, expected_transforms, - abort_callback=None): + def input_elements(self, + instruction_id, # type: str + expected_transforms, # type: Collection[str] + abort_callback=None # type: Optional[Callable[[], bool]] + ): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements.Data] """ Generator to retrieve elements for an instruction_id input_elements should be called only once for an instruction_id @@ -205,7 +246,7 @@ def input_elements(self, instruction_id, expected_transforms, expected_transforms(collection): expected transforms """ received = self._receiving_queue(instruction_id) - done_transforms = [] + done_transforms = [] # type: List[str] abort_callback = abort_callback or (lambda: False) try: while len(done_transforms) < len(expected_transforms): @@ -231,7 +272,9 @@ def input_elements(self, instruction_id, expected_transforms, self._clean_receiving_queue(instruction_id) def output_stream(self, instruction_id, transform_id): + # type: (str, str) -> ClosableOutputStream def add_to_send_queue(data): + # type: (bytes) -> None if data: self._to_send.put( beam_fn_api_pb2.Elements.Data( @@ -240,6 +283,7 @@ def add_to_send_queue(data): data=data)) def close_callback(data): + # type: (bytes) -> None add_to_send_queue(data) # End of stream marker. self._to_send.put( @@ -251,6 +295,7 @@ def close_callback(data): close_callback, flush_callback=add_to_send_queue) def _write_outputs(self): + # type: () -> Iterator[beam_fn_api_pb2.Elements] done = False while not done: data = [self._to_send.get()] @@ -267,6 +312,7 @@ def _write_outputs(self): yield beam_fn_api_pb2.Elements(data=data) def _read_inputs(self, elements_iterator): + # type: (Iterable[beam_fn_api_pb2.Elements]) -> None # TODO(robertwb): Pushback/throttling to avoid unbounded buffering. try: for elements in elements_iterator: @@ -282,6 +328,7 @@ def _read_inputs(self, elements_iterator): self._reads_finished.set() def set_inputs(self, elements_iterator): + # type: (Iterable[beam_fn_api_pb2.Elements]) -> None reader = threading.Thread( target=lambda: self._read_inputs(elements_iterator), name='read_grpc_client_inputs') @@ -292,7 +339,10 @@ def set_inputs(self, elements_iterator): class GrpcClientDataChannel(_GrpcDataChannel): """A DataChannel wrapping the client side of a BeamFnData connection.""" - def __init__(self, data_stub): + def __init__(self, + data_stub # type: beam_fn_api_pb2_grpc.BeamFnDataStub + ): + # type: (...) -> None super(GrpcClientDataChannel, self).__init__() self.set_inputs(data_stub.Data(self._write_outputs())) @@ -303,14 +353,19 @@ class BeamFnDataServicer(beam_fn_api_pb2_grpc.BeamFnDataServicer): def __init__(self): self._lock = threading.Lock() self._connections_by_worker_id = collections.defaultdict( - _GrpcDataChannel) + _GrpcDataChannel) # type: DefaultDict[str, _GrpcDataChannel] def get_conn_by_worker_id(self, worker_id): + # type: (str) -> _GrpcDataChannel with self._lock: return self._connections_by_worker_id[worker_id] - def Data(self, elements_iterator, context): - worker_id = dict(context.invocation_metadata()).get('worker_id') + def Data(self, + elements_iterator, # type: Iterable[beam_fn_api_pb2.Elements] + context + ): + # type: (...) -> Iterator[beam_fn_api_pb2.Elements] + worker_id = dict(context.invocation_metadata())['worker_id'] data_conn = self.get_conn_by_worker_id(worker_id) data_conn.set_inputs(elements_iterator) for elements in data_conn._write_outputs(): @@ -322,11 +377,13 @@ class DataChannelFactory(with_metaclass(abc.ABCMeta, object)): @abc.abstractmethod def create_data_channel(self, remote_grpc_port): + # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel """Returns a ``DataChannel`` from the given RemoteGrpcPort.""" raise NotImplementedError(type(self)) @abc.abstractmethod def close(self): + # type: () -> None """Close all channels that this factory owns.""" raise NotImplementedError(type(self)) @@ -337,8 +394,12 @@ class GrpcClientDataChannelFactory(DataChannelFactory): Caches the created channels by ``data descriptor url``. """ - def __init__(self, credentials=None, worker_id=None): - self._data_channel_cache = {} + def __init__(self, + credentials=None, + worker_id=None # type: Optional[str] + ): + # type: (...) -> None + self._data_channel_cache = {} # type: Dict[str, GrpcClientDataChannel] self._lock = threading.Lock() self._credentials = None self._worker_id = worker_id @@ -347,6 +408,7 @@ def __init__(self, credentials=None, worker_id=None): self._credentials = credentials def create_data_channel(self, remote_grpc_port): + # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel url = remote_grpc_port.api_service_descriptor.url if url not in self._data_channel_cache: with self._lock: @@ -373,6 +435,7 @@ def create_data_channel(self, remote_grpc_port): return self._data_channel_cache[url] def close(self): + # type: () -> None logging.info('Closing all cached grpc data channels.') for _, channel in self._data_channel_cache.items(): channel.close() @@ -383,10 +446,13 @@ class InMemoryDataChannelFactory(DataChannelFactory): """A singleton factory for ``InMemoryDataChannel``.""" def __init__(self, in_memory_data_channel): + # type: (GrpcClientDataChannel) -> None self._in_memory_data_channel = in_memory_data_channel def create_data_channel(self, unused_remote_grpc_port): + # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel return self._in_memory_data_channel def close(self): + # type: () -> None pass diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 08dac3a8923e..143409ad5c43 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -62,7 +62,7 @@ def __init__(self, log_service_descriptor): self._alive = True self._dropped_logs = 0 - self._log_entry_queue = queue.Queue(maxsize=self._QUEUE_SIZE) + self._log_entry_queue = queue.Queue(maxsize=self._QUEUE_SIZE) # type: queue.Queue[beam_fn_api_pb2.LogEntry] ch = GRPCChannelFactory.insecure_channel(log_service_descriptor.url) # Make sure the channel is ready to avoid [BEAM-4649] @@ -82,6 +82,7 @@ def connect(self): return self._logging_stub.Logging(self._write_log_entries()) def emit(self, record): + # type: (logging.LogRecord) -> None log_entry = beam_fn_api_pb2.LogEntry() log_entry.severity = self.LOG_LEVEL_MAP[record.levelno] log_entry.message = self.format(record) diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py index ae9cdd3ac751..72f24ff885fd 100644 --- a/sdks/python/apache_beam/runners/worker/logger.py +++ b/sdks/python/apache_beam/runners/worker/logger.py @@ -25,6 +25,8 @@ import logging import threading import traceback +from typing import Any +from typing import Dict from apache_beam.runners.worker import statesampler @@ -115,7 +117,7 @@ def format(self, record): Python thread object. Nevertheless having this value can allow to filter log statement from only one specific thread. """ - output = {} + output = {} # type: Dict[str, Any] output['timestamp'] = { 'seconds': int(record.created), 'nanos': int(record.msecs * 1000000)} diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py index ae36a6b0bf28..80528d390056 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters.py +++ b/sdks/python/apache_beam/runners/worker/opcounters.py @@ -27,11 +27,17 @@ import random from builtins import hex from builtins import object +from typing import TYPE_CHECKING +from typing import Optional from apache_beam.utils import counters from apache_beam.utils.counters import Counter from apache_beam.utils.counters import CounterName +if TYPE_CHECKING: + from apache_beam.utils import windowed_value + from apache_beam.runners.worker.statesampler import StateSampler + # This module is experimental. No backwards-compatibility guarantees. @@ -122,8 +128,12 @@ class SideInputReadCounter(TransformIOCounter): not be the only step that spends time reading from this side input. """ - def __init__(self, counter_factory, state_sampler, declaring_step, - input_index): + def __init__(self, + counter_factory, + state_sampler, # type: StateSampler + declaring_step, + input_index + ): """Create a side input read counter. Args: @@ -177,7 +187,12 @@ def value(self): class OperationCounters(object): """The set of basic counters to attach to an Operation.""" - def __init__(self, counter_factory, step_name, coder, output_index): + def __init__(self, + counter_factory, + step_name, # type: str + coder, + output_index + ): self._counter_factory = counter_factory self.element_counter = counter_factory.get_counter( '%s-out%s-ElementCount' % (step_name, output_index), Counter.SUM) @@ -185,12 +200,13 @@ def __init__(self, counter_factory, step_name, coder, output_index): '%s-out%s-MeanByteCount' % (step_name, output_index), Counter.BEAM_DISTRIBUTION) self.coder_impl = coder.get_impl() if coder else None - self.active_accumulator = None - self.current_size = None + self.active_accumulator = None # type: Optional[SumAccumulator] + self.current_size = None # type: Optional[int] self._sample_counter = 0 self._next_sample = 0 def update_from(self, windowed_value): + # type: (windowed_value.WindowedValue) -> None """Add one value to this counter.""" if self._should_sample(): self.do_sample(windowed_value) @@ -210,6 +226,7 @@ def _observable_callback_inner(value, is_encoded=False): return _observable_callback_inner def do_sample(self, windowed_value): + # type: (windowed_value.WindowedValue) -> None size, observables = ( self.coder_impl.get_estimated_size_and_observables(windowed_value)) if not observables: diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index 7e01e7ff3f5f..afae859eba63 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -29,6 +29,17 @@ from builtins import filter from builtins import object from builtins import zip +from typing import TYPE_CHECKING +from typing import Any +from typing import DefaultDict +from typing import Dict +from typing import FrozenSet +from typing import Hashable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union from apache_beam import pvalue from apache_beam.internal import pickler @@ -52,6 +63,10 @@ from apache_beam.transforms.window import GlobalWindows from apache_beam.utils.windowed_value import WindowedValue +if TYPE_CHECKING: + from apache_beam.runners.worker.bundle_processor import ExecutionContext + from apache_beam.runners.worker.statesampler import StateSampler + # Allow some "pure mode" declarations. try: import cython @@ -76,7 +91,13 @@ class ConsumerSet(Receiver): ConsumerSet are attached to the outputting Operation. """ @staticmethod - def create(counter_factory, step_name, output_index, consumers, coder): + def create(counter_factory, + step_name, # type: str + output_index, + consumers, # type: List[Operation] + coder + ): + # type: (...) -> ConsumerSet if len(consumers) == 1: return SingletonConsumerSet( counter_factory, step_name, output_index, consumers, coder) @@ -84,8 +105,13 @@ def create(counter_factory, step_name, output_index, consumers, coder): return ConsumerSet( counter_factory, step_name, output_index, consumers, coder) - def __init__( - self, counter_factory, step_name, output_index, consumers, coder): + def __init__(self, + counter_factory, + step_name, # type: str + output_index, + consumers, # type: List[Operation] + coder + ): self.consumers = consumers self.opcounter = opcounters.OperationCounters( counter_factory, step_name, coder, output_index) @@ -95,6 +121,7 @@ def __init__( self.coder = coder def receive(self, windowed_value): + # type: (WindowedValue) -> None self.update_counters_start(windowed_value) for consumer in self.consumers: cython.cast(Operation, consumer).process(windowed_value) @@ -109,6 +136,7 @@ def try_split(self, fraction_of_remainder): return None def current_element_progress(self): + # type: () -> Optional[iobase.RestrictionProgress] """Returns the progress of the current element. This progress should be an instance of @@ -119,9 +147,11 @@ def current_element_progress(self): return None def update_counters_start(self, windowed_value): + # type: (WindowedValue) -> None self.opcounter.update_from(windowed_value) def update_counters_finish(self): + # type: () -> None self.opcounter.update_collect() def __repr__(self): @@ -139,6 +169,7 @@ def __init__( self.consumer = consumers[0] def receive(self, windowed_value): + # type: (WindowedValue) -> None self.update_counters_start(windowed_value) self.consumer.process(windowed_value) self.update_counters_finish() @@ -157,7 +188,12 @@ class Operation(object): one or more receiver operations that will take that as input. """ - def __init__(self, name_context, spec, counter_factory, state_sampler): + def __init__(self, + name_context, # type: Union[str, common.NameContext] + spec, + counter_factory, + state_sampler # type: StateSampler + ): """Initializes a worker operation instance. Args: @@ -177,8 +213,8 @@ def __init__(self, name_context, spec, counter_factory, state_sampler): self.spec = spec self.counter_factory = counter_factory - self.execution_context = None - self.consumers = collections.defaultdict(list) + self.execution_context = None # type: Optional[ExecutionContext] + self.consumers = collections.defaultdict(list) # type: DefaultDict[int, List[Operation]] # These are overwritten in the legacy harness. self.metrics_container = MetricsContainer(self.name_context.metrics_name()) @@ -192,12 +228,13 @@ def __init__(self, name_context, spec, counter_factory, state_sampler): self.name_context, 'finish', metrics_container=self.metrics_container) # TODO(ccy): the '-abort' state can be added when the abort is supported in # Operations. - self.receivers = [] + self.receivers = [] # type: List[ConsumerSet] # Legacy workers cannot call setup() until after setting additional state # on the operation. self.setup_done = False def setup(self): + # type: () -> None """Set up operation. This must be called before any other methods of the operation.""" @@ -218,16 +255,19 @@ def setup(self): self.setup_done = True def start(self): + # type: () -> None """Start operation.""" if not self.setup_done: # For legacy workers. self.setup() def process(self, o): + # type: (WindowedValue) -> None """Process element in operation.""" pass def finalize_bundle(self): + # type: () -> None pass def needs_finalization(self): @@ -240,26 +280,32 @@ def current_element_progress(self): return None def finish(self): + # type: () -> None """Finish operation.""" pass def teardown(self): + # type: () -> None """Tear down operation. No other methods of this operation should be called after this.""" pass def reset(self): + # type: () -> None self.metrics_container.reset() def output(self, windowed_value, output_index=0): + # type: (WindowedValue, int) -> None cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value) def add_receiver(self, operation, output_index=0): + # type: (Operation, int) -> None """Adds a receiver operation for the specified output.""" self.consumers[output_index].append(operation) def progress_metrics(self): + # type: () -> beam_fn_api_pb2.Metrics.PTransform return beam_fn_api_pb2.Metrics.PTransform( processed_elements=beam_fn_api_pb2.Metrics.PTransform.ProcessedElements( measured=beam_fn_api_pb2.Metrics.PTransform.Measured( @@ -279,6 +325,7 @@ def progress_metrics(self): user=self.metrics_container.to_runner_api()) def monitoring_infos(self, transform_id): + # type: (str) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] """Returns the list of MonitoringInfos collected by this operation.""" all_monitoring_infos = self.execution_time_monitoring_infos(transform_id) all_monitoring_infos.update( @@ -328,6 +375,7 @@ def user_monitoring_infos(self, transform_id): return self.metrics_container.to_runner_api_monitoring_infos(transform_id) def execution_time_monitoring_infos(self, transform_id): + # type: (str) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] total_time_spent_msecs = ( self.scoped_start_state.sampled_msecs_int() + self.scoped_process_state.sampled_msecs_int() @@ -425,6 +473,7 @@ def __init__(self, name_context, counter_factory, state_sampler, next(iter(consumers.values())), output_coder)] def process(self, unused_impulse): + # type: (WindowedValue) -> None with self.scoped_process_state: range_tracker = self.source.get_range_tracker(None, None) for value in self.source.read(range_tracker): @@ -439,6 +488,7 @@ class InMemoryWriteOperation(Operation): """A write operation that will write to an in-memory sink.""" def process(self, o): + # type: (WindowedValue) -> None with self.scoped_process_state: if self.debug_logging_enabled: logging.debug('Processing [%s] in %s', o, self) @@ -461,17 +511,24 @@ def __missing__(self, tag): class DoOperation(Operation): """A Do operation that will execute a custom DoFn for each input element.""" - def __init__( - self, name, spec, counter_factory, sampler, side_input_maps=None, - user_state_context=None, timer_inputs=None): + def __init__(self, + name, # type: common.NameContext + spec, # operation_specs.WorkerDoFn # need to fix this type + counter_factory, + sampler, + side_input_maps=None, + user_state_context=None, + timer_inputs=None + ): super(DoOperation, self).__init__(name, spec, counter_factory, sampler) self.side_input_maps = side_input_maps self.user_state_context = user_state_context - self.tagged_receivers = None + self.tagged_receivers = None # type: Optional[_TaggedReceivers] # A mapping of timer tags to the input "PCollections" they come in on. self.timer_inputs = timer_inputs or {} def _read_side_inputs(self, tags_and_types): + # type: (...) -> Iterator[apache_sideinputs.SideInputMap] """Generator reading side inputs in the order prescribed by tags_and_types. Args: @@ -532,6 +589,7 @@ def _read_side_inputs(self, tags_and_types): view_class, view_options, sideinputs.EmulatedIterable(iterator_fn)) def setup(self): + # type: () -> None with self.scoped_start_state: super(DoOperation, self).setup() @@ -551,7 +609,7 @@ def setup(self): output_tag_prefix = PropertyNames.OUT + '_' for index, tag in enumerate(self.spec.output_tags): if tag == PropertyNames.OUT: - original_tag = None + original_tag = None # type: Optional[str] elif tag.startswith(output_tag_prefix): original_tag = tag[len(output_tag_prefix):] else: @@ -585,11 +643,13 @@ def setup(self): else DoFnRunnerReceiver(self.dofn_runner)) def start(self): + # type: () -> None with self.scoped_start_state: super(DoOperation, self).start() self.dofn_runner.start() def process(self, o): + # type: (WindowedValue) -> None with self.scoped_process_state: delayed_application = self.dofn_receiver.receive(o) if delayed_application: @@ -597,9 +657,11 @@ def process(self, o): (self, delayed_application)) def finalize_bundle(self): + # type: () -> None self.dofn_receiver.finalize() def needs_finalization(self): + # type: () -> bool return self.dofn_receiver.bundle_finalizer_param.has_callbacks() def process_timer(self, tag, windowed_timer): @@ -609,16 +671,19 @@ def process_timer(self, tag, windowed_timer): timer_spec, key, windowed_timer.windows[0], timer_data['timestamp']) def finish(self): + # type: () -> None with self.scoped_finish_state: self.dofn_runner.finish() if self.user_state_context: self.user_state_context.commit() def teardown(self): + # type: () -> None with self.scoped_finish_state: self.dofn_runner.teardown() def reset(self): + # type: () -> None super(DoOperation, self).reset() for side_input_map in self.side_input_maps: side_input_map.reset() @@ -627,6 +692,7 @@ def reset(self): self.dofn_receiver.bundle_finalizer_param.reset() def progress_metrics(self): + # type: () -> beam_fn_api_pb2.Metrics.PTransform metrics = super(DoOperation, self).progress_metrics() if self.tagged_receivers: metrics.processed_elements.measured.output_element_counts.clear() @@ -636,6 +702,7 @@ def progress_metrics(self): return metrics def monitoring_infos(self, transform_id): + # type: (str) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] infos = super(DoOperation, self).monitoring_infos(transform_id) if self.tagged_receivers: for tag, receiver in self.tagged_receivers.items(): @@ -676,6 +743,7 @@ def __init__(self, *args, **kwargs): self.element_start_output_bytes = None def process(self, o): + # type: (WindowedValue) -> None with self.scoped_process_state: try: with self.lock: @@ -707,6 +775,7 @@ def current_element_progress(self): self._total_output_bytes() - self.element_start_output_bytes) def progress_metrics(self): + # type: () -> beam_fn_api_pb2.Metrics.PTransform with self.lock: metrics = super(SdfProcessSizedElements, self).progress_metrics() current_element_progress = self.current_element_progress() @@ -733,6 +802,7 @@ def __init__(self, dofn_runner): self.dofn_runner = dofn_runner def receive(self, windowed_value): + # type: (WindowedValue) -> None self.dofn_runner.process(windowed_value) @@ -750,6 +820,7 @@ def __init__(self, name_context, spec, counter_factory, state_sampler): PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs)) def process(self, o): + # type: (WindowedValue) -> None with self.scoped_process_state: if self.debug_logging_enabled: logging.debug('Processing [%s] in %s', o, self) @@ -786,6 +857,7 @@ def __init__(self, name_context, spec, counter_factory, state_sampler): self.max_size = 10 * 1000 def process(self, o): + # type: (WindowedValue) -> None with self.scoped_process_state: # TODO(robertwb): Structural (hashable) values. key = o.value[0], tuple(o.windows) @@ -843,12 +915,13 @@ def __init__(self, name_context, spec, counter_factory, state_sampler): self.table = {} def process(self, wkv): + # type: (WindowedValue) -> None with self.scoped_process_state: key, value = wkv.value # pylint: disable=unidiomatic-typecheck # Optimization for the global window case. if len(wkv.windows) == 1 and type(wkv.windows[0]) is _global_window_type: - wkey = 0, key + wkey = 0, key # type: Tuple[Hashable, Any] else: wkey = tuple(wkv.windows), key entry = self.table.get(wkey, None) @@ -898,6 +971,7 @@ class FlattenOperation(Operation): """ def process(self, o): + # type: (WindowedValue) -> None with self.scoped_process_state: if self.debug_logging_enabled: logging.debug('Processing [%s] in %s', o, self) @@ -907,6 +981,7 @@ def process(self, o): def create_operation(name_context, spec, counter_factory, step_name=None, state_sampler=None, test_shuffle_source=None, test_shuffle_sink=None, is_streaming=False): + # type: (...) -> Operation """Create Operation object for given operation specification.""" # TODO(pabloem): Document arguments to this function call. @@ -916,7 +991,7 @@ def create_operation(name_context, spec, counter_factory, step_name=None, if isinstance(spec, operation_specs.WorkerRead): if isinstance(spec.source, iobase.SourceBundle): op = ReadOperation( - name_context, spec, counter_factory, state_sampler) + name_context, spec, counter_factory, state_sampler) # type: Operation else: from dataflow_worker.native_operations import NativeReadOperation op = NativeReadOperation( @@ -998,12 +1073,13 @@ def __init__( self._map_task = map_task self._counter_factory = counter_factory - self._ops = [] + self._ops = [] # type: List[Operation] self._state_sampler = state_sampler self._test_shuffle_source = test_shuffle_source self._test_shuffle_sink = test_shuffle_sink def operations(self): + # type: () -> List[Operation] return self._ops[:] def execute(self): diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 1016ae8275fd..b9bbb6ba901b 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -32,6 +32,14 @@ from builtins import object from builtins import range from concurrent import futures +from typing import TYPE_CHECKING +from typing import Callable +from typing import DefaultDict +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple import grpc from future.utils import raise_ @@ -46,6 +54,10 @@ from apache_beam.runners.worker.statecache import StateCache from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor +if TYPE_CHECKING: + from apache_beam.portability.api import endpoints_pb2 + from apache_beam.utils.profiler import Profile + # This SDK harness will (by default), log a "lull" in processing if it sees no # transitions in over 5 minutes. # 5 minutes * 60 seconds * 1020 millis * 1000 micros * 1000 nanoseconds @@ -56,13 +68,15 @@ class SdkHarness(object): REQUEST_METHOD_PREFIX = '_request_' SCHEDULING_DELAY_THRESHOLD_SEC = 5*60 # 5 Minutes - def __init__( - self, control_address, worker_count, - credentials=None, - worker_id=None, - # Caching is disabled by default - state_cache_size=0, - profiler_factory=None): + def __init__(self, + control_address, # type: str + worker_count, # type: int + credentials=None, + worker_id=None, # type: Optional[str] + # Caching is disabled by default + state_cache_size=0, + profiler_factory=None # type: Optional[Callable[..., Profile]] + ): self._alive = True self._worker_count = worker_count self._worker_index = 0 @@ -85,14 +99,14 @@ def __init__( self._state_handler_factory = GrpcStateHandlerFactory(state_cache_size, credentials) self._profiler_factory = profiler_factory - self._fns = {} + self._fns = {} # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor] # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) # workers for process/finalize bundle. - self.workers = queue.Queue() + self.workers = queue.Queue() # type: queue.Queue[SdkWorker] # one worker for progress/split request. self.progress_worker = SdkWorker(self._bundle_processor_cache, profiler_factory=self._profiler_factory) @@ -105,9 +119,9 @@ def __init__( # finalize and process share one thread pool. self._process_thread_pool = futures.ThreadPoolExecutor( max_workers=self._worker_count) - self._responses = queue.Queue() - self._process_bundle_queue = queue.Queue() - self._unscheduled_process_bundle = {} + self._responses = queue.Queue() # type: queue.Queue[beam_fn_api_pb2.InstructionResponse] + self._process_bundle_queue = queue.Queue() # type: queue.Queue[beam_fn_api_pb2.InstructionRequest] + self._unscheduled_process_bundle = {} # type: Dict[str, float] logging.info('Initializing SDKHarness with %s workers.', self._worker_count) def run(self): @@ -129,6 +143,7 @@ def run(self): profiler_factory=self._profiler_factory)) def get_responses(): + # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse] while True: response = self._responses.get() if response is no_more_work: @@ -165,7 +180,11 @@ def get_responses(): self._state_handler_factory.close() logging.info('Done consuming work.') - def _execute(self, task, request): + def _execute(self, + task, # type: Callable[[], beam_fn_api_pb2.InstructionResponse] + request # type: beam_fn_api_pb2.InstructionRequest + ): + # type: (...) -> None try: response = task() except Exception: # pylint: disable=broad-except @@ -179,6 +198,7 @@ def _execute(self, task, request): self._responses.put(response) def _request_register(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): for process_bundle_descriptor in getattr( @@ -192,6 +212,7 @@ def task(): self._execute(task, request) def _request_process_bundle(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): # Take the free worker. Wait till a worker is free. @@ -213,12 +234,15 @@ def task(): len(self._process_thread_pool._threads)) # type: ignore # private attr not exposed def _request_process_bundle_split(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_progress(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_action(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): instruction_id = getattr( @@ -239,6 +263,7 @@ def task(): self._progress_thread_pool.submit(task) def _request_finalize_bundle(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): # Get one available worker. @@ -253,6 +278,7 @@ def task(): self._process_thread_pool.submit(task) def _monitor_process_bundle(self): + # type: () -> None """ Monitor the unscheduled bundles and log if a bundle is not scheduled for more than SCHEDULING_DELAY_THRESHOLD_SEC. @@ -296,18 +322,29 @@ class BundleProcessorCache(object): performing processing. """ - def __init__(self, state_handler_factory, data_channel_factory, fns): + def __init__(self, + state_handler_factory, # type: StateHandlerFactory + data_channel_factory, # type: data_plane.DataChannelFactory + fns # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor] + ): self.fns = fns self.state_handler_factory = state_handler_factory self.data_channel_factory = data_channel_factory - self.active_bundle_processors = {} - self.cached_bundle_processors = collections.defaultdict(list) + self.active_bundle_processors = {} # type: Dict[str, Tuple[str, bundle_processor.BundleProcessor]] + self.cached_bundle_processors = collections.defaultdict(list) # type: DefaultDict[str, List[bundle_processor.BundleProcessor]] def register(self, bundle_descriptor): + # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None """Register a ``beam_fn_api_pb2.ProcessBundleDescriptor`` by its id.""" self.fns[bundle_descriptor.id] = bundle_descriptor def get(self, instruction_id, bundle_descriptor_id): + # type: (str, str) -> bundle_processor.BundleProcessor + """ + Return the requested ``BundleProcessor``, creating it if necessary. + + Moves the ``BundleProcessor`` from the inactive to the active cache. + """ try: # pop() is threadsafe processor = self.cached_bundle_processors[bundle_descriptor_id].pop() @@ -322,18 +359,36 @@ def get(self, instruction_id, bundle_descriptor_id): return processor def lookup(self, instruction_id): + # type: (str) -> Optional[bundle_processor.BundleProcessor] + """ + Return the requested ``BundleProcessor`` from the cache. + """ return self.active_bundle_processors.get(instruction_id, (None, None))[-1] def discard(self, instruction_id): + # type: (str) -> None + """ + Remove the ``BundleProcessor`` from the cache. + """ self.active_bundle_processors[instruction_id][1].shutdown() del self.active_bundle_processors[instruction_id] def release(self, instruction_id): + # type: (str) -> None + """ + Release the requested ``BundleProcessor``. + + Resets the ``BundleProcessor`` and moves it from the active to the + inactive cache. + """ descriptor_id, processor = self.active_bundle_processors.pop(instruction_id) processor.reset() self.cached_bundle_processors[descriptor_id].append(processor) def shutdown(self): + """ + Shutdown all ``BundleProcessor``s in the cache. + """ for instruction_id in self.active_bundle_processors: self.active_bundle_processors[instruction_id][1].shutdown() del self.active_bundle_processors[instruction_id] @@ -345,15 +400,17 @@ def shutdown(self): class SdkWorker(object): def __init__(self, - bundle_processor_cache, - profiler_factory=None, - log_lull_timeout_ns=None): + bundle_processor_cache, # type: BundleProcessorCache + profiler_factory=None, # type: Optional[Callable[..., Profile]] + log_lull_timeout_ns=None, + ): self.bundle_processor_cache = bundle_processor_cache self.profiler_factory = profiler_factory self.log_lull_timeout_ns = (log_lull_timeout_ns or DEFAULT_LOG_LULL_TIMEOUT_NS) def do_instruction(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> beam_fn_api_pb2.InstructionResponse request_type = request.WhichOneof('request') if request_type: # E.g. if register is set, this will call self.register(request.register)) @@ -362,7 +419,11 @@ def do_instruction(self, request): else: raise NotImplementedError - def register(self, request, instruction_id): + def register(self, + request, # type: beam_fn_api_pb2.RegisterRequest + instruction_id # type: str + ): + # type: (...) -> beam_fn_api_pb2.InstructionResponse """Registers a set of ``beam_fn_api_pb2.ProcessBundleDescriptor``s. This set of ``beam_fn_api_pb2.ProcessBundleDescriptor`` come as part of a @@ -376,7 +437,11 @@ def register(self, request, instruction_id): instruction_id=instruction_id, register=beam_fn_api_pb2.RegisterResponse()) - def process_bundle(self, request, instruction_id): + def process_bundle(self, + request, # type: beam_fn_api_pb2.ProcessBundleRequest + instruction_id # type: str + ): + # type: (...) -> beam_fn_api_pb2.InstructionResponse bundle_processor = self.bundle_processor_cache.get( instruction_id, request.process_bundle_descriptor_id) try: @@ -401,7 +466,11 @@ def process_bundle(self, request, instruction_id): self.bundle_processor_cache.discard(instruction_id) raise - def process_bundle_split(self, request, instruction_id): + def process_bundle_split(self, + request, # type: beam_fn_api_pb2.ProcessBundleSplitRequest + instruction_id # type: str + ): + # type: (...) -> beam_fn_api_pb2.InstructionResponse processor = self.bundle_processor_cache.lookup( request.instruction_id) if processor: @@ -437,7 +506,11 @@ def _log_lull_in_bundle_processor(self, processor): logging.warning( '%s%s. Traceback:\n%s', state_lull_log, step_name_log, stack_trace) - def process_bundle_progress(self, request, instruction_id): + def process_bundle_progress(self, + request, # type: beam_fn_api_pb2.ProcessBundleProgressRequest + instruction_id # type: str + ): + # type: (...) -> beam_fn_api_pb2.InstructionResponse # It is an error to get progress for a not-in-flight bundle. processor = self.bundle_processor_cache.lookup(request.instruction_id) if processor: @@ -448,7 +521,11 @@ def process_bundle_progress(self, request, instruction_id): metrics=processor.metrics() if processor else None, monitoring_infos=processor.monitoring_infos() if processor else [])) - def finalize_bundle(self, request, instruction_id): + def finalize_bundle(self, + request, # type: beam_fn_api_pb2.FinalizeBundleRequest + instruction_id # type: str + ): + # type: (...) -> beam_fn_api_pb2.InstructionResponse processor = self.bundle_processor_cache.lookup( request.instruction_id) if processor: @@ -503,13 +580,14 @@ class GrpcStateHandlerFactory(StateHandlerFactory): """ def __init__(self, state_cache_size, credentials=None): - self._state_handler_cache = {} + self._state_handler_cache = {} # type: Dict[str, GrpcStateHandler] self._lock = threading.Lock() self._throwing_state_handler = ThrowingStateHandler() self._credentials = credentials self._state_cache = StateCache(state_cache_size) def create_state_handler(self, api_service_descriptor): + # type: (endpoints_pb2.ApiServiceDescriptor) -> GrpcStateHandler if not api_service_descriptor: return self._throwing_state_handler url = api_service_descriptor.url @@ -571,10 +649,11 @@ class GrpcStateHandler(object): _DONE = object() def __init__(self, state_stub): + # type: (beam_fn_api_pb2_grpc.BeamFnStateStub) -> None self._lock = threading.Lock() self._state_stub = state_stub - self._requests = queue.Queue() - self._responses_by_id = {} + self._requests = queue.Queue() # type: queue.Queue[beam_fn_api_pb2.StateRequest] + self._responses_by_id = {} # type: Dict[str, _Future] self._last_id = 0 self._exc_info = None self._context = threading.local() @@ -623,7 +702,11 @@ def done(self): self._done = True self._requests.put(self._DONE) - def get_raw(self, state_key, continuation_token=None): + def get_raw(self, + state_key, # type: beam_fn_api_pb2.StateKey + continuation_token=None # type: Optional[bytes] + ): + # type: (...) -> Tuple[bytes, Optional[bytes]] response = self._blocking_request( beam_fn_api_pb2.StateRequest( state_key=state_key, @@ -631,19 +714,25 @@ def get_raw(self, state_key, continuation_token=None): continuation_token=continuation_token))) return response.get.data, response.get.continuation_token - def append_raw(self, state_key, data): + def append_raw(self, + state_key, # type: Optional[beam_fn_api_pb2.StateKey] + data # type: bytes + ): + # type: (...) -> _Future return self._request( beam_fn_api_pb2.StateRequest( state_key=state_key, append=beam_fn_api_pb2.StateAppendRequest(data=data))) def clear(self, state_key): + # type: (Optional[beam_fn_api_pb2.StateKey]) -> _Future return self._request( beam_fn_api_pb2.StateRequest( state_key=state_key, clear=beam_fn_api_pb2.StateClearRequest())) def _request(self, request): + # type: (beam_fn_api_pb2.StateRequest) -> _Future request.id = self._next_id() request.instruction_id = self._context.process_instruction_id # Adding a new item to a dictionary is atomic in cPython @@ -667,6 +756,7 @@ def _blocking_request(self, request): return response def _next_id(self): + # type: () -> str with self._lock: # Use a lock here because this GrpcStateHandler is shared across all # requests which have the same process bundle descriptor. State requests @@ -728,7 +818,13 @@ def blocking_get(self, state_key, coder, is_cached=False): materialized) return iter(cached_value) - def extend(self, state_key, coder, elements, is_cached=False): + def extend(self, + state_key, # type: beam_fn_api_pb2.StateKey + coder, # type: coder_impl.CoderImpl + elements, # type: Iterable[Any] + is_cached=False + ): + # type: (...) -> _Future if self._should_be_cached(is_cached): # Update the cache cache_key = self._convert_to_cache_key(state_key) @@ -746,6 +842,7 @@ def clear(self, state_key, is_cached=False): return self._underlying.clear(state_key) def done(self): + # type: () -> None self._underlying.done() def _materialize_iter(self, state_key, coder): @@ -794,8 +891,9 @@ def set(self, value): @classmethod def done(cls): + # type: () -> _Future if not hasattr(cls, 'DONE'): done_future = _Future() done_future.set(None) - cls.DONE = done_future - return cls.DONE + cls.DONE = done_future # type: ignore[attr-defined] + return cls.DONE # type: ignore[attr-defined] diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py index 528d040eb954..6a3f998f3beb 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler.py +++ b/sdks/python/apache_beam/runners/worker/statesampler.py @@ -20,7 +20,11 @@ from __future__ import absolute_import import threading -from collections import namedtuple +from typing import TYPE_CHECKING +from typing import Dict +from typing import NamedTuple +from typing import Optional +from typing import Union from apache_beam.runners import common from apache_beam.utils.counters import Counter @@ -34,6 +38,8 @@ from apache_beam.runners.worker import statesampler_slow as statesampler_impl FAST_SAMPLER = False +if TYPE_CHECKING: + from apache_beam.metrics.execution import MetricsContainer _STATE_SAMPLERS = threading.local() @@ -54,12 +60,12 @@ def for_test(): return get_current_tracker() -StateSamplerInfo = namedtuple( +StateSamplerInfo = NamedTuple( 'StateSamplerInfo', - ['state_name', - 'transition_count', - 'time_since_transition', - 'tracked_thread']) + [('state_name', CounterName), + ('transition_count', int), + ('time_since_transition', int), + ('tracked_thread', Optional[threading.Thread])]) # Default period for sampling current state of pipeline execution. @@ -68,37 +74,43 @@ def for_test(): class StateSampler(statesampler_impl.StateSampler): - def __init__(self, prefix, counter_factory, + def __init__(self, + prefix, # type: str + counter_factory, sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS): - self.states_by_name = {} self._prefix = prefix self._counter_factory = counter_factory - self._states_by_name = {} + self._states_by_name = {} # type: Dict[CounterName, statesampler_impl.ScopedState] self.sampling_period_ms = sampling_period_ms - self.tracked_thread = None + self.tracked_thread = None # type: Optional[threading.Thread] self.finished = False self.started = False super(StateSampler, self).__init__(sampling_period_ms) @property def stage_name(self): + # type: () -> str return self._prefix def stop(self): + # type: () -> None set_current_tracker(None) super(StateSampler, self).stop() def stop_if_still_running(self): + # type: () -> None if self.started and not self.finished: self.stop() def start(self): + # type: () -> None self.tracked_thread = threading.current_thread() set_current_tracker(self) super(StateSampler, self).start() self.started = True def get_info(self): + # type: () -> StateSamplerInfo """Returns StateSamplerInfo with transition statistics.""" return StateSamplerInfo( self.current_state().name, @@ -107,10 +119,12 @@ def get_info(self): self.tracked_thread) def scoped_state(self, - name_context, - state_name, + name_context, # type: Union[str, common.NameContext] + state_name, # type: str io_target=None, - metrics_container=None): + metrics_container=None # type: Optional[MetricsContainer] + ): + # type: (...) -> statesampler_impl.ScopedState """Returns a ScopedState object associated to a Step and a State. Args: @@ -143,6 +157,7 @@ def scoped_state(self, return self._states_by_name[counter_name] def commit_counters(self): + # type: () -> None """Updates output counters with latest state statistics.""" for state in self._states_by_name.values(): state_msecs = int(1e-6 * state.nsecs) diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index 00918285aee3..17bda0009e4d 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -20,6 +20,7 @@ from __future__ import absolute_import from builtins import object +from typing import Optional from apache_beam.runners import common from apache_beam.utils import counters @@ -35,6 +36,7 @@ def __init__(self, sampling_period_ms): self.time_since_transition = 0 def current_state(self): + # type: () -> ScopedState """Returns the current execution state. This operation is not thread safe, and should only be called from the @@ -42,38 +44,48 @@ def current_state(self): return self._state_stack[-1] def _scoped_state(self, - counter_name, - name_context, + counter_name, # type: counters.CounterName + name_context, # type: common.NameContext output_counter, metrics_container=None): + # type: (...) -> ScopedState assert isinstance(name_context, common.NameContext) return ScopedState( self, counter_name, name_context, output_counter, metrics_container) def _enter_state(self, state): + # type: (ScopedState) -> None self.state_transition_count += 1 self._state_stack.append(state) def _exit_state(self): + # type: () -> None self.state_transition_count += 1 self._state_stack.pop() def start(self): + # type: () -> None # Sampling not yet supported. Only state tracking at the moment. pass def stop(self): + # type: () -> None pass def reset(self): + # type: () -> None for state in self._states_by_name.values(): state.nsecs = 0 class ScopedState(object): - def __init__(self, sampler, name, step_name_context, - counter=None, metrics_container=None): + def __init__(self, + sampler, # type: StateSampler + name, # type: counters.CounterName + step_name_context, # type: Optional[common.NameContext] + counter=None, + metrics_container=None): self.state_sampler = sampler self.name = name self.name_context = step_name_context @@ -82,9 +94,11 @@ def __init__(self, sampler, name, step_name_context, self.metrics_container = metrics_container def sampled_seconds(self): + # type: () -> float return 1e-9 * self.nsecs def sampled_msecs_int(self): + # type: () -> int return int(1e-6 * self.nsecs) def __repr__(self): diff --git a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py index 6c9a605c3a5c..be24eb465167 100644 --- a/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py +++ b/sdks/python/apache_beam/runners/worker/worker_id_interceptor.py @@ -21,6 +21,7 @@ import collections import os +from typing import Optional import grpc @@ -41,6 +42,7 @@ class WorkerIdInterceptor(grpc.StreamStreamClientInterceptor): _worker_id = os.environ.get('WORKER_ID') def __init__(self, worker_id=None): + # type: (Optional[str]) -> None if worker_id: self._worker_id = worker_id diff --git a/sdks/python/apache_beam/runners/worker/worker_pool_main.py b/sdks/python/apache_beam/runners/worker/worker_pool_main.py index 11db233305e6..7656f2097a3f 100644 --- a/sdks/python/apache_beam/runners/worker/worker_pool_main.py +++ b/sdks/python/apache_beam/runners/worker/worker_pool_main.py @@ -36,6 +36,9 @@ import threading import time from concurrent import futures +from typing import Dict +from typing import Optional +from typing import Tuple import grpc @@ -47,19 +50,27 @@ class BeamFnExternalWorkerPoolServicer( beam_fn_api_pb2_grpc.BeamFnExternalWorkerPoolServicer): - def __init__(self, worker_threads, + def __init__(self, + worker_threads, # type: int use_process=False, - container_executable=None, - state_cache_size=0): + container_executable=None, # type: Optional[str] + state_cache_size=0 + ): self._worker_threads = worker_threads self._use_process = use_process self._container_executable = container_executable self._state_cache_size = state_cache_size - self._worker_processes = {} + self._worker_processes = {} # type: Dict[str, subprocess.Popen] @classmethod - def start(cls, worker_threads=1, use_process=False, port=0, - state_cache_size=0, container_executable=None): + def start(cls, + worker_threads=1, + use_process=False, + port=0, + state_cache_size=0, + container_executable=None # type: Optional[str] + ): + # type: (...) -> Tuple[str, grpc.Server] worker_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) worker_address = 'localhost:%s' % worker_server.add_insecure_port( '[::]:%s' % port) @@ -80,7 +91,11 @@ def kill_worker_processes(): return worker_address, worker_server - def StartWorker(self, start_worker_request, unused_context): + def StartWorker(self, + start_worker_request, # type: beam_fn_api_pb2.StartWorkerRequest + unused_context + ): + # type: (...) -> beam_fn_api_pb2.StartWorkerResponse try: if self._use_process: command = ['python', '-c', @@ -133,7 +148,11 @@ def StartWorker(self, start_worker_request, unused_context): except Exception as exn: return beam_fn_api_pb2.StartWorkerResponse(error=str(exn)) - def StopWorker(self, stop_worker_request, unused_context): + def StopWorker(self, + stop_worker_request, # type: beam_fn_api_pb2.StopWorkerRequest + unused_context + ): + # type: (...) -> beam_fn_api_pb2.StopWorkerResponse # applicable for process mode to ensure process cleanup # thread based workers terminate automatically worker_process = self._worker_processes.pop(stop_worker_request.worker_id, diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index ca3b3af2f633..67675123befb 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -32,6 +32,7 @@ import logging import time import uuid +from typing import List import apache_beam as beam from apache_beam.metrics import Metrics @@ -168,7 +169,7 @@ class MetricsReader(object): A :class:`MetricsReader` retrieves metrics from pipeline result, prepares it for publishers and setup publishers. """ - publishers = [] + publishers = [] # type: List[ConsoleMetricsPublisher] def __init__(self, project_name=None, bq_table=None, bq_dataset=None, filters=None): diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py index 02a860749c4b..7e6065695960 100644 --- a/sdks/python/apache_beam/testing/test_stream.py +++ b/sdks/python/apache_beam/testing/test_stream.py @@ -29,9 +29,9 @@ from future.utils import with_metaclass from apache_beam import coders -from apache_beam import core from apache_beam import pvalue from apache_beam.transforms import PTransform +from apache_beam.transforms import core from apache_beam.transforms import window from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index a3696b836030..1f674c13b920 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -219,6 +219,8 @@ def _py3__init__(self, n, **kwargs): self._py2__init__(n, None, **kwargs) # Python 3 sort does not accept a comparison operator, and nor do we. + # FIXME: mypy would handle this better if we placed the _py*__init__ funcs + # inside the if/else block below: if sys.version_info[0] < 3: __init__ = _py2__init__ else: @@ -309,7 +311,7 @@ def _py3__init__(self, n, **kwargs): if sys.version_info[0] < 3: __init__ = _py2__init__ else: - __init__ = _py3__init__ + __init__ = _py3__init__ # type: ignore def default_label(self): return 'TopPerKey(%d)' % self._n diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 2ebe5575ceb9..b60f461d7328 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -65,12 +65,21 @@ from apache_beam.typehints.typehints import is_consistent_with from apache_beam.utils import urns +if typing.TYPE_CHECKING: + from google.protobuf import message # pylint: disable=ungrouped-imports + from apache_beam.io import iobase + from apache_beam.pipeline import Pipeline + from apache_beam.runners.pipeline_context import PipelineContext + from apache_beam.transforms import create_source + from apache_beam.transforms.trigger import AccumulationMode + from apache_beam.transforms.trigger import DefaultTrigger + from apache_beam.transforms.trigger import TriggerFn + try: import funcsigs # Python 2 only. except ImportError: funcsigs = None - __all__ = [ 'DoFn', 'CombineFn', @@ -240,6 +249,7 @@ class RestrictionProvider(object): """ def create_tracker(self, restriction): + # type: (...) -> iobase.RestrictionTracker """Produces a new ``RestrictionTracker`` for the given restriction. Args: @@ -292,6 +302,7 @@ def split_and_size(self, element, restriction): def get_function_arguments(obj, func): + # type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]] """Return the function arguments based on the name provided. If they have a _inspect_function attached to the class then use that otherwise default to the modified version of python inspect library. @@ -308,6 +319,7 @@ def get_function_arguments(obj, func): def get_function_args_defaults(f): + # type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]] """Returns the function arguments of a given function. Returns: @@ -405,6 +417,7 @@ class _RestrictionDoFnParam(_DoFnParam): """Restriction Provider DoFn parameter.""" def __init__(self, restriction_provider): + # type: (RestrictionProvider) -> None if not isinstance(restriction_provider, RestrictionProvider): raise ValueError( 'DoFn.RestrictionParam expected RestrictionProvider object.') @@ -417,6 +430,7 @@ class _StateDoFnParam(_DoFnParam): """State DoFn parameter.""" def __init__(self, state_spec): + # type: (StateSpec) -> None if not isinstance(state_spec, StateSpec): raise ValueError("DoFn.StateParam expected StateSpec object.") self.state_spec = state_spec @@ -427,6 +441,7 @@ class _TimerDoFnParam(_DoFnParam): """Timer DoFn parameter.""" def __init__(self, timer_spec): + # type: (TimerSpec) -> None if not isinstance(timer_spec, TimerSpec): raise ValueError("DoFn.TimerParam expected TimerSpec object.") self.timer_spec = timer_spec @@ -845,6 +860,7 @@ def from_callable(fn): @staticmethod def maybe_from_callable(fn, has_side_inputs=True): + # type: (typing.Union[CombineFn, typing.Callable], bool) -> CombineFn if isinstance(fn, CombineFn): return fn elif callable(fn) and not has_side_inputs: @@ -1030,6 +1046,7 @@ def default_label(self): return self.__class__.__name__ def partition_for(self, element, num_partitions, *args, **kwargs): + # type: (T, int, *typing.Any, **typing.Any) -> int """Specify which partition will receive this element. Args: @@ -1069,6 +1086,7 @@ def __init__(self, fn): self._fn = fn def partition_for(self, element, num_partitions, *args, **kwargs): + # type: (T, int, *typing.Any, **typing.Any) -> int return self._fn(element, num_partitions, *args, **kwargs) @@ -1108,7 +1126,7 @@ def __init__(self, fn, *args, **kwargs): super(ParDo, self).__init__(fn, *args, **kwargs) # TODO(robertwb): Change all uses of the dofn attribute to use fn instead. self.dofn = self.fn - self.output_tags = set() + self.output_tags = set() # type: typing.Set[str] if not isinstance(self.fn, DoFn): raise TypeError('ParDo must be called with a DoFn instance.') @@ -1203,6 +1221,7 @@ def _pardo_fn_data(self): return self.fn, self.args, self.kwargs, si_tags_and_types, windowing def to_runner_api_parameter(self, context): + # type: (PipelineContext) -> typing.Tuple[str, message.Message] assert isinstance(self, ParDo), \ "expected instance of ParDo, but got %s" % self.__class__ picked_pardo_fn_data = pickler.dumps(self._pardo_fn_data()) @@ -1212,7 +1231,7 @@ def to_runner_api_parameter(self, context): if is_splittable: restriction_coder = ( DoFnSignature(self.fn).get_restriction_provider().restriction_coder()) - restriction_coder_id = context.coders.get_id(restriction_coder) + restriction_coder_id = context.coders.get_id(restriction_coder) # type: typing.Optional[str] else: restriction_coder_id = None return ( @@ -1795,7 +1814,10 @@ def default_type_hints(self): hints.set_output_types(typehints.Tuple[K, main_output_type]) return hints - def to_runner_api_parameter(self, context): + def to_runner_api_parameter(self, + context # type: PipelineContext + ): + # type: (...) -> typing.Tuple[str, beam_runner_api_pb2.CombinePayload] if self.args or self.kwargs: from apache_beam.transforms.combiners import curry_combine_fn combine_fn = curry_combine_fn(self.fn, self.args, self.kwargs) @@ -1858,7 +1880,11 @@ def from_runner_api_parameter(combine_payload, context): class CombineValuesDoFn(DoFn): """DoFn for performing per-key Combine transforms.""" - def __init__(self, input_pcoll_type, combinefn, runtime_type_check): + def __init__(self, + input_pcoll_type, + combinefn, # type: CombineFn + runtime_type_check, # type: bool + ): super(CombineValuesDoFn, self).__init__() self.combinefn = combinefn self.runtime_type_check = runtime_type_check @@ -1910,7 +1936,11 @@ def default_type_hints(self): class _CombinePerKeyWithHotKeyFanout(PTransform): - def __init__(self, combine_fn, fanout): + def __init__(self, + combine_fn, # type: CombineFn + fanout, # type: typing.Union[int, typing.Callable[[typing.Any], int]] + ): + # type: (...) -> None self._combine_fn = combine_fn self._fanout_fn = ( (lambda key: fanout) if isinstance(fanout, int) else fanout) @@ -2064,6 +2094,7 @@ def infer_output_type(self, input_type): return typehints.KV[key_type, typehints.Iterable[value_type]] def to_runner_api_parameter(self, unused_context): + # type: (PipelineContext) -> typing.Tuple[str, None] return common_urns.primitives.GROUP_BY_KEY.urn, None @PTransform.register_urn(common_urns.primitives.GROUP_BY_KEY.urn, None) @@ -2167,8 +2198,12 @@ def expand(self, pcoll): class Windowing(object): - def __init__(self, windowfn, triggerfn=None, accumulation_mode=None, - timestamp_combiner=None): + def __init__(self, + windowfn, # type: WindowFn + triggerfn=None, # type: typing.Optional[TriggerFn] + accumulation_mode=None, + timestamp_combiner=None, + ): global AccumulationMode, DefaultTrigger # pylint: disable=global-variable-not-assigned # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.transforms.trigger import AccumulationMode, DefaultTrigger @@ -2224,6 +2259,7 @@ def is_default(self): return self._is_default def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.WindowingStrategy return beam_runner_api_pb2.WindowingStrategy( window_fn=self.windowfn.to_runner_api(context), # TODO(robertwb): Prohibit implicit multi-level merging. @@ -2266,6 +2302,7 @@ class WindowIntoFn(DoFn): """A DoFn that applies a WindowInto operation.""" def __init__(self, windowing): + # type: (Windowing) -> None self.windowing = windowing def process(self, element, timestamp=DoFn.TimestampParam, @@ -2276,10 +2313,11 @@ def process(self, element, timestamp=DoFn.TimestampParam, yield WindowedValue(element, context.timestamp, new_windows) def __init__(self, - windowfn, - trigger=None, + windowfn, # type: typing.Union[Windowing, WindowFn] + trigger=None, # type: typing.Optional[TriggerFn] accumulation_mode=None, - timestamp_combiner=None): + timestamp_combiner=None + ): """Initializes a WindowInto transform. Args: @@ -2305,6 +2343,7 @@ def __init__(self, super(WindowInto, self).__init__(self.WindowIntoFn(self.windowing)) def get_windowing(self, unused_inputs): + # type: (typing.Any) -> Windowing return self.windowing def infer_output_type(self, input_type): @@ -2320,6 +2359,7 @@ def expand(self, pcoll): return super(WindowInto, self).expand(pcoll) def to_runner_api_parameter(self, context): + # type: (PipelineContext) -> typing.Tuple[str, message.Message] return ( common_urns.primitives.ASSIGN_WINDOWS.urn, self.windowing.to_runner_api(context)) @@ -2366,7 +2406,7 @@ class Flatten(PTransform): def __init__(self, **kwargs): super(Flatten, self).__init__() - self.pipeline = kwargs.pop('pipeline', None) + self.pipeline = kwargs.pop('pipeline', None) # type: typing.Optional[Pipeline] if kwargs: raise ValueError('Unexpected keyword arguments: %s' % list(kwargs)) @@ -2388,12 +2428,14 @@ def expand(self, pcolls): return result def get_windowing(self, inputs): + # type: (typing.Any) -> Windowing if not inputs: # TODO(robertwb): Return something compatible with every windowing? return Windowing(GlobalWindows()) return super(Flatten, self).get_windowing(inputs) def to_runner_api_parameter(self, context): + # type: (PipelineContext) -> typing.Tuple[str, None] return common_urns.primitives.FLATTEN.urn, None @staticmethod @@ -2424,6 +2466,7 @@ def __init__(self, values, reshuffle=True): self.reshuffle = reshuffle def to_runner_api_parameter(self, context): + # type: (PipelineContext) -> typing.Tuple[str, bytes] # Required as this is identified by type in PTransformOverrides. # TODO(BEAM-3812): Use an actual URN here. return self.to_runner_api_pickled(context) @@ -2474,6 +2517,7 @@ def expand(self, pcoll): | iobase.Read(source).with_output_types(self.get_output_type())) def get_windowing(self, unused_inputs): + # type: (typing.Any) -> Windowing return Windowing(GlobalWindows()) @staticmethod @@ -2482,6 +2526,7 @@ def _create_source_from_iterable(values, coder): @staticmethod def _create_source(serialized_values, coder): + # type: (typing.Any, typing.Any) -> create_source._CreateSource from apache_beam.transforms.create_source import _CreateSource return _CreateSource(serialized_values, coder) @@ -2497,12 +2542,14 @@ def expand(self, pbegin): return pvalue.PCollection(pbegin.pipeline) def get_windowing(self, inputs): + # type: (typing.Any) -> Windowing return Windowing(GlobalWindows()) def infer_output_type(self, unused_input_type): return bytes def to_runner_api_parameter(self, unused_context): + # type: (PipelineContext) -> typing.Tuple[str, None] return common_urns.primitives.IMPULSE.urn, None @PTransform.register_urn(common_urns.primitives.IMPULSE.urn, None) diff --git a/sdks/python/apache_beam/transforms/display.py b/sdks/python/apache_beam/transforms/display.py index bcbf68e8c051..8a24fe6d1895 100644 --- a/sdks/python/apache_beam/transforms/display.py +++ b/sdks/python/apache_beam/transforms/display.py @@ -44,9 +44,14 @@ from builtins import object from datetime import datetime from datetime import timedelta +from typing import TYPE_CHECKING +from typing import List from past.builtins import unicode +if TYPE_CHECKING: + from apache_beam.options.pipeline_options import PipelineOptions + __all__ = ['HasDisplayData', 'DisplayDataItem', 'DisplayData'] @@ -57,6 +62,7 @@ class HasDisplayData(object): """ def display_data(self): + # type: () -> dict """ Returns the display data associated to a pipeline component. It should be reimplemented in pipeline components that wish to have @@ -80,6 +86,7 @@ def display_data(self): return {} def _namespace(self): + # type: () -> str return '{}.{}'.format(self.__module__, self.__class__.__name__) @@ -87,9 +94,13 @@ class DisplayData(object): """ Static display data associated with a pipeline component. """ - def __init__(self, namespace, display_data_dict): + def __init__(self, + namespace, # type: str + display_data_dict # type: dict + ): + # type: (...) -> None self.namespace = namespace - self.items = [] + self.items = [] # type: List[DisplayDataItem] self._populate_items(display_data_dict) def _populate_items(self, display_data_dict): @@ -191,6 +202,7 @@ def __init__(self, value, url=None, label=None, self._drop_if_default = False def drop_if_none(self): + # type: () -> DisplayDataItem """ The item should be dropped if its value is None. Returns: @@ -200,6 +212,7 @@ def drop_if_none(self): return self def drop_if_default(self, default): + # type: (...) -> DisplayDataItem """ The item should be dropped if its value is equal to its default. Returns: @@ -210,6 +223,7 @@ def drop_if_default(self, default): return self def should_drop(self): + # type: () -> bool """ Return True if the item should be dropped, or False if it should not be dropped. This depends on the drop_if_none, and drop_if_default calls. @@ -223,6 +237,7 @@ def should_drop(self): return False def is_valid(self): + # type: () -> None """ Checks that all the necessary fields of the :class:`DisplayDataItem` are filled in. It checks that neither key, namespace, value or type are :data:`None`. @@ -261,6 +276,7 @@ def _get_dict(self): return res def get_dict(self): + # type: () -> dict """ Returns the internal-API dictionary representing the :class:`DisplayDataItem`. diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index e8c7e3f4a924..9714bc6580b8 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -25,6 +25,7 @@ import contextlib import copy import threading +from typing import Dict from apache_beam import pvalue from apache_beam.coders import registry @@ -247,6 +248,8 @@ def __init__(self, urn, payload, expansion_service=None): else payload) self._expansion_service = expansion_service self._namespace = self._fresh_namespace() + self._inputs = {} # type: Dict[str, pvalue.PCollection] + self._output = {} # type: Dict[str, pvalue.PCollection] def __post_init__(self, expansion_service): """ @@ -276,10 +279,12 @@ def outer_namespace(cls, namespace): @classmethod def _fresh_namespace(cls): + # type: () -> str ExternalTransform._namespace_counter += 1 return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter) def expand(self, pvalueish): + # type: (pvalue.PCollection) -> pvalue.PCollection if isinstance(pvalueish, pvalue.PBegin): self._inputs = {} elif isinstance(pvalueish, (list, tuple)): diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 3cbaef2e230d..1459ff4c50f3 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -48,6 +48,17 @@ class and wrapper class that allows lambda functions to be used as from builtins import zip from functools import reduce from functools import wraps +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import overload from google.protobuf import message @@ -68,12 +79,25 @@ class and wrapper class that allows lambda functions to be used as from apache_beam.typehints.typehints import validate_composite_type_param from apache_beam.utils import proto_utils +if TYPE_CHECKING: + from apache_beam import coders + from apache_beam.pipeline import Pipeline + from apache_beam.runners.pipeline_context import PipelineContext + from apache_beam.transforms.core import Windowing + from apache_beam.portability.api import beam_runner_api_pb2 + __all__ = [ 'PTransform', 'ptransform_fn', 'label_from_callable', ] +T = TypeVar('T') +PTransformT = TypeVar('PTransformT', bound='PTransform') +ConstructorFn = Callable[ + [Optional[Any], 'PipelineContext'], + Any] + class _PValueishTransform(object): """Visitor for PValueish objects. @@ -307,27 +331,31 @@ class PTransform(WithTypeHints, HasDisplayData): with input as an argument. """ # By default, transforms don't have any side inputs. - side_inputs = () + side_inputs = () # type: Sequence[pvalue.AsSideInput] # Used for nullary transforms. - pipeline = None + pipeline = None # type: Optional[Pipeline] # Default is unset. - _user_label = None + _user_label = None # type: Optional[str] def __init__(self, label=None): + # type: (Optional[str]) -> None super(PTransform, self).__init__() self.label = label # type: ignore # https://github.com/python/mypy/issues/3004 @property def label(self): + # type: () -> str return self._user_label or self.default_label() @label.setter def label(self, value): + # type: (Optional[str]) -> None self._user_label = value def default_label(self): + # type: () -> str return self.__class__.__name__ def with_input_types(self, input_type_hint): @@ -409,6 +437,7 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output): pvalue_.element_type)) def _infer_output_coder(self, input_type=None, input_coder=None): + # type: (...) -> Optional[coders.Coder] """Returns the output coder to use for output of this transform. Note: this API is experimental and is subject to change; please do not rely @@ -454,12 +483,14 @@ def _str_internal(self): ' side_inputs=%s' % str(self.side_inputs) if self.side_inputs else '') def _check_pcollection(self, pcoll): + # type: (pvalue.PCollection) -> None if not isinstance(pcoll, pvalue.PCollection): raise error.TransformError('Expecting a PCollection argument.') if not pcoll.pipeline: raise error.TransformError('PCollection not part of a pipeline.') def get_windowing(self, inputs): + # type: (Any) -> Windowing """Returns the window function to be associated with transform's output. By default most transforms just return the windowing function associated @@ -555,7 +586,45 @@ def _pvaluish_from_dict(self, input_dict): else: return input_dict - _known_urns = {} + _known_urns = {} # type: Dict[str, Tuple[Optional[type], ConstructorFn]] + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: Type[T] + ): + # type: (...) -> Callable[[Union[type, Callable[[T, PipelineContext], Any]]], Callable[[T, PipelineContext], Any]] + pass + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: None + ): + # type: (...) -> Callable[[Union[type, Callable[[bytes, PipelineContext], Any]]], Callable[[bytes, PipelineContext], Any]] + pass + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: Type[T] + constructor # type: Callable[[T, PipelineContext], Any] + ): + # type: (...) -> None + pass + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: None + constructor # type: Callable[[bytes, PipelineContext], Any] + ): + # type: (...) -> None + pass @classmethod def register_urn(cls, urn, parameter_type, constructor=None): @@ -588,6 +657,7 @@ def fake_static_method(): return register def to_runner_api(self, context, has_parts=False): + # type: (PipelineContext, bool) -> beam_runner_api_pb2.FunctionSpec from apache_beam.portability.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) if urn == python_urns.GENERIC_COMPOSITE_TRANSFORM and not has_parts: @@ -601,7 +671,11 @@ def to_runner_api(self, context, has_parts=False): else typed_param) @classmethod - def from_runner_api(cls, proto, context): + def from_runner_api(cls, + proto, # type: Optional[beam_runner_api_pb2.FunctionSpec] + context # type: PipelineContext + ): + # type: (...) -> Optional[PTransform] if proto is None or not proto.urn: return None parameter_type, constructor = cls._known_urns[proto.urn] @@ -618,12 +692,16 @@ def from_runner_api(cls, proto, context): return RunnerAPIPTransformHolder(proto, context) raise - def to_runner_api_parameter(self, unused_context): + def to_runner_api_parameter(self, + unused_context # type: PipelineContext + ): + # type: (...) -> Tuple[str, Optional[Union[message.Message, bytes, str]]] # The payload here is just to ease debugging. return (python_urns.GENERIC_COMPOSITE_TRANSFORM, getattr(self, '_fn_api_payload', str(self))) def to_runner_api_pickled(self, unused_context): + # type: (PipelineContext) -> Tuple[str, bytes] return (python_urns.PICKLED_TRANSFORM, pickler.dumps(self)) @@ -646,6 +724,7 @@ def _unpickle_transform(pickled_bytes, unused_context): class _ChainedPTransform(PTransform): def __init__(self, *parts): + # type: (*PTransform) -> None super(_ChainedPTransform, self).__init__(label=self._chain_label(parts)) self._parts = parts @@ -677,6 +756,7 @@ class PTransformWithSideInputs(PTransform): """ def __init__(self, fn, *args, **kwargs): + # type: (WithTypeHints, *Any, **Any) -> None if isinstance(fn, type) and issubclass(fn, WithTypeHints): # Don't treat Fn class objects as callables. raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__)) diff --git a/sdks/python/apache_beam/transforms/sideinputs.py b/sdks/python/apache_beam/transforms/sideinputs.py index 21fc919b72d1..cf63b728dc49 100644 --- a/sdks/python/apache_beam/transforms/sideinputs.py +++ b/sdks/python/apache_beam/transforms/sideinputs.py @@ -27,20 +27,31 @@ from __future__ import absolute_import from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Dict from apache_beam.transforms import window +if TYPE_CHECKING: + from apache_beam import pvalue + +WindowMappingFn = Callable[[window.BoundedWindow], window.BoundedWindow] # Top-level function so we can identify it later. def _global_window_mapping_fn(w, global_window=window.GlobalWindow()): + # type: (...) -> window.GlobalWindow return global_window def default_window_mapping_fn(target_window_fn): + # type: (window.WindowFn) -> WindowMappingFn if target_window_fn == window.GlobalWindows(): return _global_window_mapping_fn def map_via_end(source_window): + # type: (window.BoundedWindow) -> window.BoundedWindow return list(target_window_fn.assign( window.WindowFn.AssignContext(source_window.max_timestamp())))[-1] @@ -50,15 +61,20 @@ def map_via_end(source_window): class SideInputMap(object): """Represents a mapping of windows to side input values.""" - def __init__(self, view_class, view_options, iterable): + def __init__(self, + view_class, # type: pvalue.AsSideInput + view_options, + iterable + ): self._window_mapping_fn = view_options.get( 'window_mapping_fn', _global_window_mapping_fn) self._view_class = view_class self._view_options = view_options self._iterable = iterable - self._cache = {} + self._cache = {} # type: Dict[window.BoundedWindow, Any] def __getitem__(self, window): + # type: (window.BoundedWindow) -> Any if window not in self._cache: target_window = self._window_mapping_fn(window) self._cache[window] = self._view_class._from_runtime_iterable( @@ -66,6 +82,7 @@ def __getitem__(self, window): return self._cache[window] def is_globally_windowed(self): + # type: () -> bool return self._window_mapping_fn == _global_window_mapping_fn diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index 4d7126e9597f..83caee122e64 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -24,12 +24,25 @@ import types from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TypeVar from apache_beam.coders import Coder from apache_beam.coders import coders from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.transforms.timeutil import TimeDomain +if TYPE_CHECKING: + from apache_beam.runners.pipeline_context import PipelineContext + from apache_beam.transforms.core import CombineFn + +CallableT = TypeVar('CallableT', bound=Callable) + class StateSpec(object): """Specification for a user DoFn state cell.""" @@ -48,12 +61,14 @@ class BagStateSpec(StateSpec): """Specification for a user DoFn bag state cell.""" def __init__(self, name, coder): + # type: (str, Coder) -> None assert isinstance(name, str) assert isinstance(coder, Coder) self.name = name self.coder = coder def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec return beam_runner_api_pb2.StateSpec( bag_spec=beam_runner_api_pb2.BagStateSpec( element_coder_id=context.coders.get_id(self.coder))) @@ -63,6 +78,7 @@ class SetStateSpec(StateSpec): """Specification for a user DoFn Set State cell""" def __init__(self, name, coder): + # type: (str, Coder) -> None if not isinstance(name, str): raise TypeError("SetState name is not a string") if not isinstance(coder, Coder): @@ -80,6 +96,7 @@ class CombiningValueStateSpec(StateSpec): """Specification for a user DoFn combining value state cell.""" def __init__(self, name, coder=None, combine_fn=None): + # type: (str, Optional[Coder], Any) -> None """Initialize the specification for CombiningValue state. CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value @@ -118,6 +135,7 @@ def __init__(self, name, coder=None, combine_fn=None): self.coder = coder def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec return beam_runner_api_pb2.StateSpec( combining_spec=beam_runner_api_pb2.CombiningStateSpec( combine_fn=self.combine_fn.to_runner_api(context), @@ -138,6 +156,7 @@ def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self.name) def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.TimerSpec return beam_runner_api_pb2.TimerSpec( time_domain=TimeDomain.to_runner_api(self.time_domain), timer_coder_id=context.coders.get_id( @@ -145,6 +164,7 @@ def to_runner_api(self, context): def on_timer(timer_spec): + # type: (TimerSpec) -> Callable[[CallableT], CallableT] """Decorator for timer firing DoFn method. This decorator allows a user to specify an on_timer processing method @@ -174,6 +194,7 @@ def _inner(method): def get_dofn_specs(dofn): + # type: (...) -> Tuple[Set[StateSpec], Set[TimerSpec]] """Gets the state and timer specs for a DoFn, if any. Args: diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 0cba670e3c62..bda43e3bce43 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -26,12 +26,18 @@ import random import re import time -import typing import warnings from builtins import filter from builtins import object from builtins import range from builtins import zip +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterable +from typing import List +from typing import Tuple +from typing import TypeVar +from typing import Union from future.utils import itervalues @@ -65,6 +71,10 @@ from apache_beam.utils.annotations import deprecated from apache_beam.utils.annotations import experimental +if TYPE_CHECKING: + from apache_beam import pvalue + from apache_beam.runners.pipeline_context import PipelineContext + __all__ = [ 'BatchElements', 'CoGroupByKey', @@ -81,9 +91,9 @@ 'GroupIntoBatches' ] -K = typing.TypeVar('K') -V = typing.TypeVar('V') -T = typing.TypeVar('T') +K = TypeVar('K') +V = TypeVar('V') +T = TypeVar('T') class CoGroupByKey(PTransform): @@ -484,7 +494,7 @@ def finish_bundle(self): @typehints.with_input_types(T) -@typehints.with_output_types(typing.List[T]) +@typehints.with_output_types(List[T]) class BatchElements(PTransform): """A Transform that batches elements for amortized processing. @@ -577,8 +587,8 @@ def get_window_coder(self): return self._window_coder -@typehints.with_input_types(typing.Tuple[K, V]) -@typehints.with_output_types(typing.Tuple[K, V]) +@typehints.with_input_types(Tuple[K, V]) +@typehints.with_output_types(Tuple[K, V]) class ReshufflePerKey(PTransform): """PTransform that returns a PCollection equivalent to its input, but operationally provides some of the side effects of a GroupByKey, @@ -654,12 +664,15 @@ class Reshuffle(PTransform): """ def expand(self, pcoll): + # type: (pvalue.PValue) -> pvalue.PCollection + # FIXME: mypy plugin causing mypy to crash here: return (pcoll | 'AddRandomKeys' >> Map(lambda t: (random.getrandbits(32), t)) | ReshufflePerKey() | 'RemoveRandomKeys' >> Map(lambda t: t[1])) def to_runner_api_parameter(self, unused_context): + # type: (PipelineContext) -> Tuple[str, None] return common_urns.composites.RESHUFFLE.urn, None @PTransform.register_urn(common_urns.composites.RESHUFFLE.urn, None) @@ -680,7 +693,7 @@ def WithKeys(pcoll, k): @experimental() -@typehints.with_input_types(typing.Tuple[K, V]) +@typehints.with_input_types(Tuple[K, V]) class GroupIntoBatches(PTransform): """PTransform that batches the input into desired batch size. Elements are buffered until they are equal to batch size provided in the argument at which @@ -760,7 +773,7 @@ def __init__(self, delimiter=None): self.delimiter = delimiter or "," def expand(self, pcoll): - input_type = typing.Tuple[typing.Any, typing.Any] + input_type = Tuple[Any, Any] output_type = str return (pcoll | ('%s:KeyVaueToString' % self.label >> (Map( lambda x: "{}{}{}".format(x[0], self.delimiter, x[1]))) @@ -790,7 +803,7 @@ def __init__(self, delimiter=None): self.delimiter = delimiter or "," def expand(self, pcoll): - input_type = typing.Iterable[typing.Any] + input_type = Iterable[Any] output_type = str return (pcoll | ('%s:IterablesToString' % self.label >> ( Map(lambda x: self.delimiter.join(str(_x) for _x in x))) @@ -830,8 +843,8 @@ def add_window_info(element, timestamp=DoFn.TimestampParam, def expand(self, pcoll): return pcoll | ParDo(self.add_window_info) - @typehints.with_input_types(typing.Tuple[K, V]) - @typehints.with_output_types(typing.Tuple[K, V]) + @typehints.with_input_types(Tuple[K, V]) + @typehints.with_output_types(Tuple[K, V]) class TimestampInValue(PTransform): """PTransform to wrap the Value in a KV pair in a TimestampedValue with the element's associated timestamp.""" @@ -844,8 +857,8 @@ def add_timestamp_info(element, timestamp=DoFn.TimestampParam): def expand(self, pcoll): return pcoll | ParDo(self.add_timestamp_info) - @typehints.with_input_types(typing.Tuple[K, V]) - @typehints.with_output_types(typing.Tuple[K, V]) + @typehints.with_input_types(Tuple[K, V]) + @typehints.with_output_types(Tuple[K, V]) class WindowInValue(PTransform): """PTransform to convert the Value in a KV pair into a tuple of (value, timestamp, window), with the whole element being wrapped inside a @@ -904,7 +917,7 @@ def _process(element): @staticmethod @typehints.with_input_types(str) - @typehints.with_output_types(typing.List[str]) + @typehints.with_output_types(List[str]) @ptransform_fn def all_matches(pcoll, regex): """ @@ -925,7 +938,7 @@ def _process(element): @staticmethod @typehints.with_input_types(str) - @typehints.with_output_types(typing.Tuple[str, str]) + @typehints.with_output_types(Tuple[str, str]) @ptransform_fn def matches_kv(pcoll, regex, keyGroup, valueGroup=0): """ @@ -970,8 +983,7 @@ def _process(element): @staticmethod @typehints.with_input_types(str) - @typehints.with_output_types(typing.Union[typing.List[str], - typing.Tuple[str, str]]) + @typehints.with_output_types(Union[List[str], Tuple[str, str]]) @ptransform_fn def find_all(pcoll, regex, group=0, outputEmpty=True): """ @@ -999,7 +1011,7 @@ def _process(element): @staticmethod @typehints.with_input_types(str) - @typehints.with_output_types(typing.Tuple[str, str]) + @typehints.with_output_types(Tuple[str, str]) @ptransform_fn def find_kv(pcoll, regex, keyGroup, valueGroup=0): """ @@ -1056,7 +1068,7 @@ def replace_first(pcoll, regex, replacement): @staticmethod @typehints.with_input_types(str) - @typehints.with_output_types(typing.List[str]) + @typehints.with_output_types(List[str]) @ptransform_fn def split(pcoll, regex, outputEmpty=False): """ diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index e477303eebd7..ef191fe30212 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -53,6 +53,10 @@ from builtins import object from builtins import range from functools import total_ordering +from typing import Any +from typing import Iterable +from typing import List +from typing import Union from future.utils import with_metaclass from google.protobuf import duration_pb2 @@ -118,13 +122,18 @@ class WindowFn(with_metaclass(abc.ABCMeta, urns.RunnerApiFn)): class AssignContext(object): """Context passed to WindowFn.assign().""" - def __init__(self, timestamp, element=None, window=None): + def __init__(self, + timestamp, # type: Union[int, float, Timestamp] + element=None, + window=None + ): self.timestamp = Timestamp.of(timestamp) self.element = element self.window = window @abc.abstractmethod def assign(self, assign_context): + # type: (AssignContext) -> Iterable[BoundedWindow] """Associates windows to an element. Arguments: @@ -139,6 +148,7 @@ class MergeContext(object): """Context passed to WindowFn.merge() to perform merging, if any.""" def __init__(self, windows): + # type: (Iterable[Union[IntervalWindow, GlobalWindow]]) -> None self.windows = list(windows) def merge(self, to_be_merged, merge_result): @@ -146,6 +156,7 @@ def merge(self, to_be_merged, merge_result): @abc.abstractmethod def merge(self, merge_context): + # type: (WindowFn.MergeContext) -> None """Returns a window that is the result of merging a set of windows.""" raise NotImplementedError @@ -187,6 +198,7 @@ class BoundedWindow(object): """ def __init__(self, end): + # type: (Union[int, float, Timestamp]) -> None self.end = Timestamp.of(end) def max_timestamp(self): @@ -257,6 +269,7 @@ class TimestampedValue(object): """ def __init__(self, value, timestamp): + # type: (Any, Union[int, float, Timestamp]) -> None self.value = value self.timestamp = Timestamp.of(timestamp) @@ -318,6 +331,7 @@ def is_merging(self): return False def merge(self, merge_context): + # type: (WindowFn.MergeContext) -> None pass # No merging. @@ -367,7 +381,10 @@ class FixedWindows(NonMergingWindowFn): range. """ - def __init__(self, size, offset=0): + def __init__(self, + size, # type: Union[int, float, Duration] + offset=0 # type: Union[int, float, Timestamp] + ): """Initialize a ``FixedWindows`` function for a given size and offset. Args: @@ -432,7 +449,11 @@ class SlidingWindows(NonMergingWindowFn): in range [0, period). If it is not it will be normalized to this range. """ - def __init__(self, size, period, offset=0): + def __init__(self, + size, # type: Union[int, float, Duration] + period, # type: Union[int, float, Duration] + offset=0, # type: Union[int, float, Timestamp] + ): if size <= 0: raise ValueError('The size parameter must be strictly positive.') self.size = Duration.of(size) @@ -493,6 +514,7 @@ class Sessions(WindowFn): """ def __init__(self, gap_size): + # type: (Union[int, float, Duration]) -> None if gap_size <= 0: raise ValueError('The size parameter must be strictly positive.') self.gap_size = Duration.of(gap_size) @@ -505,7 +527,8 @@ def get_window_coder(self): return coders.IntervalWindowCoder() def merge(self, merge_context): - to_merge = [] + # type: (WindowFn.MergeContext) -> None + to_merge = [] # type: List[Union[IntervalWindow, GlobalWindow]] end = MIN_TIMESTAMP for w in sorted(merge_context.windows, key=lambda w: w.start): if to_merge: diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 01e2be4019d8..b68276eef150 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -92,6 +92,12 @@ def foo((a, b)): from builtins import next from builtins import object from builtins import zip +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import TypeVar from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import typehints @@ -113,6 +119,9 @@ def foo((a, b)): 'TypeCheckError', ] +T = TypeVar('T') +WithTypeHintsT = TypeVar('WithTypeHintsT', bound='WithTypeHints') # pylint: disable=invalid-name + # This is missing in the builtin types module. str.upper is arbitrary, any # method on a C-implemented type will do. # pylint: disable=invalid-name @@ -219,7 +228,10 @@ class IOTypeHints(object): """ __slots__ = ('input_types', 'output_types') - def __init__(self, input_types=None, output_types=None): + def __init__(self, + input_types=None, # type: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] + output_types=None # type: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] + ): self.input_types = input_types self.output_types = output_types @@ -312,9 +324,11 @@ def strip_iterable(self): return res def copy(self): + # type: () -> IOTypeHints return IOTypeHints(self.input_types, self.output_types) def with_defaults(self, hints): + # type: (Optional[IOTypeHints]) -> IOTypeHints if not hints: return self if self._has_input_types(): @@ -349,6 +363,7 @@ def __init__(self, *unused_args, **unused_kwargs): self._type_hints = IOTypeHints() def _get_or_create_type_hints(self): + # type: () -> IOTypeHints # __init__ may have not been called try: return self._type_hints @@ -372,12 +387,14 @@ def default_type_hints(self): return None def with_input_types(self, *arg_hints, **kwarg_hints): + # type: (WithTypeHintsT, *Any, **Any) -> WithTypeHintsT arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints) kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints) self._get_or_create_type_hints().set_input_types(*arg_hints, **kwarg_hints) return self def with_output_types(self, *arg_hints, **kwarg_hints): + # type: (WithTypeHintsT, *Any, **Any) -> WithTypeHintsT arg_hints = native_type_compatibility.convert_to_beam_types(arg_hints) kwarg_hints = native_type_compatibility.convert_to_beam_types(kwarg_hints) self._get_or_create_type_hints().set_output_types(*arg_hints, **kwarg_hints) @@ -574,6 +591,7 @@ def getcallargs_forhints_impl_py3(func, type_args, type_kwargs): def get_type_hints(fn): + # type: (Any) -> IOTypeHints """Gets the type hint associated with an arbitrary object fn. Always returns a valid IOTypeHints object, creating one if necessary. @@ -595,6 +613,7 @@ def get_type_hints(fn): def with_input_types(*positional_hints, **keyword_hints): + # type: (*Any, **Any) -> Callable[[T], T] """A decorator that type-checks defined type-hints with passed func arguments. All type-hinted arguments can be specified using positional arguments, @@ -677,6 +696,7 @@ def annotate(f): def with_output_types(*return_type_hint, **kwargs): + # type: (*Any, **Any) -> Callable[[T], T] """A decorator that type-checks defined type-hints for return values(s). This decorator will type-check the return value(s) of the decorated function. diff --git a/sdks/python/apache_beam/typehints/decorators_test_py3.py b/sdks/python/apache_beam/typehints/decorators_test_py3.py index d48f84575f89..32ec205497f0 100644 --- a/sdks/python/apache_beam/typehints/decorators_test_py3.py +++ b/sdks/python/apache_beam/typehints/decorators_test_py3.py @@ -38,7 +38,7 @@ class IOTypeHintsTest(unittest.TestCase): def test_from_callable(self): - def fn(a: int, b: str = None, *args: Tuple[T], foo: List[int], + def fn(a: int, b: str = '', *args: Tuple[T], foo: List[int], **kwargs: Dict[str, str]) -> Tuple: return a, b, args, foo, kwargs th = decorators.IOTypeHints.from_callable(fn) @@ -78,7 +78,7 @@ def method(self, arg: T = None) -> None: self.assertEqual(th.output_types, ((None,), {})) def test_getcallargs_forhints(self): - def fn(a: int, b: str = None, *args: Tuple[T], foo: List[int], + def fn(a: int, b: str = '', *args: Tuple[T], foo: List[int], **kwargs: Dict[str, str]) -> Tuple: return a, b, args, foo, kwargs callargs = decorators.getcallargs_forhints(fn, float, foo=List[str]) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 43cdedc60163..0b65458df490 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -140,7 +140,7 @@ def _match_is_union(user_type): # Mapping from typing.TypeVar/typehints.TypeVariable ids to an object of the # other type. Bidirectional mapping preserves typing.TypeVar instances. -_type_var_cache = {} +_type_var_cache = {} # type: typing.Dict[int, typehints.TypeVariable] def convert_to_beam_type(typ): diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 6062e6f00d32..084470381815 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -70,6 +70,7 @@ import logging import sys import types +import typing from builtins import next from builtins import zip @@ -1112,7 +1113,7 @@ def __getitem__(self, type_params): # There is a circular dependency between defining this mapping # and using it in normalize(). Initialize it here and populate # it below. -_KNOWN_PRIMITIVE_TYPES = {} +_KNOWN_PRIMITIVE_TYPES = {} # type: typing.Dict[type, typing.Any] def normalize(x, none_as_type=False): diff --git a/sdks/python/apache_beam/utils/counters.py b/sdks/python/apache_beam/utils/counters.py index dcb5683f62de..94a19d1c4a2e 100644 --- a/sdks/python/apache_beam/utils/counters.py +++ b/sdks/python/apache_beam/utils/counters.py @@ -29,9 +29,14 @@ from builtins import hex from builtins import object from collections import namedtuple +from typing import TYPE_CHECKING +from typing import Dict from apache_beam.transforms import cy_combiners +if TYPE_CHECKING: + from apache_beam.transforms import core + # Information identifying the IO being measured by a counter. # # A CounterName with IOTarget helps identify the IO being measured by a @@ -45,6 +50,7 @@ def side_input_id(step_name, input_index): + # type: (str, int) -> IOTargetName """Create an IOTargetName that identifies the reading of a side input. Given a step "s4" that receives two side inputs, then the CounterName @@ -60,6 +66,7 @@ def side_input_id(step_name, input_index): def shuffle_id(step_name): + # type: (str) -> IOTargetName """Create an IOTargetName that identifies a GBK step. Given a step "s6" that is downstream from a GBK "s5", then "s6" will read @@ -141,6 +148,7 @@ class Counter(object): DATAFLOW_DISTRIBUTION = cy_combiners.DataflowDistributionCounterFn() def __init__(self, name, combine_fn): + # type: (CounterName, core.CombineFn) -> None """Creates a Counter object. Args: @@ -177,6 +185,7 @@ class AccumulatorCombineFnCounter(Counter): """Counter optimized for a mutating accumulator that holds all the logic.""" def __init__(self, name, combine_fn): + # type: (CounterName, cy_combiners.AccumulatorCombineFn) -> None assert isinstance(combine_fn, cy_combiners.AccumulatorCombineFn) super(AccumulatorCombineFnCounter, self).__init__(name, combine_fn) self.reset() @@ -193,12 +202,13 @@ class CounterFactory(object): """Keeps track of unique counters.""" def __init__(self): - self.counters = {} + self.counters = {} # type: Dict[CounterName, Counter] # Lock to be acquired when accessing the counters map. self._lock = threading.Lock() def get_counter(self, name, combine_fn): + # type: (CounterName, core.CombineFn) -> Counter """Returns a counter with the requested name. Passing in the same name will return the same counter; the diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index 0606744bc1b9..5de27ca82fc0 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -33,6 +33,8 @@ import warnings from builtins import object from threading import Timer +from typing import Callable +from typing import Optional from apache_beam.io import filesystems @@ -95,11 +97,13 @@ def default_file_copy_fn(src, dest): @staticmethod def factory_from_options(options): + # type: (...) -> Optional[Callable[..., Profile]] if options.profile_cpu: def create_profiler(profile_id, **kwargs): if random.random() < options.profile_sample_rate: return Profile(profile_id, options.profile_location, **kwargs) return create_profiler + return None class MemoryReporter(object): diff --git a/sdks/python/apache_beam/utils/proto_utils.py b/sdks/python/apache_beam/utils/proto_utils.py index 9a76448b1068..c37b0ad39f5d 100644 --- a/sdks/python/apache_beam/utils/proto_utils.py +++ b/sdks/python/apache_beam/utils/proto_utils.py @@ -19,9 +19,29 @@ from __future__ import absolute_import +from typing import Type +from typing import TypeVar +from typing import Union +from typing import overload + from google.protobuf import any_pb2 +from google.protobuf import message from google.protobuf import struct_pb2 +MessageT = TypeVar('MessageT', bound=message.Message) + + +@overload +def pack_Any(msg): + # type: (message.Message) -> any_pb2.Any + pass + + +@overload +def pack_Any(msg): + # type: (None) -> None + pass + def pack_Any(msg): """Creates a protobuf Any with msg as its content. @@ -36,6 +56,18 @@ def pack_Any(msg): return result +@overload +def unpack_Any(any_msg, msg_class): + # type: (any_pb2.Any, Type[MessageT]) -> MessageT + pass + + +@overload +def unpack_Any(any_msg, msg_class): + # type: (any_pb2.Any, None) -> None + pass + + def unpack_Any(any_msg, msg_class): """Unpacks any_msg into msg_class. @@ -48,6 +80,18 @@ def unpack_Any(any_msg, msg_class): return msg +@overload +def parse_Bytes(serialized_bytes, msg_class): + # type: (bytes, Type[MessageT]) -> MessageT + pass + + +@overload +def parse_Bytes(serialized_bytes, msg_class): + # type: (bytes, Union[Type[bytes], None]) -> bytes + pass + + def parse_Bytes(serialized_bytes, msg_class): """Parses the String of bytes into msg_class. @@ -60,6 +104,7 @@ def parse_Bytes(serialized_bytes, msg_class): def pack_Struct(**kwargs): + # type: (...) -> struct_pb2.Struct """Returns a struct containing the values indicated by kwargs. """ msg = struct_pb2.Struct() diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py index 9bccdfd399f8..bf3f4dd99682 100644 --- a/sdks/python/apache_beam/utils/timestamp.py +++ b/sdks/python/apache_beam/utils/timestamp.py @@ -26,6 +26,9 @@ import datetime import functools from builtins import object +from typing import Any +from typing import Union +from typing import overload import dateutil.parser import pytz @@ -47,6 +50,7 @@ class Timestamp(object): """ def __init__(self, seconds=0, micros=0): + # type: (Union[int, long, float], Union[int, long, float]) -> None if not isinstance(seconds, (int, long, float)): raise TypeError('Cannot interpret %s %s as seconds.' % ( seconds, type(seconds))) @@ -57,6 +61,7 @@ def __init__(self, seconds=0, micros=0): @staticmethod def of(seconds): + # type: (Union[int, long, float, Timestamp]) -> Timestamp """Return the Timestamp for the given number of seconds. If the input is already a Timestamp, the input itself will be returned. @@ -136,14 +141,17 @@ def to_rfc3339(self): return self.to_utc_datetime().isoformat() + 'Z' def __float__(self): + # type: () -> float # Note that the returned value may have lost precision. return self.micros / 1000000 def __int__(self): + # type: () -> int # Note that the returned value may have lost precision. return self.micros // 1000000 def __eq__(self, other): + # type: (Union[int, long, float, Timestamp, Duration]) -> bool # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Duration): try: @@ -153,10 +161,12 @@ def __eq__(self, other): return self.micros == other.micros def __ne__(self, other): + # type: (Any) -> bool # TODO(BEAM-5949): Needed for Python 2 compatibility. return not self == other def __lt__(self, other): + # type: (Union[int, long, float, Timestamp, Duration]) -> bool # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Duration): other = Timestamp.of(other) @@ -166,17 +176,21 @@ def __hash__(self): return hash(self.micros) def __add__(self, other): + # type: (Union[int, long, float, Duration]) -> Timestamp other = Duration.of(other) return Timestamp(micros=self.micros + other.micros) def __radd__(self, other): + # type: (Union[int, long, float, Duration]) -> Timestamp return self + other def __sub__(self, other): + # type: (Union[int, long, float, Duration]) -> Timestamp other = Duration.of(other) return Timestamp(micros=self.micros - other.micros) def __mod__(self, other): + # type: (Union[int, long, float, Duration]) -> Duration other = Duration.of(other) return Duration(micros=self.micros % other.micros) @@ -200,10 +214,12 @@ class Duration(object): """ def __init__(self, seconds=0, micros=0): + # type: (Union[int, long, float], Union[int, long, float]) -> None self.micros = int(seconds * 1000000) + int(micros) @staticmethod def of(seconds): + # type: (Union[int, long, float, Duration]) -> Duration """Return the Duration for the given number of seconds since Unix epoch. If the input is already a Duration, the input itself will be returned. @@ -234,20 +250,24 @@ def __repr__(self): return 'Duration(%s%d)' % (sign, int_part) def __float__(self): + # type: () -> float # Note that the returned value may have lost precision. return self.micros / 1000000 def __eq__(self, other): + # type: (Union[int, long, float, Duration, Timestamp]) -> bool # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Timestamp): other = Duration.of(other) return self.micros == other.micros def __ne__(self, other): + # type: (Any) -> bool # TODO(BEAM-5949): Needed for Python 2 compatibility. return not self == other def __lt__(self, other): + # type: (Union[int, long, float, Duration, Timestamp]) -> bool # Allow comparisons between Duration and Timestamp values. if not isinstance(other, Timestamp): other = Duration.of(other) @@ -257,8 +277,19 @@ def __hash__(self): return hash(self.micros) def __neg__(self): + # type: () -> Duration return Duration(micros=-self.micros) + @overload + def __add__(self, other): + # type: (Timestamp) -> Timestamp + pass + + @overload + def __add__(self, other): + # type: (Union[int, long, float, Duration]) -> Duration + pass + def __add__(self, other): if isinstance(other, Timestamp): return other + self @@ -269,6 +300,7 @@ def __radd__(self, other): return self + other def __sub__(self, other): + # type: (Union[int, long, float, Duration]) -> Duration other = Duration.of(other) return Duration(micros=self.micros - other.micros) @@ -276,6 +308,7 @@ def __rsub__(self, other): return -(self - other) def __mul__(self, other): + # type: (Union[int, long, float, Duration]) -> Duration other = Duration.of(other) return Duration(micros=self.micros * other.micros // 1000000) @@ -283,6 +316,7 @@ def __rmul__(self, other): return self * other def __mod__(self, other): + # type: (Union[int, long, float, Duration]) -> Duration other = Duration.of(other) return Duration(micros=self.micros % other.micros) diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index 4e9c357e8e7e..4402ea40a2ac 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -22,6 +22,16 @@ import abc import inspect from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import overload from google.protobuf import message from google.protobuf import wrappers_pb2 @@ -29,6 +39,16 @@ from apache_beam.internal import pickler from apache_beam.utils import proto_utils +if TYPE_CHECKING: + from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.pipeline_context import PipelineContext + +T = TypeVar('T') +ConstructorFn = Callable[ + [Union['message.Message', bytes], + 'PipelineContext'], + Any] + class RunnerApiFn(object): """Abstract base class that provides urn registration utilities. @@ -44,16 +64,55 @@ class RunnerApiFn(object): # TODO(BEAM-2685): Issue with dill + local classes + abc metaclass # __metaclass__ = abc.ABCMeta - _known_urns = {} + _known_urns = {} # type: Dict[str, Tuple[Optional[type], ConstructorFn]] @abc.abstractmethod def to_runner_api_parameter(self, unused_context): + # type: (PipelineContext) -> Tuple[str, Any] """Returns the urn and payload for this Fn. The returned urn(s) should be registered with `register_urn`. """ pass + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: Type[T] + ): + # type: (...) -> Callable[[Callable[[T, PipelineContext], Any]], Callable[[T, PipelineContext], Any]] + pass + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: None + ): + # type: (...) -> Callable[[Callable[[bytes, PipelineContext], Any]], Callable[[bytes, PipelineContext], Any]] + pass + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: Type[T] + fn # type: Callable[[T, PipelineContext], Any] + ): + # type: (...) -> None + pass + + @classmethod + @overload + def register_urn(cls, + urn, # type: str + parameter_type, # type: None + fn # type: Callable[[bytes, PipelineContext], Any] + ): + # type: (...) -> None + pass + @classmethod def register_urn(cls, urn, parameter_type, fn=None): """Registers a urn with a constructor. @@ -90,6 +149,7 @@ def register_pickle_urn(cls, pickle_urn): lambda proto, unused_context: pickler.loads(proto.value)) def to_runner_api(self, context): + # type: (PipelineContext) -> beam_runner_api_pb2.SdkFunctionSpec """Returns an SdkFunctionSpec encoding this Fn. Prefer overriding self.to_runner_api_parameter. @@ -106,6 +166,7 @@ def to_runner_api(self, context): @classmethod def from_runner_api(cls, fn_proto, context): + # type: (beam_runner_api_pb2.SdkFunctionSpec, PipelineContext) -> Any """Converts from an SdkFunctionSpec to a Fn object. Prefer registering a urn with its parameter type and constructor. diff --git a/sdks/python/apache_beam/utils/windowed_value.py b/sdks/python/apache_beam/utils/windowed_value.py index 5570c4513b8c..2efb4c05d1f9 100644 --- a/sdks/python/apache_beam/utils/windowed_value.py +++ b/sdks/python/apache_beam/utils/windowed_value.py @@ -30,11 +30,21 @@ from __future__ import absolute_import from builtins import object +from typing import TYPE_CHECKING +from typing import Any +from typing import Optional +from typing import Tuple +from typing import Union + +from past.builtins import long from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import Timestamp +if TYPE_CHECKING: + from apache_beam.transforms.window import BoundedWindow + class PaneInfoTiming(object): """The timing of a PaneInfo.""" @@ -168,12 +178,20 @@ class WindowedValue(object): PANE_INFO_UNKNOWN. """ - def __init__(self, value, timestamp, windows, pane_info=PANE_INFO_UNKNOWN): + def __init__(self, + value, + timestamp, # type: Union[int, long, float, Timestamp] + windows, # type: Tuple[BoundedWindow, ...] + pane_info=PANE_INFO_UNKNOWN + ): + # type: (...) -> None # For performance reasons, only timestamp_micros is stored by default # (as a C int). The Timestamp object is created on demand below. self.value = value if isinstance(timestamp, int): self.timestamp_micros = timestamp * 1000000 + if TYPE_CHECKING: + self.timestamp_object = None # type: Optional[Timestamp] else: self.timestamp_object = (timestamp if isinstance(timestamp, Timestamp) else Timestamp.of(timestamp)) @@ -183,6 +201,7 @@ def __init__(self, value, timestamp, windows, pane_info=PANE_INFO_UNKNOWN): @property def timestamp(self): + # type: () -> Timestamp if self.timestamp_object is None: self.timestamp_object = Timestamp(0, self.timestamp_micros) return self.timestamp_object @@ -214,6 +233,7 @@ def __hash__(self): 11 * (hash(self.pane_info) & 0xFFFFFFFFFFFFF)) def with_value(self, new_value): + # type: (Any) -> WindowedValue """Creates a new WindowedValue with the same timestamps and windows as this. This is the fasted way to create a new WindowedValue. @@ -237,6 +257,7 @@ def create(value, timestamp_micros, windows, pane_info=PANE_INFO_UNKNOWN): try: + # FIXME: for review: why not add this as a class attribute? WindowedValue.timestamp_object = None except TypeError: # When we're compiled, we can't dynamically add attributes to @@ -249,10 +270,13 @@ class _IntervalWindowBase(object): """Optimized form of IntervalWindow storing only microseconds for endpoints. """ - def __init__(self, start, end): + def __init__(self, + start, # type: Optional[Union[int, long, float, Timestamp]] + end # type: Optional[Union[int, long, float, Timestamp]] + ): if start is not None or end is not None: - self._start_object = Timestamp.of(start) - self._end_object = Timestamp.of(end) + self._start_object = Timestamp.of(start) # type: Optional[Timestamp] + self._end_object = Timestamp.of(end) # type: Optional[Timestamp] try: self._start_micros = self._start_object.micros except OverflowError: @@ -271,12 +295,14 @@ def __init__(self, start, end): @property def start(self): + # type: () -> Timestamp if self._start_object is None: self._start_object = Timestamp(0, self._start_micros) return self._start_object @property def end(self): + # type: () -> Timestamp if self._end_object is None: self._end_object = Timestamp(0, self._end_micros) return self._end_object diff --git a/sdks/python/gen_protos.py b/sdks/python/gen_protos.py index 06867efa172c..5f98369ad86d 100644 --- a/sdks/python/gen_protos.py +++ b/sdks/python/gen_protos.py @@ -122,14 +122,20 @@ def generate_proto_files(force=False, log=None): if p.exitcode: raise ValueError("Proto generation failed (see log for details).") else: - log.info('Regenerating Python proto definitions (%s).' % regenerate) + + ret_code = subprocess.call(["pip", "install", "mypy-protobuf==1.12"]) + if ret_code: + raise RuntimeError( + 'Error installing mypy-protobuf during proto generation') + builtin_protos = pkg_resources.resource_filename('grpc_tools', '_proto') args = ( [sys.executable] + # expecting to be called from command line ['--proto_path=%s' % builtin_protos] + ['--proto_path=%s' % d for d in proto_dirs] + ['--python_out=%s' % out_dir] + + ['--mypy_out=%s' % out_dir] + # TODO(robertwb): Remove the prefix once it's the default. ['--grpc_python_out=grpc_2_0:%s' % out_dir] + proto_files)