Skip to content

Commit

Permalink
reuse parquet sink for writing rows and pa.Tables
Browse files Browse the repository at this point in the history
  • Loading branch information
easadler committed Sep 11, 2022
1 parent e0acb72 commit 67ccf9f
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 158 deletions.
267 changes: 112 additions & 155 deletions sdks/python/apache_beam/io/parquetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from apache_beam.transforms import DoFn
from apache_beam.transforms import ParDo
from apache_beam.transforms import PTransform
from apache_beam.transforms.util import BatchElements
from apache_beam.transforms import window

try:
import pyarrow as pa
Expand Down Expand Up @@ -84,6 +86,67 @@ def process(self, table, with_filename=False):
yield row


class _RowDictionariesToArrowTable(DoFn):
""" A DoFn that consumes python dictionarys and yields a pyarrow table."""
def __init__(
self,
schema,
row_group_buffer_size=64 * 1024 * 1024,
record_batch_size=1000):
self._schema = schema
self._row_group_buffer_size = row_group_buffer_size
self._buffer = [[] for _ in range(len(schema.names))]
self._buffer_size = record_batch_size
self._record_batches = []
self._record_batches_byte_size = 0

def process(self, row):
if len(self._buffer[0]) >= self._buffer_size:
self._flush_buffer()

if self._record_batches_byte_size >= self._row_group_buffer_size:
table = self._create_table()
yield window.GlobalWindows.windowed_value_at_end_of_window(table)

# reorder the data in columnar format.
for i, n in enumerate(self._schema.names):
self._buffer[i].append(row[n])

def finish_bundle(self):
if len(self._buffer[0]) > 0:
self._flush_buffer()
if self._record_batches_byte_size > 0:
table = self._create_table()
yield window.GlobalWindows.windowed_value_at_end_of_window(table)

def display_data(self):
res = super().display_data()
res['row_group_buffer_size'] = str(self._row_group_buffer_size)
res['buffer_size'] = str(self._buffer_size)

return res

def _create_table(self):
table = pa.Table.from_batches(self._record_batches, schema=self._schema)
self._record_batches = []
self._record_batches_byte_size = 0
return table

def _flush_buffer(self):
arrays = [[] for _ in range(len(self._schema.names))]
for x, y in enumerate(self._buffer):
arrays[x] = pa.array(y, type=self._schema.types[x])
self._buffer[x] = []
rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
self._record_batches.append(rb)
size = 0
for x in arrays:
for b in x.buffers():
if b is not None:
size = size + b.size
self._record_batches_byte_size = self._record_batches_byte_size + size


class ReadFromParquetBatched(PTransform):
"""A :class:`~apache_beam.transforms.ptransform.PTransform` for reading
Parquet files as a `PCollection` of `pyarrow.Table`. This `PTransform` is
Expand Down Expand Up @@ -454,13 +517,15 @@ def __init__(
A WriteToParquet transform usable for writing.
"""
super().__init__()
self._schema = schema
self._row_group_buffer_size = row_group_buffer_size
self._record_batch_size = record_batch_size

self._sink = \
_create_parquet_sink(
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
Expand All @@ -470,149 +535,16 @@ def __init__(
)

def expand(self, pcoll):
return pcoll | Write(self._sink)

def display_data(self):
return {'sink_dd': self._sink}


def _create_parquet_sink(
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type):
return \
_ParquetSink(
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type
)


class _ParquetSink(filebasedsink.FileBasedSink):
"""A sink for parquet files."""
def __init__(
self,
file_path_prefix,
schema,
codec,
row_group_buffer_size,
record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type):
super().__init__(
file_path_prefix,
file_name_suffix=file_name_suffix,
num_shards=num_shards,
shard_name_template=shard_name_template,
coder=None,
mime_type=mime_type,
# Compression happens at the block level using the supplied codec, and
# not at the file level.
compression_type=CompressionTypes.UNCOMPRESSED)
self._schema = schema
self._codec = codec
if ARROW_MAJOR_VERSION == 1 and self._codec.lower() == "lz4":
raise ValueError(
"Due to ARROW-9424, writing with LZ4 compression is not supported in "
"pyarrow 1.x, please use a different pyarrow version or a different "
f"codec. Your pyarrow version: {pa.__version__}")
self._row_group_buffer_size = row_group_buffer_size
self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps
if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4:
raise ValueError(
"With ARROW-11497, use_compliant_nested_type is only supported in "
"pyarrow version >= 4.x, please use a different pyarrow version. "
f"Your pyarrow version: {pa.__version__}")
self._use_compliant_nested_type = use_compliant_nested_type
self._buffer = [[] for _ in range(len(schema.names))]
self._buffer_size = record_batch_size
self._record_batches = []
self._record_batches_byte_size = 0
self._file_handle = None

def open(self, temp_path):
self._file_handle = super().open(temp_path)
if ARROW_MAJOR_VERSION < 4:
return pq.ParquetWriter(
self._file_handle,
self._schema,
compression=self._codec,
use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps)
return pq.ParquetWriter(
self._file_handle,
self._schema,
compression=self._codec,
use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps,
use_compliant_nested_type=self._use_compliant_nested_type)

def write_record(self, writer, value):
if len(self._buffer[0]) >= self._buffer_size:
self._flush_buffer()

if self._record_batches_byte_size >= self._row_group_buffer_size:
self._write_batches(writer)

# reorder the data in columnar format.
for i, n in enumerate(self._schema.names):
self._buffer[i].append(value[n])

def close(self, writer):
if len(self._buffer[0]) > 0:
self._flush_buffer()
if self._record_batches_byte_size > 0:
self._write_batches(writer)

writer.close()
if self._file_handle:
self._file_handle.close()
self._file_handle = None
return pcoll | ParDo(
_RowDictionariesToArrowTable(
self._schema, self._row_group_buffer_size,
self._record_batch_size)) | Write(self._sink)

def display_data(self):
res = super().display_data()
res['codec'] = str(self._codec)
res['schema'] = str(self._schema)
res['row_group_buffer_size'] = str(self._row_group_buffer_size)
return res

def _write_batches(self, writer):
table = pa.Table.from_batches(self._record_batches, schema=self._schema)
self._record_batches = []
self._record_batches_byte_size = 0
writer.write_table(table)

def _flush_buffer(self):
arrays = [[] for _ in range(len(self._schema.names))]
for x, y in enumerate(self._buffer):
arrays[x] = pa.array(y, type=self._schema.types[x])
self._buffer[x] = []
rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
self._record_batches.append(rb)
size = 0
for x in arrays:
for b in x.buffers():
if b is not None:
size = size + b.size
self._record_batches_byte_size = self._record_batches_byte_size + size
return {
'sink_dd': self._sink,
'row_group_buffer_size': str(self._row_group_buffer_size)
}


class WriteToParquetBatched(PTransform):
Expand All @@ -634,17 +566,18 @@ def __init__(
mime_type='application/x-parquet',
):
super().__init__()
self._sink = _BatchedParquetSink(
file_path_prefix=file_path_prefix,
schema=schema,
codec=codec,
use_deprecated_int96_timestamps=use_deprecated_int96_timestamps,
use_compliant_nested_type=use_compliant_nested_type,
file_name_suffix=file_name_suffix,
num_shards=num_shards,
shard_name_template=shard_name_template,
mime_type=mime_type,
)
self._sink = \
_create_parquet_sink(
file_path_prefix,
schema,
codec,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type
)

def expand(self, pcoll):
return pcoll | Write(self._sink)
Expand All @@ -653,7 +586,31 @@ def display_data(self):
return {'sink_dd': self._sink}


class _BatchedParquetSink(filebasedsink.FileBasedSink):
def _create_parquet_sink(
file_path_prefix,
schema,
codec,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type):
return \
_ParquetSink(
file_path_prefix,
schema,
codec,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type
)


class _ParquetSink(filebasedsink.FileBasedSink):
"""A sink for parquet files from batches."""
def __init__(
self,
Expand Down
19 changes: 16 additions & 3 deletions sdks/python/apache_beam/io/parquetio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,6 @@ def test_sink_display_data(self):
file_name,
self.SCHEMA,
'none',
1024 * 1024,
1000,
False,
False,
'.end',
Expand All @@ -301,7 +299,6 @@ def test_sink_display_data(self):
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('row_group_buffer_size', str(1024 * 1024)),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
Expand All @@ -310,6 +307,7 @@ def test_write_display_data(self):
file_name = 'some_parquet_sink'
write = WriteToParquet(file_name, self.SCHEMA)
dd = DisplayData.create_from(write)

expected_items = [
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
Expand All @@ -321,6 +319,21 @@ def test_write_display_data(self):
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

def test_write_batched_display_data(self):
file_name = 'some_parquet_sink'
write = WriteToParquetBatched(file_name, self.SCHEMA)
dd = DisplayData.create_from(write)

expected_items = [
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))

def test_sink_transform_int96(self):
with tempfile.NamedTemporaryFile() as dst:
path = dst.name
Expand Down

0 comments on commit 67ccf9f

Please sign in to comment.