Skip to content

Commit

Permalink
initial draft
Browse files Browse the repository at this point in the history
  • Loading branch information
easadler committed Sep 5, 2022
1 parent 31561e2 commit e0acb72
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
111 changes: 110 additions & 1 deletion sdks/python/apache_beam/io/parquetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
'ReadAllFromParquet',
'ReadFromParquetBatched',
'ReadAllFromParquetBatched',
'WriteToParquet'
'WriteToParquet',
'WriteToParquetBatched'
]


Expand Down Expand Up @@ -612,3 +613,111 @@ def _flush_buffer(self):
if b is not None:
size = size + b.size
self._record_batches_byte_size = self._record_batches_byte_size + size


class WriteToParquetBatched(PTransform):
"""Initialize a WriteToParquetBatched transform.
Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of
batches. Each batch is a pa.Table. Schema must be specified like the example below.
"""
def __init__(
self,
file_path_prefix,
schema=None,
codec='none',
use_deprecated_int96_timestamps=False,
use_compliant_nested_type=False,
file_name_suffix='',
num_shards=0,
shard_name_template=None,
mime_type='application/x-parquet',
):
super().__init__()
self._sink = _BatchedParquetSink(
file_path_prefix=file_path_prefix,
schema=schema,
codec=codec,
use_deprecated_int96_timestamps=use_deprecated_int96_timestamps,
use_compliant_nested_type=use_compliant_nested_type,
file_name_suffix=file_name_suffix,
num_shards=num_shards,
shard_name_template=shard_name_template,
mime_type=mime_type,
)

def expand(self, pcoll):
return pcoll | Write(self._sink)

def display_data(self):
return {'sink_dd': self._sink}


class _BatchedParquetSink(filebasedsink.FileBasedSink):
"""A sink for parquet files from batches."""
def __init__(
self,
file_path_prefix,
schema,
codec,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
num_shards,
shard_name_template,
mime_type):
super().__init__(
file_path_prefix,
file_name_suffix=file_name_suffix,
num_shards=num_shards,
shard_name_template=shard_name_template,
coder=None,
mime_type=mime_type,
# Compression happens at the block level using the supplied codec, and
# not at the file level.
compression_type=CompressionTypes.UNCOMPRESSED)
self._schema = schema
self._codec = codec
if ARROW_MAJOR_VERSION == 1 and self._codec.lower() == "lz4":
raise ValueError(
"Due to ARROW-9424, writing with LZ4 compression is not supported in "
"pyarrow 1.x, please use a different pyarrow version or a different "
f"codec. Your pyarrow version: {pa.__version__}")
self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps
if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4:
raise ValueError(
"With ARROW-11497, use_compliant_nested_type is only supported in "
"pyarrow version >= 4.x, please use a different pyarrow version. "
f"Your pyarrow version: {pa.__version__}")
self._use_compliant_nested_type = use_compliant_nested_type
self._file_handle = None

def open(self, temp_path):
self._file_handle = super().open(temp_path)
if ARROW_MAJOR_VERSION < 4:
return pq.ParquetWriter(
self._file_handle,
self._schema,
compression=self._codec,
use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps)
return pq.ParquetWriter(
self._file_handle,
self._schema,
compression=self._codec,
use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps,
use_compliant_nested_type=self._use_compliant_nested_type)

def write_record(self, writer, value):
writer.write_table(value)

def close(self, writer):
writer.close()
if self._file_handle:
self._file_handle.close()
self._file_handle = None

def display_data(self):
res = super().display_data()
res['codec'] = str(self._codec)
res['schema'] = str(self._schema)
return res
18 changes: 18 additions & 0 deletions sdks/python/apache_beam/io/parquetio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from apache_beam.io.parquetio import ReadFromParquet
from apache_beam.io.parquetio import ReadFromParquetBatched
from apache_beam.io.parquetio import WriteToParquet
from apache_beam.io.parquetio import WriteToParquetBatched

from apache_beam.io.parquetio import _create_parquet_sink
from apache_beam.io.parquetio import _create_parquet_source
from apache_beam.testing.test_pipeline import TestPipeline
Expand Down Expand Up @@ -348,6 +350,22 @@ def test_sink_transform(self):
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))

def test_sink_transform_batched(self):
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + "tmp_filename")
with TestPipeline() as p:
_ = p \
| Create([self._records_as_arrow()]) \
| WriteToParquetBatched(
path, self.SCHEMA, num_shards=1, shard_name_template='')
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| ReadFromParquet(path) \
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))

def test_sink_transform_compliant_nested_type(self):
if ARROW_MAJOR_VERSION < 4:
return unittest.skip(
Expand Down

0 comments on commit e0acb72

Please sign in to comment.