Skip to content

Commit

Permalink
Allow ib.collect to work with non-interactive runners. (#32383)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Sep 3, 2024
1 parent f06df5d commit 2b02fd3
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 27 deletions.
16 changes: 13 additions & 3 deletions sdks/python/apache_beam/runners/interactive/interactive_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,12 @@ def show(


@progress_indicated
def collect(pcoll, n='inf', duration='inf', include_window_info=False):
def collect(
pcoll,
n='inf',
duration='inf',
include_window_info=False,
force_compute=False):
"""Materializes the elements from a PCollection into a Dataframe.
This reads each element from file and reads only the amount that it needs
Expand All @@ -889,6 +894,8 @@ def collect(pcoll, n='inf', duration='inf', include_window_info=False):
a string duration. Default 'inf'.
include_window_info: (optional) if True, appends the windowing information
to each row. Default False.
force_compute: (optional) if True, forces recomputation rather than using
cached PCollections
For example::
Expand Down Expand Up @@ -938,7 +945,7 @@ def collect(pcoll, n='inf', duration='inf', include_window_info=False):
user_pipeline, create_if_absent=True)

# If already computed, directly read the stream and return.
if pcoll in ie.current_env().computed_pcollections:
if pcoll in ie.current_env().computed_pcollections and not force_compute:
pcoll_name = find_pcoll_name(pcoll)
elements = list(
recording_manager.read(pcoll_name, pcoll, n, duration).read())
Expand All @@ -947,7 +954,10 @@ def collect(pcoll, n='inf', duration='inf', include_window_info=False):
include_window_info=include_window_info,
element_type=element_type)

recording = recording_manager.record([pcoll], max_n=n, max_duration=duration)
recording = recording_manager.record([pcoll],
max_n=n,
max_duration=duration,
force_compute=force_compute)

try:
elements = list(recording.stream(pcoll).read())
Expand Down
50 changes: 35 additions & 15 deletions sdks/python/apache_beam/runners/interactive/pipeline_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import apache_beam as beam
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PipelineVisitor
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import pipeline_instrument as instr
from apache_beam.testing.test_stream import TestStream


Expand Down Expand Up @@ -65,7 +67,6 @@ def __init__(self, pcolls, options=None):
# into a pipeline fragment that later run by the underlying runner.
self._runner_pipeline = self._build_runner_pipeline()
_, self._context = self._runner_pipeline.to_runner_api(return_context=True)
from apache_beam.runners.interactive import pipeline_instrument as instr
self._runner_pcoll_to_id = instr.pcoll_to_pcoll_id(
self._runner_pipeline, self._context)
# Correlate components in the runner pipeline to components in the user
Expand Down Expand Up @@ -104,23 +105,42 @@ def deduce_fragment(self):

def run(self, display_pipeline_graph=False, use_cache=True, blocking=False):
"""Shorthand to run the pipeline fragment."""
fragment = self.deduce_fragment()
from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
if not isinstance(self._runner_pipeline.runner, InteractiveRunner):
raise RuntimeError(
'Please specify InteractiveRunner when creating '
'the Beam pipeline to use this function.')
try:
preserved_skip_display = self._runner_pipeline.runner._skip_display
preserved_force_compute = self._runner_pipeline.runner._force_compute
preserved_blocking = self._runner_pipeline.runner._blocking
self._runner_pipeline.runner._skip_display = not display_pipeline_graph
self._runner_pipeline.runner._force_compute = not use_cache
self._runner_pipeline.runner._blocking = blocking
return self.deduce_fragment().run()
if isinstance(self._runner_pipeline.runner, InteractiveRunner):
preserved_skip_display = self._runner_pipeline.runner._skip_display
preserved_force_compute = self._runner_pipeline.runner._force_compute
preserved_blocking = self._runner_pipeline.runner._blocking
self._runner_pipeline.runner._skip_display = not display_pipeline_graph
self._runner_pipeline.runner._force_compute = not use_cache
self._runner_pipeline.runner._blocking = blocking
return fragment.run()
else:
pipeline_instrument = instr.build_pipeline_instrument(
fragment, self._runner_pipeline._options)
pipeline_instrument_proto = (
pipeline_instrument.instrumented_pipeline_proto())
if any(pcoll.is_bounded == beam_runner_api_pb2.IsBounded.UNBOUNDED
for pcoll in
pipeline_instrument_proto.components.pcollections.values()):
raise RuntimeError(
'Please specify InteractiveRunner when creating '
'the Beam pipeline to use this function '
'on unbouded PCollections.')
result = beam.pipeline.Pipeline.from_runner_api(
pipeline_instrument_proto,
self._runner_pipeline.runner,
self._runner_pipeline._options).run()
result.wait_until_finish()
ie.current_env().mark_pcollection_computed(
pipeline_instrument.cached_pcolls)
return result
finally:
self._runner_pipeline.runner._skip_display = preserved_skip_display
self._runner_pipeline.runner._force_compute = preserved_force_compute
self._runner_pipeline.runner._blocking = preserved_blocking
if isinstance(self._runner_pipeline.runner, InteractiveRunner):
self._runner_pipeline.runner._skip_display = preserved_skip_display
self._runner_pipeline.runner._force_compute = preserved_force_compute
self._runner_pipeline.runner._blocking = preserved_blocking

def _build_runner_pipeline(self):
runner_pipeline = beam.pipeline.Pipeline.from_runner_api(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,6 @@ def Foo(pcoll):
result = pf.PipelineFragment([pc]).run()
self.assertEqual([0, 6, 12, 18, 24], list(result.get(pc)))

def test_ib_show_without_using_ir(self):
"""Tests that ib.show is called when ir is not specified.
"""
p = beam.Pipeline()
print_words = p | beam.Create(["this is a test"]) | beam.Map(print)
with self.assertRaises(RuntimeError):
ib.show(print_words)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from apache_beam.runners.interactive import pipeline_fragment as pf
from apache_beam.runners.interactive import utils
from apache_beam.runners.interactive.caching.cacheable import CacheKey
from apache_beam.runners.interactive.options import capture_control
from apache_beam.runners.runner import PipelineState

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -384,11 +385,17 @@ def record(
self,
pcolls: List[beam.pvalue.PCollection],
max_n: int,
max_duration: Union[int, str]) -> Recording:
max_duration: Union[int, str],
force_compute: bool = False) -> Recording:
# noqa: F821

"""Records the given PCollections."""

if not ie.current_env().options.enable_recording_replay:
capture_control.evict_captured_data()
if force_compute:
ie.current_env().evict_computed_pcollections()

# Assert that all PCollection come from the same user_pipeline.
for pcoll in pcolls:
assert pcoll.pipeline is self.user_pipeline, (
Expand Down

0 comments on commit 2b02fd3

Please sign in to comment.