diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py index acbf1e23f2034..c5c83dcb81e8e 100644 --- a/sdks/python/apache_beam/io/parquetio.py +++ b/sdks/python/apache_beam/io/parquetio.py @@ -60,7 +60,8 @@ 'ReadAllFromParquet', 'ReadFromParquetBatched', 'ReadAllFromParquetBatched', - 'WriteToParquet' + 'WriteToParquet', + 'WriteToParquetBatched' ] @@ -612,3 +613,111 @@ def _flush_buffer(self): if b is not None: size = size + b.size self._record_batches_byte_size = self._record_batches_byte_size + size + + +class WriteToParquetBatched(PTransform): + """Initialize a WriteToParquetBatched transform. + + Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of + batches. Each batch is a pa.Table. Schema must be specified like the example below. + """ + def __init__( + self, + file_path_prefix, + schema=None, + codec='none', + use_deprecated_int96_timestamps=False, + use_compliant_nested_type=False, + file_name_suffix='', + num_shards=0, + shard_name_template=None, + mime_type='application/x-parquet', + ): + super().__init__() + self._sink = _BatchedParquetSink( + file_path_prefix=file_path_prefix, + schema=schema, + codec=codec, + use_deprecated_int96_timestamps=use_deprecated_int96_timestamps, + use_compliant_nested_type=use_compliant_nested_type, + file_name_suffix=file_name_suffix, + num_shards=num_shards, + shard_name_template=shard_name_template, + mime_type=mime_type, + ) + + def expand(self, pcoll): + return pcoll | Write(self._sink) + + def display_data(self): + return {'sink_dd': self._sink} + + +class _BatchedParquetSink(filebasedsink.FileBasedSink): + """A sink for parquet files from batches.""" + def __init__( + self, + file_path_prefix, + schema, + codec, + use_deprecated_int96_timestamps, + use_compliant_nested_type, + file_name_suffix, + num_shards, + shard_name_template, + mime_type): + super().__init__( + file_path_prefix, + file_name_suffix=file_name_suffix, + num_shards=num_shards, + shard_name_template=shard_name_template, + coder=None, + mime_type=mime_type, + # Compression happens at the block level using the supplied codec, and + # not at the file level. + compression_type=CompressionTypes.UNCOMPRESSED) + self._schema = schema + self._codec = codec + if ARROW_MAJOR_VERSION == 1 and self._codec.lower() == "lz4": + raise ValueError( + "Due to ARROW-9424, writing with LZ4 compression is not supported in " + "pyarrow 1.x, please use a different pyarrow version or a different " + f"codec. Your pyarrow version: {pa.__version__}") + self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps + if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4: + raise ValueError( + "With ARROW-11497, use_compliant_nested_type is only supported in " + "pyarrow version >= 4.x, please use a different pyarrow version. " + f"Your pyarrow version: {pa.__version__}") + self._use_compliant_nested_type = use_compliant_nested_type + self._file_handle = None + + def open(self, temp_path): + self._file_handle = super().open(temp_path) + if ARROW_MAJOR_VERSION < 4: + return pq.ParquetWriter( + self._file_handle, + self._schema, + compression=self._codec, + use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps) + return pq.ParquetWriter( + self._file_handle, + self._schema, + compression=self._codec, + use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps, + use_compliant_nested_type=self._use_compliant_nested_type) + + def write_record(self, writer, value): + writer.write_table(value) + + def close(self, writer): + writer.close() + if self._file_handle: + self._file_handle.close() + self._file_handle = None + + def display_data(self): + res = super().display_data() + res['codec'] = str(self._codec) + res['schema'] = str(self._schema) + return res diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py index 454a45493c4af..065db351d086f 100644 --- a/sdks/python/apache_beam/io/parquetio_test.py +++ b/sdks/python/apache_beam/io/parquetio_test.py @@ -40,6 +40,8 @@ from apache_beam.io.parquetio import ReadFromParquet from apache_beam.io.parquetio import ReadFromParquetBatched from apache_beam.io.parquetio import WriteToParquet +from apache_beam.io.parquetio import WriteToParquetBatched + from apache_beam.io.parquetio import _create_parquet_sink from apache_beam.io.parquetio import _create_parquet_source from apache_beam.testing.test_pipeline import TestPipeline @@ -348,6 +350,22 @@ def test_sink_transform(self): | Map(json.dumps) assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) + def test_sink_transform_batched(self): + with TemporaryDirectory() as tmp_dirname: + path = os.path.join(tmp_dirname + "tmp_filename") + with TestPipeline() as p: + _ = p \ + | Create([self._records_as_arrow()]) \ + | WriteToParquetBatched( + path, self.SCHEMA, num_shards=1, shard_name_template='') + with TestPipeline() as p: + # json used for stable sortability + readback = \ + p \ + | ReadFromParquet(path) \ + | Map(json.dumps) + assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) + def test_sink_transform_compliant_nested_type(self): if ARROW_MAJOR_VERSION < 4: return unittest.skip(