diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 04ecbf161c75..1044c136671f 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -78,6 +78,9 @@ from apache_beam.typehints import typehints from apache_beam.utils.annotations import deprecated +if typing.TYPE_CHECKING: + from apache_beam.portability.api import beam_runner_api_pb2 + __all__ = ['Pipeline', 'PTransformOverride'] @@ -611,6 +614,7 @@ def visit_value(self, value, _): def to_runner_api( self, return_context=False, context=None, use_fake_coders=False, default_environment=None): + # type: (...) -> beam_runner_api_pb2.Pipeline """For internal use only; no backwards-compatibility guarantees.""" from apache_beam.runners import pipeline_context from apache_beam.portability.api import beam_runner_api_pb2 diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py index 23605b535c33..d29c48059489 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py @@ -26,6 +26,7 @@ import collections import threading +import typing import pydot @@ -37,9 +38,10 @@ class PipelineGraph(object): """Creates a DOT representation of the pipeline. Thread-safe.""" def __init__(self, - pipeline, + pipeline, # type: typing.Union[beam_runner_api_pb2.Pipeline, beam.Pipeline] default_vertex_attrs=None, - default_edge_attrs=None): + default_edge_attrs=None + ): """Constructor of PipelineGraph. Examples: @@ -57,7 +59,7 @@ def __init__(self, default_edge_attrs: (Dict[str, str]) a dict of default edge attributes """ self._lock = threading.Lock() - self._graph = None + self._graph = None # type: pydot.Dot if isinstance(pipeline, beam_runner_api_pb2.Pipeline): self._pipeline_proto = pipeline @@ -93,6 +95,7 @@ def __init__(self, default_edge_attrs) def get_dot(self): + # type: () -> str return self._get_graph().to_string() def _top_level_transforms(self): diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py index a216d195a1f1..d07f7c79dfaa 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_renderer.py @@ -27,11 +27,15 @@ import abc import os import subprocess +import typing from future.utils import with_metaclass from apache_beam.utils.plugin import BeamPlugin +if typing.TYPE_CHECKING: + from apache_beam.runners.interactive.display.pipeline_graph import PipelineGraph + class PipelineGraphRenderer(with_metaclass(abc.ABCMeta, BeamPlugin)): """Abstract class for renderers, who decide how pipeline graphs are rendered. @@ -40,12 +44,14 @@ class PipelineGraphRenderer(with_metaclass(abc.ABCMeta, BeamPlugin)): @classmethod @abc.abstractmethod def option(cls): + # type: () -> str """The corresponding rendering option for the renderer. """ raise NotImplementedError @abc.abstractmethod def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str """Renders the pipeline graph in HTML-compatible format. Args: @@ -63,9 +69,11 @@ class MuteRenderer(PipelineGraphRenderer): @classmethod def option(cls): + # type: () -> str return 'mute' def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str return '' @@ -75,9 +83,11 @@ class TextRenderer(PipelineGraphRenderer): @classmethod def option(cls): + # type: () -> str return 'text' def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str return pipeline_graph.get_dot() @@ -91,13 +101,16 @@ class PydotRenderer(PipelineGraphRenderer): @classmethod def option(cls): + # type: () -> str return 'graph' def render_pipeline_graph(self, pipeline_graph): + # type: (PipelineGraph) -> str return pipeline_graph._get_graph().create_svg() # pylint: disable=protected-access def get_renderer(option=None): + # type: (typing.Optional[str]) -> typing.Type[PipelineGraphRenderer] """Get an instance of PipelineGraphRenderer given rendering option. Args: diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py index 117477af93c6..511fe14e13ed 100644 --- a/sdks/python/apache_beam/runners/portability/portable_runner.py +++ b/sdks/python/apache_beam/runners/portability/portable_runner.py @@ -27,6 +27,7 @@ import sys import threading import time +import typing from concurrent import futures import grpc @@ -52,6 +53,9 @@ from apache_beam.runners.worker import sdk_worker from apache_beam.runners.worker import sdk_worker_main +if typing.TYPE_CHECKING: + from apache_beam.pipeline import Pipeline + __all__ = ['PortableRunner'] MESSAGE_LOG_LEVELS = { @@ -112,6 +116,7 @@ def default_docker_image(): @staticmethod def _create_environment(options): + # type: (...) -> beam_runner_api_pb2.Environment portable_options = options.view_as(PortableOptions) environment_urn = common_urns.environments.DOCKER.urn if portable_options.environment_type == 'DOCKER': @@ -166,6 +171,7 @@ def init_dockerized_job_server(self): self._job_endpoint = docker.start() def run_pipeline(self, pipeline, options): + # type: (Pipeline, typing.Any) -> PipelineResult portable_options = options.view_as(PortableOptions) job_endpoint = portable_options.job_endpoint diff --git a/sdks/python/apache_beam/runners/portability/portable_stager.py b/sdks/python/apache_beam/runners/portability/portable_stager.py index 09ff18fd4565..1fbe8f370fc5 100644 --- a/sdks/python/apache_beam/runners/portability/portable_stager.py +++ b/sdks/python/apache_beam/runners/portability/portable_stager.py @@ -22,6 +22,7 @@ import hashlib import os +import typing from apache_beam.portability.api import beam_artifact_api_pb2 from apache_beam.portability.api import beam_artifact_api_pb2_grpc @@ -54,9 +55,10 @@ def __init__(self, artifact_service_channel, staging_session_token): self._artifact_staging_stub = beam_artifact_api_pb2_grpc.\ ArtifactStagingServiceStub(channel=artifact_service_channel) self._staging_session_token = staging_session_token - self._artifacts = [] + self._artifacts = [] # type: typing.List[beam_artifact_api_pb2.ArtifactMetadata] def stage_artifact(self, local_path_to_artifact, artifact_name): + # type: (str, str) -> None """Stage a file to ArtifactStagingService. Args: @@ -69,6 +71,7 @@ def stage_artifact(self, local_path_to_artifact, artifact_name): .format(local_path_to_artifact)) def artifact_request_generator(): + # type: () -> typing.Iterator[beam_artifact_api_pb2.PutArtifactRequest] artifact_metadata = beam_artifact_api_pb2.ArtifactMetadata( name=artifact_name, sha256=_get_file_hash(local_path_to_artifact), diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index 0448a255aea2..7de5c3b48355 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -52,6 +52,7 @@ import shutil import sys import tempfile +import typing import pkg_resources @@ -107,10 +108,11 @@ def get_sdk_package_name(): def stage_job_resources(self, options, - build_setup_args=None, - temp_dir=None, - populate_requirements_cache=None, - staging_location=None): + build_setup_args=None, # type: typing.Optional[typing.List[str]] + temp_dir=None, # type: typing.Optional[str] + populate_requirements_cache=None, # type: typing.Optional[str] + staging_location=None + ): """For internal use only; no backwards-compatibility guarantees. Creates (if needed) and stages job resources to staging_location. diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 0c522744de9c..077c34cea3e0 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -80,7 +80,7 @@ def __init__( data_channel_factory=self._data_channel_factory, fns=self._fns) # workers for process/finalize bundle. - self.workers = queue.Queue() + self.workers = queue.Queue() # type: queue.Queue[SdkWorker] # one worker for progress/split request. self.progress_worker = SdkWorker(self._bundle_processor_cache, profiler_factory=self._profiler_factory) @@ -330,7 +330,10 @@ def shutdown(self): class SdkWorker(object): - def __init__(self, bundle_processor_cache, profiler_factory=None): + def __init__(self, + bundle_processor_cache, # type: BundleProcessorCache + profiler_factory=None + ): self.bundle_processor_cache = bundle_processor_cache self.profiler_factory = profiler_factory @@ -526,6 +529,7 @@ class GrpcStateHandler(object): _DONE = object() def __init__(self, state_stub): + # type: (beam_fn_api_pb2_grpc.BeamFnStateStub) -> None self._lock = threading.Lock() self._state_stub = state_stub self._requests = queue.Queue()