From 67ccf9fa8bce805e3b496cd36bee551991a0c5d7 Mon Sep 17 00:00:00 2001 From: Evan Sadler Date: Sun, 11 Sep 2022 14:07:48 -0400 Subject: [PATCH] reuse parquet sink for writing rows and pa.Tables --- sdks/python/apache_beam/io/parquetio.py | 267 ++++++++----------- sdks/python/apache_beam/io/parquetio_test.py | 19 +- 2 files changed, 128 insertions(+), 158 deletions(-) diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py index c5c83dcb81e8e..23da6de73b9ab 100644 --- a/sdks/python/apache_beam/io/parquetio.py +++ b/sdks/python/apache_beam/io/parquetio.py @@ -43,6 +43,8 @@ from apache_beam.transforms import DoFn from apache_beam.transforms import ParDo from apache_beam.transforms import PTransform +from apache_beam.transforms.util import BatchElements +from apache_beam.transforms import window try: import pyarrow as pa @@ -84,6 +86,67 @@ def process(self, table, with_filename=False): yield row +class _RowDictionariesToArrowTable(DoFn): + """ A DoFn that consumes python dictionarys and yields a pyarrow table.""" + def __init__( + self, + schema, + row_group_buffer_size=64 * 1024 * 1024, + record_batch_size=1000): + self._schema = schema + self._row_group_buffer_size = row_group_buffer_size + self._buffer = [[] for _ in range(len(schema.names))] + self._buffer_size = record_batch_size + self._record_batches = [] + self._record_batches_byte_size = 0 + + def process(self, row): + if len(self._buffer[0]) >= self._buffer_size: + self._flush_buffer() + + if self._record_batches_byte_size >= self._row_group_buffer_size: + table = self._create_table() + yield window.GlobalWindows.windowed_value_at_end_of_window(table) + + # reorder the data in columnar format. + for i, n in enumerate(self._schema.names): + self._buffer[i].append(row[n]) + + def finish_bundle(self): + if len(self._buffer[0]) > 0: + self._flush_buffer() + if self._record_batches_byte_size > 0: + table = self._create_table() + yield window.GlobalWindows.windowed_value_at_end_of_window(table) + + def display_data(self): + res = super().display_data() + res['row_group_buffer_size'] = str(self._row_group_buffer_size) + res['buffer_size'] = str(self._buffer_size) + + return res + + def _create_table(self): + table = pa.Table.from_batches(self._record_batches, schema=self._schema) + self._record_batches = [] + self._record_batches_byte_size = 0 + return table + + def _flush_buffer(self): + arrays = [[] for _ in range(len(self._schema.names))] + for x, y in enumerate(self._buffer): + arrays[x] = pa.array(y, type=self._schema.types[x]) + self._buffer[x] = [] + rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema) + self._record_batches.append(rb) + size = 0 + for x in arrays: + for b in x.buffers(): + if b is not None: + size = size + b.size + self._record_batches_byte_size = self._record_batches_byte_size + size + + class ReadFromParquetBatched(PTransform): """A :class:`~apache_beam.transforms.ptransform.PTransform` for reading Parquet files as a `PCollection` of `pyarrow.Table`. This `PTransform` is @@ -454,13 +517,15 @@ def __init__( A WriteToParquet transform usable for writing. """ super().__init__() + self._schema = schema + self._row_group_buffer_size = row_group_buffer_size + self._record_batch_size = record_batch_size + self._sink = \ _create_parquet_sink( file_path_prefix, schema, codec, - row_group_buffer_size, - record_batch_size, use_deprecated_int96_timestamps, use_compliant_nested_type, file_name_suffix, @@ -470,149 +535,16 @@ def __init__( ) def expand(self, pcoll): - return pcoll | Write(self._sink) - - def display_data(self): - return {'sink_dd': self._sink} - - -def _create_parquet_sink( - file_path_prefix, - schema, - codec, - row_group_buffer_size, - record_batch_size, - use_deprecated_int96_timestamps, - use_compliant_nested_type, - file_name_suffix, - num_shards, - shard_name_template, - mime_type): - return \ - _ParquetSink( - file_path_prefix, - schema, - codec, - row_group_buffer_size, - record_batch_size, - use_deprecated_int96_timestamps, - use_compliant_nested_type, - file_name_suffix, - num_shards, - shard_name_template, - mime_type - ) - - -class _ParquetSink(filebasedsink.FileBasedSink): - """A sink for parquet files.""" - def __init__( - self, - file_path_prefix, - schema, - codec, - row_group_buffer_size, - record_batch_size, - use_deprecated_int96_timestamps, - use_compliant_nested_type, - file_name_suffix, - num_shards, - shard_name_template, - mime_type): - super().__init__( - file_path_prefix, - file_name_suffix=file_name_suffix, - num_shards=num_shards, - shard_name_template=shard_name_template, - coder=None, - mime_type=mime_type, - # Compression happens at the block level using the supplied codec, and - # not at the file level. - compression_type=CompressionTypes.UNCOMPRESSED) - self._schema = schema - self._codec = codec - if ARROW_MAJOR_VERSION == 1 and self._codec.lower() == "lz4": - raise ValueError( - "Due to ARROW-9424, writing with LZ4 compression is not supported in " - "pyarrow 1.x, please use a different pyarrow version or a different " - f"codec. Your pyarrow version: {pa.__version__}") - self._row_group_buffer_size = row_group_buffer_size - self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps - if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4: - raise ValueError( - "With ARROW-11497, use_compliant_nested_type is only supported in " - "pyarrow version >= 4.x, please use a different pyarrow version. " - f"Your pyarrow version: {pa.__version__}") - self._use_compliant_nested_type = use_compliant_nested_type - self._buffer = [[] for _ in range(len(schema.names))] - self._buffer_size = record_batch_size - self._record_batches = [] - self._record_batches_byte_size = 0 - self._file_handle = None - - def open(self, temp_path): - self._file_handle = super().open(temp_path) - if ARROW_MAJOR_VERSION < 4: - return pq.ParquetWriter( - self._file_handle, - self._schema, - compression=self._codec, - use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps) - return pq.ParquetWriter( - self._file_handle, - self._schema, - compression=self._codec, - use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps, - use_compliant_nested_type=self._use_compliant_nested_type) - - def write_record(self, writer, value): - if len(self._buffer[0]) >= self._buffer_size: - self._flush_buffer() - - if self._record_batches_byte_size >= self._row_group_buffer_size: - self._write_batches(writer) - - # reorder the data in columnar format. - for i, n in enumerate(self._schema.names): - self._buffer[i].append(value[n]) - - def close(self, writer): - if len(self._buffer[0]) > 0: - self._flush_buffer() - if self._record_batches_byte_size > 0: - self._write_batches(writer) - - writer.close() - if self._file_handle: - self._file_handle.close() - self._file_handle = None + return pcoll | ParDo( + _RowDictionariesToArrowTable( + self._schema, self._row_group_buffer_size, + self._record_batch_size)) | Write(self._sink) def display_data(self): - res = super().display_data() - res['codec'] = str(self._codec) - res['schema'] = str(self._schema) - res['row_group_buffer_size'] = str(self._row_group_buffer_size) - return res - - def _write_batches(self, writer): - table = pa.Table.from_batches(self._record_batches, schema=self._schema) - self._record_batches = [] - self._record_batches_byte_size = 0 - writer.write_table(table) - - def _flush_buffer(self): - arrays = [[] for _ in range(len(self._schema.names))] - for x, y in enumerate(self._buffer): - arrays[x] = pa.array(y, type=self._schema.types[x]) - self._buffer[x] = [] - rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema) - self._record_batches.append(rb) - size = 0 - for x in arrays: - for b in x.buffers(): - if b is not None: - size = size + b.size - self._record_batches_byte_size = self._record_batches_byte_size + size + return { + 'sink_dd': self._sink, + 'row_group_buffer_size': str(self._row_group_buffer_size) + } class WriteToParquetBatched(PTransform): @@ -634,17 +566,18 @@ def __init__( mime_type='application/x-parquet', ): super().__init__() - self._sink = _BatchedParquetSink( - file_path_prefix=file_path_prefix, - schema=schema, - codec=codec, - use_deprecated_int96_timestamps=use_deprecated_int96_timestamps, - use_compliant_nested_type=use_compliant_nested_type, - file_name_suffix=file_name_suffix, - num_shards=num_shards, - shard_name_template=shard_name_template, - mime_type=mime_type, - ) + self._sink = \ + _create_parquet_sink( + file_path_prefix, + schema, + codec, + use_deprecated_int96_timestamps, + use_compliant_nested_type, + file_name_suffix, + num_shards, + shard_name_template, + mime_type + ) def expand(self, pcoll): return pcoll | Write(self._sink) @@ -653,7 +586,31 @@ def display_data(self): return {'sink_dd': self._sink} -class _BatchedParquetSink(filebasedsink.FileBasedSink): +def _create_parquet_sink( + file_path_prefix, + schema, + codec, + use_deprecated_int96_timestamps, + use_compliant_nested_type, + file_name_suffix, + num_shards, + shard_name_template, + mime_type): + return \ + _ParquetSink( + file_path_prefix, + schema, + codec, + use_deprecated_int96_timestamps, + use_compliant_nested_type, + file_name_suffix, + num_shards, + shard_name_template, + mime_type + ) + + +class _ParquetSink(filebasedsink.FileBasedSink): """A sink for parquet files from batches.""" def __init__( self, diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py index 065db351d086f..ddedcf19622fe 100644 --- a/sdks/python/apache_beam/io/parquetio_test.py +++ b/sdks/python/apache_beam/io/parquetio_test.py @@ -286,8 +286,6 @@ def test_sink_display_data(self): file_name, self.SCHEMA, 'none', - 1024 * 1024, - 1000, False, False, '.end', @@ -301,7 +299,6 @@ def test_sink_display_data(self): 'file_pattern', 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'), DisplayDataItemMatcher('codec', 'none'), - DisplayDataItemMatcher('row_group_buffer_size', str(1024 * 1024)), DisplayDataItemMatcher('compression', 'uncompressed') ] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) @@ -310,6 +307,7 @@ def test_write_display_data(self): file_name = 'some_parquet_sink' write = WriteToParquet(file_name, self.SCHEMA) dd = DisplayData.create_from(write) + expected_items = [ DisplayDataItemMatcher('codec', 'none'), DisplayDataItemMatcher('schema', str(self.SCHEMA)), @@ -321,6 +319,21 @@ def test_write_display_data(self): ] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) + def test_write_batched_display_data(self): + file_name = 'some_parquet_sink' + write = WriteToParquetBatched(file_name, self.SCHEMA) + dd = DisplayData.create_from(write) + + expected_items = [ + DisplayDataItemMatcher('codec', 'none'), + DisplayDataItemMatcher('schema', str(self.SCHEMA)), + DisplayDataItemMatcher( + 'file_pattern', + 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'), + DisplayDataItemMatcher('compression', 'uncompressed') + ] + hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) + def test_sink_transform_int96(self): with tempfile.NamedTemporaryFile() as dst: path = dst.name