diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index a6fd3b184d4c..297915023c5c 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -3683,8 +3683,16 @@ def _extract_input_pvalues(self, pvalueish): return pvalueish, pvalueish def expand(self, pcolls): + windowing = self.get_windowing(pcolls) for pcoll in pcolls: self._check_pcollection(pcoll) + if pcoll.windowing != windowing: + _LOGGER.warning( + 'All input pcollections must have the same window. Windowing for ' + 'flatten set to %s, windowing of pcoll %s set to %s', + windowing, + pcoll, + pcoll.windowing) is_bounded = all(pcoll.is_bounded for pcoll in pcolls) return pvalue.PCollection(self.pipeline, is_bounded=is_bounded) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index a60974ceb706..4fbeaa2ee97a 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -24,6 +24,9 @@ import pytest import apache_beam as beam +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.transforms.window import FixedWindows class TestDoFn1(beam.DoFn): @@ -136,6 +139,37 @@ def process(self, element): self.assertEqual(p6.is_bounded, False) +class FlattenTest(unittest.TestCase): + def test_flatten_identical_windows(self): + with beam.testing.test_pipeline.TestPipeline() as p: + source1 = p | "c1" >> beam.Create( + [1, 2, 3, 4, 5]) | "w1" >> beam.WindowInto(FixedWindows(100)) + source2 = p | "c2" >> beam.Create([6, 7, 8]) | "w2" >> beam.WindowInto( + FixedWindows(100)) + source3 = p | "c3" >> beam.Create([9, 10]) | "w3" >> beam.WindowInto( + FixedWindows(100)) + out = (source1, source2, source3) | "flatten" >> beam.Flatten() + assert_that(out, equal_to([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) + + def test_flatten_no_windows(self): + with beam.testing.test_pipeline.TestPipeline() as p: + source1 = p | "c1" >> beam.Create([1, 2, 3, 4, 5]) + source2 = p | "c2" >> beam.Create([6, 7, 8]) + source3 = p | "c3" >> beam.Create([9, 10]) + out = (source1, source2, source3) | "flatten" >> beam.Flatten() + assert_that(out, equal_to([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) + + def test_flatten_mismatched_windows(self): + with beam.testing.test_pipeline.TestPipeline() as p: + source1 = p | "c1" >> beam.Create( + [1, 2, 3, 4, 5]) | "w1" >> beam.WindowInto(FixedWindows(25)) + source2 = p | "c2" >> beam.Create([6, 7, 8]) | "w2" >> beam.WindowInto( + FixedWindows(100)) + source3 = p | "c3" >> beam.Create([9, 10]) | "w3" >> beam.WindowInto( + FixedWindows(100)) + _ = (source1, source2, source3) | "flatten" >> beam.Flatten() + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()