Skip to content

Commit

Permalink
Add WriteParquetBatched (#23030)
Browse files Browse the repository at this point in the history
* initial draft

* reuse parquet sink for writing rows and pa.Tables

* lint

* fix imports

* cr - rename value to table

* add doc string and example test

* fix doc strings

* specify doctest group to separate tests

Co-authored-by: Evan Sadler <easadler@gmail.com>
  • Loading branch information
peridotml and easadler authored Nov 1, 2022
1 parent 09cab57 commit ef7c0c9
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 54 deletions.
234 changes: 183 additions & 51 deletions sdks/python/apache_beam/io/parquetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from apache_beam.transforms import DoFn
from apache_beam.transforms import ParDo
from apache_beam.transforms import PTransform
from apache_beam.transforms import window

try:
import pyarrow as pa
Expand All @@ -60,7 +61,8 @@
'ReadAllFromParquet',
'ReadFromParquetBatched',
'ReadAllFromParquetBatched',
'WriteToParquet'
'WriteToParquet',
'WriteToParquetBatched'
]


Expand All @@ -83,6 +85,67 @@ def process(self, table, with_filename=False):
yield row


class _RowDictionariesToArrowTable(DoFn):
""" A DoFn that consumes python dictionarys and yields a pyarrow table."""
def __init__(
self,
schema,
row_group_buffer_size=64 * 1024 * 1024,
record_batch_size=1000):
self._schema = schema
self._row_group_buffer_size = row_group_buffer_size
self._buffer = [[] for _ in range(len(schema.names))]
self._buffer_size = record_batch_size
self._record_batches = []
self._record_batches_byte_size = 0

def process(self, row):
if len(self._buffer[0]) >= self._buffer_size:
self._flush_buffer()

if self._record_batches_byte_size >= self._row_group_buffer_size:
table = self._create_table()
yield table

# reorder the data in columnar format.
for i, n in enumerate(self._schema.names):
self._buffer[i].append(row[n])

def finish_bundle(self):
if len(self._buffer[0]) > 0:
self._flush_buffer()
if self._record_batches_byte_size > 0:
table = self._create_table()
yield window.GlobalWindows.windowed_value_at_end_of_window(table)

def display_data(self):
res = super().display_data()
res['row_group_buffer_size'] = str(self._row_group_buffer_size)
res['buffer_size'] = str(self._buffer_size)

return res

def _create_table(self):
table = pa.Table.from_batches(self._record_batches, schema=self._schema)
self._record_batches = []
self._record_batches_byte_size = 0
return table

def _flush_buffer(self):
arrays = [[] for _ in range(len(self._schema.names))]
for x, y in enumerate(self._buffer):
arrays[x] = pa.array(y, type=self._schema.types[x])
self._buffer[x] = []
rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
self._record_batches.append(rb)
size = 0
for x in arrays:
for b in x.buffers():
if b is not None:
size = size + b.size
self._record_batches_byte_size = self._record_batches_byte_size + size


class ReadFromParquetBatched(PTransform):
"""A :class:`~apache_beam.transforms.ptransform.PTransform` for reading
Parquet files as a `PCollection` of `pyarrow.Table`. This `PTransform` is
Expand Down Expand Up @@ -453,13 +516,127 @@ def __init__(
A WriteToParquet transform usable for writing.
"""
super().__init__()
self._schema = schema
self._row_group_buffer_size = row_group_buffer_size
self._record_batch_size = record_batch_size

self._sink = \
_create_parquet_sink(
file_path_prefix,
schema,
codec,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type
)

def expand(self, pcoll):
return pcoll | ParDo(
_RowDictionariesToArrowTable(
self._schema, self._row_group_buffer_size,
self._record_batch_size)) | Write(self._sink)

def display_data(self):
return {
'sink_dd': self._sink,
'row_group_buffer_size': str(self._row_group_buffer_size)
}


class WriteToParquetBatched(PTransform):
"""A ``PTransform`` for writing parquet files from a `PCollection` of
`pyarrow.Table`.
This ``PTransform`` is currently experimental. No backward-compatibility
guarantees.
"""
def __init__(
self,
file_path_prefix,
schema=None,
codec='none',
use_deprecated_int96_timestamps=False,
use_compliant_nested_type=False,
file_name_suffix='',
num_shards=0,
shard_name_template=None,
mime_type='application/x-parquet',
):
"""Initialize a WriteToParquetBatched transform.
Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of
records. Each record is a pa.Table Schema must be specified like the
example below.
.. testsetup:: batched
from tempfile import NamedTemporaryFile
import glob
import os
import pyarrow
filename = NamedTemporaryFile(delete=False).name
.. testcode:: batched
table = pyarrow.Table.from_pylist([{'name': 'foo', 'age': 10},
{'name': 'bar', 'age': 20}])
with beam.Pipeline() as p:
records = p | 'Read' >> beam.Create([table])
_ = records | 'Write' >> beam.io.WriteToParquetBatched(filename,
pyarrow.schema(
[('name', pyarrow.string()), ('age', pyarrow.int64())]
)
)
.. testcleanup:: batched
for output in glob.glob('{}*'.format(filename)):
os.remove(output)
For more information on supported types and schema, please see the pyarrow
document.
Args:
file_path_prefix: The file path to write to. The files written will begin
with this prefix, followed by a shard identifier (see num_shards), and
end in a common extension, if given by file_name_suffix. In most cases,
only this argument is specified and num_shards, shard_name_template, and
file_name_suffix use default values.
schema: The schema to use, as type of ``pyarrow.Schema``.
codec: The codec to use for block-level compression. Any string supported
by the pyarrow specification is accepted.
use_deprecated_int96_timestamps: Write nanosecond resolution timestamps to
INT96 Parquet format. Defaults to False.
use_compliant_nested_type: Write compliant Parquet nested type (lists).
file_name_suffix: Suffix for the files written.
num_shards: The number of files (shards) used for output. If not set, the
service will decide on the optimal number of shards.
Constraining the number of shards is likely to reduce
the performance of a pipeline. Setting this value is not recommended
unless you require a specific number of output files.
shard_name_template: A template string containing placeholders for
the shard number and shard count. When constructing a filename for a
particular shard number, the upper-case letters 'S' and 'N' are
replaced with the 0-padded shard number and shard count respectively.
This argument can be '' in which case it behaves as if num_shards was
set to 1 and only one file will be generated. The default pattern used
is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template.
mime_type: The MIME type to use for the produced files, if the filesystem
supports specifying MIME types.
Returns:
A WriteToParquetBatched transform usable for writing.
"""
super().__init__()
self._sink = \
_create_parquet_sink(
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
Expand All @@ -479,8 +656,6 @@ def _create_parquet_sink(
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
Expand All @@ -492,8 +667,6 @@ def _create_parquet_sink(
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
Expand All @@ -504,14 +677,12 @@ def _create_parquet_sink(


class _ParquetSink(filebasedsink.FileBasedSink):
"""A sink for parquet files."""
"""A sink for parquet files from batches."""
def __init__(
self,
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
Expand All @@ -535,18 +706,13 @@ def __init__(
"Due to ARROW-9424, writing with LZ4 compression is not supported in "
"pyarrow 1.x, please use a different pyarrow version or a different "
f"codec. Your pyarrow version: {pa.__version__}")
self._row_group_buffer_size = row_group_buffer_size
self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps
if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4:
raise ValueError(
"With ARROW-11497, use_compliant_nested_type is only supported in "
"pyarrow version >= 4.x, please use a different pyarrow version. "
f"Your pyarrow version: {pa.__version__}")
self._use_compliant_nested_type = use_compliant_nested_type
self._buffer = [[] for _ in range(len(schema.names))]
self._buffer_size = record_batch_size
self._record_batches = []
self._record_batches_byte_size = 0
self._file_handle = None

def open(self, temp_path):
Expand All @@ -564,23 +730,10 @@ def open(self, temp_path):
use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps,
use_compliant_nested_type=self._use_compliant_nested_type)

def write_record(self, writer, value):
if len(self._buffer[0]) >= self._buffer_size:
self._flush_buffer()

if self._record_batches_byte_size >= self._row_group_buffer_size:
self._write_batches(writer)

# reorder the data in columnar format.
for i, n in enumerate(self._schema.names):
self._buffer[i].append(value[n])
def write_record(self, writer, table: pa.Table):
writer.write_table(table)

def close(self, writer):
if len(self._buffer[0]) > 0:
self._flush_buffer()
if self._record_batches_byte_size > 0:
self._write_batches(writer)

writer.close()
if self._file_handle:
self._file_handle.close()
Expand All @@ -590,25 +743,4 @@ def display_data(self):
res = super().display_data()
res['codec'] = str(self._codec)
res['schema'] = str(self._schema)
res['row_group_buffer_size'] = str(self._row_group_buffer_size)
return res

def _write_batches(self, writer):
table = pa.Table.from_batches(self._record_batches, schema=self._schema)
self._record_batches = []
self._record_batches_byte_size = 0
writer.write_table(table)

def _flush_buffer(self):
arrays = [[] for _ in range(len(self._schema.names))]
for x, y in enumerate(self._buffer):
arrays[x] = pa.array(y, type=self._schema.types[x])
self._buffer[x] = []
rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
self._record_batches.append(rb)
size = 0
for x in arrays:
for b in x.buffers():
if b is not None:
size = size + b.size
self._record_batches_byte_size = self._record_batches_byte_size + size
36 changes: 33 additions & 3 deletions sdks/python/apache_beam/io/parquetio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from apache_beam.io.parquetio import ReadFromParquet
from apache_beam.io.parquetio import ReadFromParquetBatched
from apache_beam.io.parquetio import WriteToParquet
from apache_beam.io.parquetio import WriteToParquetBatched
from apache_beam.io.parquetio import _create_parquet_sink
from apache_beam.io.parquetio import _create_parquet_source
from apache_beam.testing.test_pipeline import TestPipeline
Expand Down Expand Up @@ -284,8 +285,6 @@ def test_sink_display_data(self):
file_name,
self.SCHEMA,
'none',
1024 * 1024,
1000,
False,
False,
'.end',
Expand All @@ -299,7 +298,6 @@ def test_sink_display_data(self):
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('row_group_buffer_size', str(1024 * 1024)),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
Expand All @@ -308,6 +306,7 @@ def test_write_display_data(self):
file_name = 'some_parquet_sink'
write = WriteToParquet(file_name, self.SCHEMA)
dd = DisplayData.create_from(write)

expected_items = [
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
Expand All @@ -319,6 +318,21 @@ def test_write_display_data(self):
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

def test_write_batched_display_data(self):
file_name = 'some_parquet_sink'
write = WriteToParquetBatched(file_name, self.SCHEMA)
dd = DisplayData.create_from(write)

expected_items = [
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

def test_sink_transform_int96(self):
with tempfile.NamedTemporaryFile() as dst:
path = dst.name
Expand Down Expand Up @@ -348,6 +362,22 @@ def test_sink_transform(self):
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))

def test_sink_transform_batched(self):
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + "tmp_filename")
with TestPipeline() as p:
_ = p \
| Create([self._records_as_arrow()]) \
| WriteToParquetBatched(
path, self.SCHEMA, num_shards=1, shard_name_template='')
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| ReadFromParquet(path) \
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))

def test_sink_transform_compliant_nested_type(self):
if ARROW_MAJOR_VERSION < 4:
return unittest.skip(
Expand Down

0 comments on commit ef7c0c9

Please sign in to comment.