Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-26685: [Python] use IPC for pickle serialisation #37683

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -674,12 +674,18 @@ cdef shared_ptr[CArrayData] _reconstruct_array_data(data):
offset)


def _restore_array(data):
def _restore_array(buffer):
"""
Reconstruct an Array from pickled ArrayData.
Restore an IPC serialized Arrow Array.

Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
cdef shared_ptr[CArrayData] ad = _reconstruct_array_data(data)
return pyarrow_wrap_array(MakeArray(ad))
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buffer) as reader:
return reader.read_next_batch().column(0)


cdef class _PandasConvertible(_Weakrefable):
Expand Down Expand Up @@ -1100,8 +1106,22 @@ cdef class Array(_PandasConvertible):
memory_pool=memory_pool)

def __reduce__(self):
return _restore_array, \
(_reduce_array_data(self.sp_array.get().data().get()),)
"""
Use Arrow IPC format for serialization.

Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import RecordBatch, BufferOutputStream

batch = RecordBatch.from_arrays([self], [''])
sink = BufferOutputStream()
with RecordBatchStreamWriter(sink, schema=batch.schema) as writer:
writer.write_batch(batch)

return _restore_array, (sink.getvalue(),)

@staticmethod
def from_buffers(DataType type, length, buffers, null_count=-1, offset=0,
Expand Down
91 changes: 79 additions & 12 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import warnings
import functools


cdef class ChunkedArray(_PandasConvertible):
Expand Down Expand Up @@ -67,7 +68,21 @@ cdef class ChunkedArray(_PandasConvertible):
self.chunked_array = chunked_array.get()

def __reduce__(self):
return chunked_array, (self.chunks, self.type)
"""
Use Arrow IPC format for serialization.

Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685

Adds ~230 extra bytes to the pickled payload per Array chunk.
"""
import pyarrow as pa

# IPC serialization requires wrapping in RecordBatch
table = pa.Table.from_arrays([self], names=[""])
reconstruct_table, serialised = table.__reduce__()
return functools.partial(_reconstruct_chunked_array, reconstruct_table), serialised

@property
def data(self):
Expand Down Expand Up @@ -1391,6 +1406,17 @@ def chunked_array(arrays, type=None):
return pyarrow_wrap_chunked_array(c_result)


def _reconstruct_chunked_array(restore_table, buffer):
"""
Restore an IPC serialized ChunkedArray.

Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
return restore_table(buffer).column(0)


cdef _schema_from_arrays(arrays, names, metadata, shared_ptr[CSchema]* schema):
cdef:
Py_ssize_t K = len(arrays)
Expand Down Expand Up @@ -2196,7 +2222,21 @@ cdef class RecordBatch(_Tabular):
return self.batch != NULL

def __reduce__(self):
return _reconstruct_record_batch, (self.columns, self.schema)
"""
Use Arrow IPC format for serialization.

Workaround for a pickling sliced RecordBatch issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import RecordBatch, BufferOutputStream

sink = BufferOutputStream()
with RecordBatchStreamWriter(sink, schema=self.schema) as writer:
writer.write_batch(self)

return _reconstruct_record_batch, (sink.getvalue(),)

def validate(self, *, full=False):
"""
Expand Down Expand Up @@ -2984,11 +3024,18 @@ cdef class RecordBatch(_Tabular):
return pyarrow_wrap_batch(c_batch)


def _reconstruct_record_batch(columns, schema):
def _reconstruct_record_batch(buffer):
"""
Internal: reconstruct RecordBatch from pickled components.
Restore an IPC serialized Arrow RecordBatch.

Workaround for a pickling sliced RecordBatch issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
return RecordBatch.from_arrays(columns, schema=schema)
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buffer) as reader:
return reader.read_next_batch()


def table_to_blocks(options, Table table, categories, extension_columns):
Expand Down Expand Up @@ -3170,10 +3217,23 @@ cdef class Table(_Tabular):
check_status(self.table.Validate())

def __reduce__(self):
# Reduce the columns as ChunkedArrays to avoid serializing schema
# data twice
columns = [col for col in self.columns]
return _reconstruct_table, (columns, self.schema)
"""
Use Arrow IPC format for serialization.

Workaround for a pickling sliced Table issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685

Adds ~230 extra bytes to pickled payload per Array chunk.
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import RecordBatch, BufferOutputStream

sink = BufferOutputStream()
with RecordBatchStreamWriter(sink, schema=self.schema) as writer:
writer.write_table(self)

return _reconstruct_table, (sink.getvalue(), )

def slice(self, offset=0, length=None):
"""
Expand Down Expand Up @@ -4754,11 +4814,18 @@ cdef class Table(_Tabular):
)


def _reconstruct_table(arrays, schema):
def _reconstruct_table(buffer):
"""
Internal: reconstruct pa.Table from pickled components.
Restore an IPC serialized Arrow Table.

Workaround for a pickling sliced Table issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
return Table.from_arrays(arrays, schema=schema)
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buffer) as reader:
return reader.read_all()


def record_batch(data, names=None, schema=None, metadata=None):
Expand Down
56 changes: 56 additions & 0 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,7 @@ def test_cast_identities(ty, values):
)



@pickle_test_parametrize
def test_array_pickle(data, typ, pickle_module):
# Allocate here so that we don't have any Arrow data allocated.
Expand Down Expand Up @@ -2050,6 +2051,61 @@ def test_array_pickle_protocol5(data, typ, pickle_module):
assert result_addresses == addresses


@pytest.mark.parametrize(
('data', 'typ'),
[
# Int array
(list(range(999)) + [None], pa.int64()),
# Float array
(list(map(float, range(999))) + [None], pa.float64()),
# Boolean array
([True, False, None, True] * 250, pa.bool_()),
# String array
(['a', 'b', 'cd', None, 'efg'] * 200, pa.string()),
# List array
([[1, 2], [3], [None, 4, 5], [6]] * 250, pa.list_(pa.int64())),
# Large list array
(
[[4, 5], [6], [None, 7], [8, 9, 10]] * 250,
pa.large_list(pa.int16())
),
# String list array
(
[['a'], None, ['b', 'cd'], ['efg']] * 250,
pa.list_(pa.string())
),
# Struct array
(
[(1, 'a'), (2, 'c'), None, (3, 'b')] * 250,
pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.string())])
),
# Empty array
])
def test_array_pickle_slice_truncation(data, typ, pickle_module):
arr = pa.array(data, type=typ)
serialized_arr = pickle_module.dumps(arr)

slice_arr = arr.slice(10, 2)
serialized_slice = pickle_module.dumps(slice_arr)

# Check truncation upon serialization
assert len(serialized_slice) <= 0.2 * len(serialized_arr)

post_pickle_slice = pickle_module.loads(serialized_slice)

# Check for post-roundtrip equality
assert post_pickle_slice.equals(slice_arr)

# Check that pickling reset the offset
assert post_pickle_slice.offset == 0

# After pickling the slice buffer trimmed to only contain the sliced data
buf_size = arr.get_total_buffer_size()
post_pickle_slice_buf_size = post_pickle_slice.get_total_buffer_size()
assert buf_size / post_pickle_slice_buf_size - \
len(arr) / len(post_pickle_slice) < 10


@pytest.mark.parametrize(
'narr',
[
Expand Down
Loading
Loading