-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
direct_runner.py
606 lines (498 loc) · 23.1 KB
/
direct_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""DirectRunner, executing on the local machine.
The DirectRunner is a runner implementation that executes the entire
graph of transformations belonging to a pipeline on the local machine.
"""
# pytype: skip-file
import itertools
import logging
import time
import typing
from google.protobuf import wrappers_pb2
import apache_beam as beam
from apache_beam import coders
from apache_beam import typehints
from apache_beam.internal.util import ArgumentPlaceholder
from apache_beam.options.pipeline_options import DirectOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.pvalue import PCollection
from apache_beam.runners.direct.bundle_factory import BundleFactory
from apache_beam.runners.direct.clock import RealClock
from apache_beam.runners.direct.clock import TestClock
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineRunner
from apache_beam.runners.runner import PipelineState
from apache_beam.transforms import userstate
from apache_beam.transforms.core import CombinePerKey
from apache_beam.transforms.core import CombineValuesDoFn
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.core import ParDo
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.typehints import trivial_inference
__all__ = ['BundleBasedDirectRunner', 'DirectRunner', 'SwitchingDirectRunner']
_LOGGER = logging.getLogger(__name__)
class SwitchingDirectRunner(PipelineRunner):
"""Executes a single pipeline on the local machine.
This implementation switches between using the FnApiRunner (which has
high throughput for batch jobs) and using the BundleBasedDirectRunner,
which supports streaming execution and certain primitives not yet
implemented in the FnApiRunner.
"""
def is_fnapi_compatible(self):
return BundleBasedDirectRunner.is_fnapi_compatible()
def run_pipeline(self, pipeline, options):
from apache_beam.pipeline import PipelineVisitor
from apache_beam.testing.test_stream import TestStream
from apache_beam.io.gcp.pubsub import ReadFromPubSub
from apache_beam.io.gcp.pubsub import WriteToPubSub
class _FnApiRunnerSupportVisitor(PipelineVisitor):
"""Visitor determining if a Pipeline can be run on the FnApiRunner."""
def accept(self, pipeline):
self.supported_by_fnapi_runner = True
pipeline.visit(self)
return self.supported_by_fnapi_runner
def enter_composite_transform(self, applied_ptransform):
# The FnApiRunner does not support streaming execution.
if isinstance(applied_ptransform.transform,
(ReadFromPubSub, WriteToPubSub)):
self.supported_by_fnapi_runner = False
def visit_transform(self, applied_ptransform):
transform = applied_ptransform.transform
# The FnApiRunner does not support streaming execution.
if isinstance(transform, TestStream):
self.supported_by_fnapi_runner = False
if isinstance(transform, beam.ParDo):
dofn = transform.dofn
# The FnApiRunner does not support execution of CombineFns with
# deferred side inputs.
if isinstance(dofn, CombineValuesDoFn):
args, kwargs = transform.raw_side_inputs
args_to_check = itertools.chain(args, kwargs.values())
if any(isinstance(arg, ArgumentPlaceholder)
for arg in args_to_check):
self.supported_by_fnapi_runner = False
if userstate.is_stateful_dofn(dofn):
_, timer_specs = userstate.get_dofn_specs(dofn)
for timer in timer_specs:
if timer.time_domain == TimeDomain.REAL_TIME:
self.supported_by_fnapi_runner = False
# Check whether all transforms used in the pipeline are supported by the
# FnApiRunner, and the pipeline was not meant to be run as streaming.
if _FnApiRunnerSupportVisitor().accept(pipeline):
from apache_beam.portability.api import beam_provision_api_pb2
from apache_beam.runners.portability.fn_api_runner import fn_runner
from apache_beam.runners.portability.portable_runner import JobServiceHandle
all_options = options.get_all_options()
encoded_options = JobServiceHandle.encode_pipeline_options(all_options)
provision_info = fn_runner.ExtendedProvisionInfo(
beam_provision_api_pb2.ProvisionInfo(
pipeline_options=encoded_options))
runner = fn_runner.FnApiRunner(provision_info=provision_info)
else:
runner = BundleBasedDirectRunner()
return runner.run_pipeline(pipeline, options)
# Type variables.
K = typing.TypeVar('K')
V = typing.TypeVar('V')
@typehints.with_input_types(typing.Tuple[K, V])
@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
class _GroupByKeyOnly(PTransform):
"""A group by key transform, ignoring windows."""
def infer_output_type(self, input_type):
key_type, value_type = trivial_inference.key_value_types(input_type)
return typehints.KV[key_type, typehints.Iterable[value_type]]
def expand(self, pcoll):
self._check_pcollection(pcoll)
return PCollection.from_(pcoll)
@typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]])
@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
class _GroupAlsoByWindow(ParDo):
"""The GroupAlsoByWindow transform."""
def __init__(self, windowing):
super().__init__(_GroupAlsoByWindowDoFn(windowing))
self.windowing = windowing
def expand(self, pcoll):
self._check_pcollection(pcoll)
return PCollection.from_(pcoll)
class _GroupAlsoByWindowDoFn(DoFn):
# TODO(robertwb): Support combiner lifting.
def __init__(self, windowing):
super().__init__()
self.windowing = windowing
def infer_output_type(self, input_type):
key_type, windowed_value_iter_type = trivial_inference.key_value_types(
input_type)
value_type = windowed_value_iter_type.inner_type.inner_type
return typehints.KV[key_type, typehints.Iterable[value_type]]
def start_bundle(self):
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.transforms.trigger import create_trigger_driver
# pylint: enable=wrong-import-order, wrong-import-position
self.driver = create_trigger_driver(self.windowing, True)
def process(self, element):
k, vs = element
return self.driver.process_entire_key(k, vs)
@typehints.with_input_types(typing.Tuple[K, V])
@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
class _StreamingGroupByKeyOnly(_GroupByKeyOnly):
"""Streaming GroupByKeyOnly placeholder for overriding in DirectRunner."""
urn = "direct_runner:streaming_gbko:v0.1"
# These are needed due to apply overloads.
def to_runner_api_parameter(self, unused_context):
return _StreamingGroupByKeyOnly.urn, None
@staticmethod
@PTransform.register_urn(urn, None)
def from_runner_api_parameter(
unused_ptransform, unused_payload, unused_context):
return _StreamingGroupByKeyOnly()
@typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]])
@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
class _StreamingGroupAlsoByWindow(_GroupAlsoByWindow):
"""Streaming GroupAlsoByWindow placeholder for overriding in DirectRunner."""
urn = "direct_runner:streaming_gabw:v0.1"
# These are needed due to apply overloads.
def to_runner_api_parameter(self, context):
return (
_StreamingGroupAlsoByWindow.urn,
wrappers_pb2.BytesValue(
value=context.windowing_strategies.get_id(self.windowing)))
@staticmethod
@PTransform.register_urn(urn, wrappers_pb2.BytesValue)
def from_runner_api_parameter(unused_ptransform, payload, context):
return _StreamingGroupAlsoByWindow(
context.windowing_strategies.get_by_id(payload.value))
@typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]])
@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
class _GroupByKey(PTransform):
"""The DirectRunner GroupByKey implementation."""
def expand(self, pcoll):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.coders import typecoders
input_type = pcoll.element_type
if input_type is not None:
# Initialize type-hints used below to enforce type-checking and to
# pass downstream to further PTransforms.
key_type, value_type = trivial_inference.key_value_types(input_type)
# Enforce the input to a GBK has a KV element type.
pcoll.element_type = typehints.typehints.coerce_to_kv_type(
pcoll.element_type)
typecoders.registry.verify_deterministic(
typecoders.registry.get_coder(key_type),
'GroupByKey operation "%s"' % self.label)
reify_output_type = typehints.KV[
key_type, typehints.WindowedValue[value_type]] # type: ignore[misc]
gbk_input_type = (
typehints.KV[
key_type,
typehints.Iterable[typehints.WindowedValue[ # type: ignore[misc]
value_type]]])
gbk_output_type = typehints.KV[key_type, typehints.Iterable[value_type]]
# pylint: disable=bad-option-value
return (
pcoll
| 'ReifyWindows' >> (
ParDo(beam.GroupByKey.ReifyWindows()).with_output_types(
reify_output_type))
| 'GroupByKey' >> (
_GroupByKeyOnly().with_input_types(
reify_output_type).with_output_types(gbk_input_type))
| (
'GroupByWindow' >>
_GroupAlsoByWindow(pcoll.windowing).with_input_types(
gbk_input_type).with_output_types(gbk_output_type)))
else:
# The input_type is None, run the default
return (
pcoll
| 'ReifyWindows' >> ParDo(beam.GroupByKey.ReifyWindows())
| 'GroupByKey' >> _GroupByKeyOnly()
| 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing))
def _get_transform_overrides(pipeline_options):
# A list of PTransformOverride objects to be applied before running a pipeline
# using DirectRunner.
# Currently this only works for overrides where the input and output types do
# not change.
# For internal use only; no backwards-compatibility guarantees.
# Importing following locally to avoid a circular dependency.
from apache_beam.pipeline import PTransformOverride
from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey
from apache_beam.runners.direct.sdf_direct_runner import ProcessKeyedElementsViaKeyedWorkItemsOverride
from apache_beam.runners.direct.sdf_direct_runner import SplittableParDoOverride
class CombinePerKeyOverride(PTransformOverride):
def matches(self, applied_ptransform):
if isinstance(applied_ptransform.transform, CombinePerKey):
return applied_ptransform.inputs[0].windowing.is_default()
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
# TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
# with resolving imports when they are at top.
# pylint: disable=wrong-import-position
try:
transform = applied_ptransform.transform
return LiftedCombinePerKey(
transform.fn, transform.args, transform.kwargs)
except NotImplementedError:
return transform
class StreamingGroupByKeyOverride(PTransformOverride):
def matches(self, applied_ptransform):
# Note: we match the exact class, since we replace it with a subclass.
return applied_ptransform.transform.__class__ == _GroupByKeyOnly
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
# Use specialized streaming implementation.
transform = _StreamingGroupByKeyOnly()
return transform
class StreamingGroupAlsoByWindowOverride(PTransformOverride):
def matches(self, applied_ptransform):
# Note: we match the exact class, since we replace it with a subclass.
transform = applied_ptransform.transform
return (
isinstance(applied_ptransform.transform, ParDo) and
isinstance(transform.dofn, _GroupAlsoByWindowDoFn) and
transform.__class__ != _StreamingGroupAlsoByWindow)
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
# Use specialized streaming implementation.
transform = _StreamingGroupAlsoByWindow(
applied_ptransform.transform.dofn.windowing)
return transform
class TestStreamOverride(PTransformOverride):
def matches(self, applied_ptransform):
from apache_beam.testing.test_stream import TestStream
self.applied_ptransform = applied_ptransform
return isinstance(applied_ptransform.transform, TestStream)
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
from apache_beam.runners.direct.test_stream_impl import _ExpandableTestStream
return _ExpandableTestStream(applied_ptransform.transform)
class GroupByKeyPTransformOverride(PTransformOverride):
"""A ``PTransformOverride`` for ``GroupByKey``.
This replaces the Beam implementation as a primitive.
"""
def matches(self, applied_ptransform):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.transforms.core import GroupByKey
return isinstance(applied_ptransform.transform, GroupByKey)
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
return _GroupByKey()
overrides = [
# This needs to be the first and the last override. Other overrides depend
# on the GroupByKey implementation to be composed of _GroupByKeyOnly and
# _GroupAlsoByWindow.
GroupByKeyPTransformOverride(),
SplittableParDoOverride(),
ProcessKeyedElementsViaKeyedWorkItemsOverride(),
CombinePerKeyOverride(),
TestStreamOverride(),
]
# Add streaming overrides, if necessary.
if pipeline_options.view_as(StandardOptions).streaming:
overrides.append(StreamingGroupByKeyOverride())
overrides.append(StreamingGroupAlsoByWindowOverride())
# Add PubSub overrides, if PubSub is available.
try:
from apache_beam.io.gcp import pubsub as unused_pubsub
overrides += _get_pubsub_transform_overrides(pipeline_options)
except ImportError:
pass
# This also needs to be last because other transforms apply GBKs which need to
# be translated into a DirectRunner-compatible transform.
overrides.append(GroupByKeyPTransformOverride())
return overrides
class _DirectReadFromPubSub(PTransform):
def __init__(self, source):
self._source = source
def _infer_output_coder(
self, unused_input_type=None, unused_input_coder=None):
# type: (...) -> typing.Optional[coders.Coder]
return coders.BytesCoder()
def get_windowing(self, unused_inputs):
return beam.Windowing(beam.window.GlobalWindows())
def expand(self, pvalue):
# This is handled as a native transform.
return PCollection(self.pipeline, is_bounded=self._source.is_bounded())
class _DirectWriteToPubSubFn(DoFn):
BUFFER_SIZE_ELEMENTS = 100
FLUSH_TIMEOUT_SECS = BUFFER_SIZE_ELEMENTS * 0.5
def __init__(self, transform):
self.project = transform.project
self.short_topic_name = transform.topic_name
self.id_label = transform.id_label
self.timestamp_attribute = transform.timestamp_attribute
self.with_attributes = transform.with_attributes
# TODO(https://github.com/apache/beam/issues/18939): Add support for
# id_label and timestamp_attribute.
if transform.id_label:
raise NotImplementedError(
'DirectRunner: id_label is not supported for '
'PubSub writes')
if transform.timestamp_attribute:
raise NotImplementedError(
'DirectRunner: timestamp_attribute is not '
'supported for PubSub writes')
def start_bundle(self):
self._buffer = []
def process(self, elem):
self._buffer.append(elem)
if len(self._buffer) >= self.BUFFER_SIZE_ELEMENTS:
self._flush()
def finish_bundle(self):
self._flush()
def _flush(self):
from google.cloud import pubsub
pub_client = pubsub.PublisherClient()
topic = pub_client.topic_path(self.project, self.short_topic_name)
if self.with_attributes:
futures = [
pub_client.publish(topic, elem.data, **elem.attributes)
for elem in self._buffer
]
else:
futures = [pub_client.publish(topic, elem) for elem in self._buffer]
timer_start = time.time()
for future in futures:
remaining = self.FLUSH_TIMEOUT_SECS - (time.time() - timer_start)
future.result(remaining)
self._buffer = []
def _get_pubsub_transform_overrides(pipeline_options):
from apache_beam.io.gcp import pubsub as beam_pubsub
from apache_beam.pipeline import PTransformOverride
class ReadFromPubSubOverride(PTransformOverride):
def matches(self, applied_ptransform):
return isinstance(
applied_ptransform.transform, beam_pubsub.ReadFromPubSub)
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
if not pipeline_options.view_as(StandardOptions).streaming:
raise Exception(
'PubSub I/O is only available in streaming mode '
'(use the --streaming flag).')
return _DirectReadFromPubSub(applied_ptransform.transform._source)
class WriteToPubSubOverride(PTransformOverride):
def matches(self, applied_ptransform):
return isinstance(applied_ptransform.transform, beam_pubsub.WriteToPubSub)
def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
if not pipeline_options.view_as(StandardOptions).streaming:
raise Exception(
'PubSub I/O is only available in streaming mode '
'(use the --streaming flag).')
return beam.ParDo(_DirectWriteToPubSubFn(applied_ptransform.transform))
return [ReadFromPubSubOverride(), WriteToPubSubOverride()]
class BundleBasedDirectRunner(PipelineRunner):
"""Executes a single pipeline on the local machine."""
@staticmethod
def is_fnapi_compatible():
return False
def run_pipeline(self, pipeline, options):
"""Execute the entire pipeline and returns an DirectPipelineResult."""
# TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
# with resolving imports when they are at top.
# pylint: disable=wrong-import-position
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.direct.consumer_tracking_pipeline_visitor import \
ConsumerTrackingPipelineVisitor
from apache_beam.runners.direct.evaluation_context import EvaluationContext
from apache_beam.runners.direct.executor import Executor
from apache_beam.runners.direct.transform_evaluator import \
TransformEvaluatorRegistry
from apache_beam.testing.test_stream import TestStream
# If the TestStream I/O is used, use a mock test clock.
class TestStreamUsageVisitor(PipelineVisitor):
"""Visitor determining whether a Pipeline uses a TestStream."""
def __init__(self):
self.uses_test_stream = False
def visit_transform(self, applied_ptransform):
if isinstance(applied_ptransform.transform, TestStream):
self.uses_test_stream = True
visitor = TestStreamUsageVisitor()
pipeline.visit(visitor)
clock = TestClock() if visitor.uses_test_stream else RealClock()
# Performing configured PTransform overrides.
pipeline.replace_all(_get_transform_overrides(options))
_LOGGER.info('Running pipeline with DirectRunner.')
self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor()
pipeline.visit(self.consumer_tracking_visitor)
evaluation_context = EvaluationContext(
options,
BundleFactory(
stacked=options.view_as(
DirectOptions).direct_runner_use_stacked_bundle),
self.consumer_tracking_visitor.root_transforms,
self.consumer_tracking_visitor.value_to_consumers,
self.consumer_tracking_visitor.step_names,
self.consumer_tracking_visitor.views,
clock)
executor = Executor(
self.consumer_tracking_visitor.value_to_consumers,
TransformEvaluatorRegistry(evaluation_context),
evaluation_context)
# DirectRunner does not support injecting
# PipelineOptions values at runtime
RuntimeValueProvider.set_runtime_options({})
# Start the executor. This is a non-blocking call, it will start the
# execution in background threads and return.
executor.start(self.consumer_tracking_visitor.root_transforms)
result = DirectPipelineResult(executor, evaluation_context)
return result
# Use the SwitchingDirectRunner as the default.
DirectRunner = SwitchingDirectRunner
class DirectPipelineResult(PipelineResult):
"""A DirectPipelineResult provides access to info about a pipeline."""
def __init__(self, executor, evaluation_context):
super().__init__(PipelineState.RUNNING)
self._executor = executor
self._evaluation_context = evaluation_context
def __del__(self):
if self._state == PipelineState.RUNNING:
_LOGGER.warning(
'The DirectPipelineResult is being garbage-collected while the '
'DirectRunner is still running the corresponding pipeline. This may '
'lead to incomplete execution of the pipeline if the main thread '
'exits before pipeline completion. Consider using '
'result.wait_until_finish() to wait for completion of pipeline '
'execution.')
def wait_until_finish(self, duration=None):
if not PipelineState.is_terminal(self.state):
if duration:
raise NotImplementedError(
'DirectRunner does not support duration argument.')
try:
self._executor.await_completion()
self._state = PipelineState.DONE
except: # pylint: disable=broad-except
self._state = PipelineState.FAILED
raise
return self._state
def aggregated_values(self, aggregator_or_name):
return self._evaluation_context.get_aggregator_values(aggregator_or_name)
def metrics(self):
return self._evaluation_context.metrics()
def cancel(self):
"""Shuts down pipeline workers.
For testing use only. Does not properly wait for pipeline workers to shut
down.
"""
self._state = PipelineState.CANCELLING
self._executor.shutdown()
self._state = PipelineState.CANCELLED