From 4ee26065d9b4ea074d3d59d38cc4fbc5aa9dbfa5 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 12 Sep 2024 11:56:41 -0700 Subject: [PATCH] Accept runner and options in ib.collect. (#32434) --- .../display/pcoll_visualization_test.py | 2 +- .../runners/interactive/interactive_beam.py | 12 +++++++- .../non_interactive_runner_test.py | 30 +++++++++++++++++++ .../runners/interactive/pipeline_fragment.py | 12 ++++---- .../runners/interactive/recording_manager.py | 17 +++++++++-- 5 files changed, 64 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py index d34b966b0efa..7fc76feb7494 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py +++ b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization_test.py @@ -66,7 +66,7 @@ def setUp(self): ie.current_env().track_user_pipelines() recording_manager = RecordingManager(self._p) - recording = recording_manager.record([self._pcoll], 5, 5) + recording = recording_manager.record([self._pcoll], max_n=5, max_duration=5) self._stream = recording.stream(self._pcoll) def test_pcoll_visualization_generate_unique_display_id(self): diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index 5c76f9c228c8..0e170eb0f508 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -880,6 +880,8 @@ def collect( n='inf', duration='inf', include_window_info=False, + runner=None, + options=None, force_compute=False, force_tuple=False): """Materializes the elements from a PCollection into a Dataframe. @@ -896,6 +898,9 @@ def collect( a string duration. Default 'inf'. include_window_info: (optional) if True, appends the windowing information to each row. Default False. + runner: (optional) the runner with which to compute the results + options: (optional) any additional pipeline options to use to compute the + results force_compute: (optional) if True, forces recomputation rather than using cached PCollections force_tuple: (optional) if True, return a 1-tuple or results rather than @@ -969,7 +974,12 @@ def as_pcollection(pcoll_or_df): uncomputed = set(pcolls) - set(computed.keys()) if uncomputed: recording = recording_manager.record( - uncomputed, max_n=n, max_duration=duration, force_compute=force_compute) + uncomputed, + max_n=n, + max_duration=duration, + runner=runner, + options=options, + force_compute=force_compute) try: for pcoll in uncomputed: diff --git a/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py index 47adf7b36b33..f7fd052fecc4 100644 --- a/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py +++ b/sdks/python/apache_beam/runners/interactive/non_interactive_runner_test.py @@ -257,6 +257,36 @@ def test_dataframes_same_cell_twice(self): df_expected['cube'], ib.collect(df['cube'], n=10).reset_index(drop=True)) + @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]") + def test_new_runner_and_options(self): + class MyRunner(beam.runners.PipelineRunner): + run_count = 0 + + @classmethod + def run_pipeline(cls, pipeline, options): + assert options._all_options['my_option'] == 123 + cls.run_count += 1 + return direct_runner.DirectRunner().run_pipeline(pipeline, options) + + clear_side_effect() + p = beam.Pipeline(direct_runner.DirectRunner()) + + # Initial collection runs the pipeline. + pcoll1 = p | beam.Create(['a', 'b', 'c']) | beam.Map(cause_side_effect) + collected1 = ib.collect(pcoll1) + self.assertEqual(set(collected1[0]), set(['a', 'b', 'c'])) + self.assertEqual(count_side_effects('a'), 1) + + # Using the PCollection uses the cache with a different runner and options. + pcoll2 = pcoll1 | beam.Map(str.upper) + collected2 = ib.collect( + pcoll2, + runner=MyRunner(), + options=beam.options.pipeline_options.PipelineOptions(my_option=123)) + self.assertEqual(set(collected2[0]), set(['A', 'B', 'C'])) + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(MyRunner.run_count, 1) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py b/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py index 5b385d3f8a0f..20dee2b71163 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py @@ -34,7 +34,7 @@ class PipelineFragment(object): A pipeline fragment is built from the original pipeline definition to include only PTransforms that are necessary to produce the given PCollections. """ - def __init__(self, pcolls, options=None): + def __init__(self, pcolls, options=None, runner=None): """Constructor of PipelineFragment. Args: @@ -42,6 +42,8 @@ def __init__(self, pcolls, options=None): fragment for. options: (PipelineOptions) the pipeline options for the implicit pipeline run. + runner: (Runner) the pipeline runner for the implicit + pipeline run. """ assert len(pcolls) > 0, ( 'Need at least 1 PCollection as the target data to build a pipeline ' @@ -61,6 +63,7 @@ def __init__(self, pcolls, options=None): 'given and cannot be used to build a pipeline fragment that produces ' 'the given PCollections.'.format(pcoll)) self._options = options + self._runner = runner # A copied pipeline instance for modification without changing the user # pipeline instance held by the end user. This instance can be processed @@ -98,7 +101,7 @@ def deduce_fragment(self): """Deduce the pipeline fragment as an apache_beam.Pipeline instance.""" fragment = beam.pipeline.Pipeline.from_runner_api( self._runner_pipeline.to_runner_api(), - self._runner_pipeline.runner, + self._runner or self._runner_pipeline.runner, self._options) ie.current_env().add_derived_pipeline(self._runner_pipeline, fragment) return fragment @@ -129,9 +132,8 @@ def run(self, display_pipeline_graph=False, use_cache=True, blocking=False): 'the Beam pipeline to use this function ' 'on unbouded PCollections.') result = beam.pipeline.Pipeline.from_runner_api( - pipeline_instrument_proto, - self._runner_pipeline.runner, - self._runner_pipeline._options).run() + pipeline_instrument_proto, fragment.runner, + fragment._options).run() result.wait_until_finish() ie.current_env().mark_pcollection_computed( pipeline_instrument.cached_pcolls) diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index cb28a61a95f1..6811d3e0d345 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -28,7 +28,9 @@ import apache_beam as beam from apache_beam.dataframe.frame_base import DeferredBase +from apache_beam.options import pipeline_options from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners import runner from apache_beam.runners.interactive import background_caching_job as bcj from apache_beam.runners.interactive import interactive_environment as ie from apache_beam.runners.interactive import interactive_runner as ir @@ -384,8 +386,11 @@ def record_pipeline(self) -> bool: def record( self, pcolls: List[beam.pvalue.PCollection], + *, max_n: int, max_duration: Union[int, str], + runner: runner.PipelineRunner = None, + options: pipeline_options.PipelineOptions = None, force_compute: bool = False) -> Recording: # noqa: F821 @@ -427,12 +432,20 @@ def record( # incomplete. self._clear() + merged_options = pipeline_options.PipelineOptions( + **{ + **self.user_pipeline.options.get_all_options( + drop_default=True, retain_unknown_options=True), + **options.get_all_options( + drop_default=True, retain_unknown_options=True) + }) if options else self.user_pipeline.options + cache_path = ie.current_env().options.cache_root is_remote_run = cache_path and ie.current_env( ).options.cache_root.startswith('gs://') pf.PipelineFragment( - list(uncomputed_pcolls), - self.user_pipeline.options).run(blocking=is_remote_run) + list(uncomputed_pcolls), merged_options, + runner=runner).run(blocking=is_remote_run) result = ie.current_env().pipeline_result(self.user_pipeline) else: result = None