Skip to content

Commit

Permalink
Ensure flatten windows match (apache#30410)
Browse files Browse the repository at this point in the history
* Ensure flatten windows match

* Downgrade to warning

* Lint
  • Loading branch information
damccorm authored Feb 27, 2024
1 parent 2a84a20 commit b7a58bf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3683,8 +3683,16 @@ def _extract_input_pvalues(self, pvalueish):
return pvalueish, pvalueish

def expand(self, pcolls):
windowing = self.get_windowing(pcolls)
for pcoll in pcolls:
self._check_pcollection(pcoll)
if pcoll.windowing != windowing:
_LOGGER.warning(
'All input pcollections must have the same window. Windowing for '
'flatten set to %s, windowing of pcoll %s set to %s',
windowing,
pcoll,
pcoll.windowing)
is_bounded = all(pcoll.is_bounded for pcoll in pcolls)
return pvalue.PCollection(self.pipeline, is_bounded=is_bounded)

Expand Down
34 changes: 34 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import pytest

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.window import FixedWindows


class TestDoFn1(beam.DoFn):
Expand Down Expand Up @@ -136,6 +139,37 @@ def process(self, element):
self.assertEqual(p6.is_bounded, False)


class FlattenTest(unittest.TestCase):
def test_flatten_identical_windows(self):
with beam.testing.test_pipeline.TestPipeline() as p:
source1 = p | "c1" >> beam.Create(
[1, 2, 3, 4, 5]) | "w1" >> beam.WindowInto(FixedWindows(100))
source2 = p | "c2" >> beam.Create([6, 7, 8]) | "w2" >> beam.WindowInto(
FixedWindows(100))
source3 = p | "c3" >> beam.Create([9, 10]) | "w3" >> beam.WindowInto(
FixedWindows(100))
out = (source1, source2, source3) | "flatten" >> beam.Flatten()
assert_that(out, equal_to([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))

def test_flatten_no_windows(self):
with beam.testing.test_pipeline.TestPipeline() as p:
source1 = p | "c1" >> beam.Create([1, 2, 3, 4, 5])
source2 = p | "c2" >> beam.Create([6, 7, 8])
source3 = p | "c3" >> beam.Create([9, 10])
out = (source1, source2, source3) | "flatten" >> beam.Flatten()
assert_that(out, equal_to([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))

def test_flatten_mismatched_windows(self):
with beam.testing.test_pipeline.TestPipeline() as p:
source1 = p | "c1" >> beam.Create(
[1, 2, 3, 4, 5]) | "w1" >> beam.WindowInto(FixedWindows(25))
source2 = p | "c2" >> beam.Create([6, 7, 8]) | "w2" >> beam.WindowInto(
FixedWindows(100))
source3 = p | "c3" >> beam.Create([9, 10]) | "w3" >> beam.WindowInto(
FixedWindows(100))
_ = (source1, source2, source3) | "flatten" >> beam.Flatten()


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit b7a58bf

Please sign in to comment.