diff --git a/sdks/python/apache_beam/dataframe/io.py b/sdks/python/apache_beam/dataframe/io.py index 176e78f385bee..2eebfaf522a82 100644 --- a/sdks/python/apache_beam/dataframe/io.py +++ b/sdks/python/apache_beam/dataframe/io.py @@ -521,7 +521,7 @@ def expand(self, pcoll): return pcoll | fileio.WriteToFiles( path=dir, file_naming=fileio.default_file_naming(name), - sink=_WriteToPandasFileSink( + sink=lambda _: _WriteToPandasFileSink( self.writer, self.args, self.kwargs, self.incremental, self.binary)) diff --git a/sdks/python/apache_beam/dataframe/io_test.py b/sdks/python/apache_beam/dataframe/io_test.py index 9d484ccb0c4b8..374eb0c03f7da 100644 --- a/sdks/python/apache_beam/dataframe/io_test.py +++ b/sdks/python/apache_beam/dataframe/io_test.py @@ -21,7 +21,9 @@ import platform import shutil import tempfile +import typing import unittest +from datetime import datetime from io import BytesIO from io import StringIO @@ -38,6 +40,11 @@ from apache_beam.testing.util import assert_that +class MyRow(typing.NamedTuple): + timestamp: int + value: int + + @unittest.skipIf(platform.system() == 'Windows', 'BEAM-10929') class IOTest(unittest.TestCase): def setUp(self): @@ -56,12 +63,14 @@ def temp_dir(self, files=None): fout.write(contents) return dir + os.path.sep - def read_all_lines(self, pattern): + def read_all_lines(self, pattern, delete=False): for path in glob.glob(pattern): with open(path) as fin: # TODO(Py3): yield from for line in fin: yield line.rstrip('\n') + if delete: + os.remove(path) def test_read_write_csv(self): input = self.temp_dir({'1.csv': 'a,b\n1,2\n', '2.csv': 'a,b\n3,4\n'}) @@ -304,6 +313,36 @@ def test_file_not_found(self): with beam.Pipeline() as p: _ = p | io.read_csv('/tmp/fake_dir/**') + def test_windowed_write(self): + output = self.temp_dir() + with beam.Pipeline() as p: + pc = ( + p | beam.Create([MyRow(timestamp=i, value=i % 3) for i in range(20)]) + | beam.Map(lambda v: beam.window.TimestampedValue(v, v.timestamp)). + with_output_types(MyRow) + | beam.WindowInto( + beam.window.FixedWindows(10)).with_output_types(MyRow)) + + deferred_df = convert.to_dataframe(pc) + deferred_df.to_csv(output + 'out.csv', index=False) + + first_window_files = ( + f'{output}out.csv-' + f'{datetime.utcfromtimestamp(0).isoformat()}*') + self.assertCountEqual( + ['timestamp,value'] + [f'{i},{i%3}' for i in range(10)], + set(self.read_all_lines(first_window_files, delete=True))) + + second_window_files = ( + f'{output}out.csv-' + f'{datetime.utcfromtimestamp(10).isoformat()}*') + self.assertCountEqual( + ['timestamp,value'] + [f'{i},{i%3}' for i in range(10, 20)], + set(self.read_all_lines(second_window_files, delete=True))) + + # Check that we've read (and removed) every output file + self.assertEqual(len(glob.glob(f'{output}out.csv*')), 0) + if __name__ == '__main__': unittest.main()