Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-369: [Python] Convert multiple record batches at once to Pandas #216

Closed
4 changes: 3 additions & 1 deletion python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@
list_, struct, field,
DataType, Field, Schema, schema)

from pyarrow.table import Column, RecordBatch, Table, from_pandas_dataframe
from pyarrow.table import (Column, RecordBatch, dataframe_from_batches, Table,
from_pandas_dataframe)

from pyarrow.version import version as __version__
3 changes: 3 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CColumn(const shared_ptr[CField]& field,
const shared_ptr[CArray]& data)

CColumn(const shared_ptr[CField]& field,
const vector[shared_ptr[CArray]]& chunks)

int64_t length()
int64_t null_count()
const c_string& name()
Expand Down
47 changes: 47 additions & 0 deletions python/pyarrow/table.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cimport pyarrow.includes.pyarrow as pyarrow
import pyarrow.config

from pyarrow.array cimport Array, box_arrow_array
from pyarrow.error import ArrowException
from pyarrow.error cimport check_status
from pyarrow.schema cimport box_data_type, box_schema

Expand Down Expand Up @@ -414,6 +415,52 @@ cdef class RecordBatch:
return result


def dataframe_from_batches(batches):
"""
Convert a list of Arrow RecordBatches to a pandas.DataFrame

Parameters
----------

batches: list of RecordBatch
RecordBatch list to be converted, schemas must be equal
"""

cdef:
vector[shared_ptr[CArray]] c_array_chunks
vector[shared_ptr[CColumn]] c_columns
shared_ptr[CTable] c_table
Array arr
Schema schema

import pandas as pd

schema = batches[0].schema

# check schemas are equal
if any((not schema.equals(other.schema) for other in batches[1:])):
raise ArrowException("Error converting list of RecordBatches to "
"DataFrame, not all schemas are equal")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later we'll want to display the mismatched schemas in the error message but this is ok for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should have added that.. I'll make a note to do that later.


cdef int K = batches[0].num_columns

# create chunked columns from the batches
c_columns.resize(K)
for i in range(K):
for batch in batches:
arr = batch[i]
c_array_chunks.push_back(arr.sp_array)
c_columns[i].reset(new CColumn(schema.sp_schema.get().field(i),
c_array_chunks))
c_array_chunks.clear()

# create a Table from columns and convert to DataFrame
c_table.reset(new CTable('', schema.sp_schema, c_columns))
table = Table()
table.init(c_table)
return table.to_pandas()


cdef class Table:
"""
A collection of top-level named, equal length Arrow arrays.
Expand Down
35 changes: 35 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pandas.util.testing import assert_frame_equal
import pandas as pd
import pytest

import pyarrow as pa

Expand Down Expand Up @@ -50,6 +51,40 @@ def test_recordbatch_from_to_pandas():
assert_frame_equal(data, result)


def test_recordbatchlist_to_pandas():
data1 = pd.DataFrame({
'c1': np.array([1, 1, 2], dtype='uint32'),
'c2': np.array([1.0, 2.0, 3.0], dtype='float64'),
'c3': [True, None, False],
'c4': ['foo', 'bar', None]
})

data2 = pd.DataFrame({
'c1': np.array([3, 5], dtype='uint32'),
'c2': np.array([4.0, 5.0], dtype='float64'),
'c3': [True, True],
'c4': ['baz', 'qux']
})

batch1 = pa.RecordBatch.from_pandas(data1)
batch2 = pa.RecordBatch.from_pandas(data2)

result = pa.dataframe_from_batches([batch1, batch2])
data = pd.concat([data1, data2], ignore_index=True)
assert_frame_equal(data, result)


def test_recordbatchlist_schema_equals():
data1 = pd.DataFrame({'c1': np.array([1], dtype='uint32')})
data2 = pd.DataFrame({'c1': np.array([4.0, 5.0], dtype='float64')})

batch1 = pa.RecordBatch.from_pandas(data1)
batch2 = pa.RecordBatch.from_pandas(data2)

with pytest.raises(pa.ArrowException):
pa.dataframe_from_batches([batch1, batch2])


def test_table_basics():
data = [
pa.from_pylist(range(5)),
Expand Down
Loading