Skip to content

Commit

Permalink
Support both array and stream arrow objects in data constructors (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Jul 31, 2024
1 parent d0709bb commit d5a529a
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 41 deletions.
20 changes: 14 additions & 6 deletions arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Array:
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
@classmethod
def from_arrow(cls, input: ArrowArrayExportable) -> Array:
def from_arrow(cls, input: ArrowArrayExportable | ArrowStreamExportable) -> Array:
"""
Construct this object from an existing Arrow object.
Expand Down Expand Up @@ -75,7 +75,9 @@ class ArrayReader:
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ...
def __repr__(self) -> str: ...
@classmethod
def from_arrow(cls, input: ArrowStreamExportable) -> ArrayReader: ...
def from_arrow(
cls, input: ArrowArrayExportable | ArrowStreamExportable
) -> ArrayReader: ...
@classmethod
def from_arrow_pycapsule(cls, capsule) -> ArrayReader:
"""Construct this object from a bare Arrow PyCapsule"""
Expand Down Expand Up @@ -103,7 +105,9 @@ class ChunkedArray:
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
@classmethod
def from_arrow(cls, input: ArrowStreamExportable) -> ChunkedArray: ...
def from_arrow(
cls, input: ArrowArrayExportable | ArrowStreamExportable
) -> ChunkedArray: ...
@classmethod
def from_arrow_pycapsule(cls, capsule) -> ChunkedArray:
"""Construct this object from a bare Arrow PyCapsule"""
Expand Down Expand Up @@ -630,7 +634,9 @@ class RecordBatch:
New RecordBatch
"""
@classmethod
def from_arrow(cls, input: ArrowArrayExportable) -> RecordBatch: ...
def from_arrow(
cls, input: ArrowArrayExportable | ArrowStreamExportable
) -> RecordBatch: ...
@classmethod
def from_arrow_pycapsule(cls, schema_capsule, array_capsule) -> RecordBatch:
"""Construct this object from bare Arrow PyCapsules"""
Expand Down Expand Up @@ -720,7 +726,9 @@ class RecordBatchReader:
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ...
def __repr__(self) -> str: ...
@classmethod
def from_arrow(cls, input: ArrowStreamExportable) -> RecordBatchReader: ...
def from_arrow(
cls, input: ArrowArrayExportable | ArrowStreamExportable
) -> RecordBatchReader: ...
@classmethod
def from_arrow_pycapsule(cls, capsule) -> RecordBatchReader:
"""Construct this object from a bare Arrow PyCapsule"""
Expand Down Expand Up @@ -963,7 +971,7 @@ class Table:
new table
"""
@classmethod
def from_arrow(cls, input: ArrowStreamExportable) -> Table:
def from_arrow(cls, input: ArrowArrayExportable | ArrowStreamExportable) -> Table:
"""
Construct this object from an existing Arrow object.
Expand Down
15 changes: 13 additions & 2 deletions pyo3-arrow/src/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Display;
use std::sync::Arc;

use arrow::compute::concat;
use arrow::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
UInt64Type, UInt8Type,
Expand All @@ -20,6 +21,7 @@ use crate::error::PyArrowResult;
use crate::ffi::from_python::utils::import_array_pycapsules;
use crate::ffi::to_array_pycapsules;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
use crate::input::AnyArray;
use crate::interop::numpy::from_numpy::from_numpy;
use crate::interop::numpy::to_numpy::to_numpy;
use crate::PyDataType;
Expand Down Expand Up @@ -228,8 +230,17 @@ impl PyArray {
}

#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: AnyArray) -> PyArrowResult<Self> {
match input {
AnyArray::Array(array) => Ok(array),
AnyArray::Stream(stream) => {
let chunked_array = stream.into_chunked_array()?;
let (chunks, field) = chunked_array.into_inner();
let chunk_refs = chunks.iter().map(|arr| arr.as_ref()).collect::<Vec<_>>();
let concatted = concat(chunk_refs.as_slice())?;
Ok(Self::new(concatted, field))
}
}
}

#[classmethod]
Expand Down
6 changes: 4 additions & 2 deletions pyo3-arrow/src/array_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
use crate::ffi::to_python::to_stream_pycapsule;
use crate::ffi::{ArrayIterator, ArrayReader};
use crate::input::AnyArray;
use crate::{PyArray, PyChunkedArray, PyField};

/// A Python-facing Arrow array reader.
Expand Down Expand Up @@ -136,8 +137,9 @@ impl PyArrayReader {
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`), such as a `Table` or `ArrayReader`.
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: AnyArray) -> PyArrowResult<Self> {
let reader = input.into_reader()?;
Ok(Self::new(reader))
}

/// Construct this object from a bare Arrow PyCapsule.
Expand Down
5 changes: 3 additions & 2 deletions pyo3-arrow/src/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::ffi::to_python::chunked::ArrayIterator;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
use crate::ffi::to_python::to_stream_pycapsule;
use crate::input::AnyArray;
use crate::interop::numpy::to_numpy::chunked_to_numpy;
use crate::{PyArray, PyDataType};

Expand Down Expand Up @@ -287,8 +288,8 @@ impl PyChunkedArray {
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`). All batches will be materialized in memory.
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: AnyArray) -> PyArrowResult<Self> {
input.into_chunked_array()
}

/// Construct this object from a bare Arrow PyCapsule
Expand Down
4 changes: 2 additions & 2 deletions pyo3-arrow/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ impl PyDataType {
/// It can be called on anything that exports the Arrow schema interface
/// (`__arrow_c_schema__`).
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
input
}

/// Construct this object from a bare Arrow PyCapsule
Expand Down
4 changes: 2 additions & 2 deletions pyo3-arrow/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ impl PyField {
/// It can be called on anything that exports the Arrow schema interface
/// (`__arrow_c_schema__`).
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
input
}

/// Construct this object from a bare Arrow PyCapsule
Expand Down
26 changes: 12 additions & 14 deletions pyo3-arrow/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use pyo3::prelude::*;
use crate::array_reader::PyArrayReader;
use crate::error::PyArrowResult;
use crate::ffi::{ArrayIterator, ArrayReader};
use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader};
use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PyTable};

/// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either
/// Arrow object as input.
Expand All @@ -36,6 +36,13 @@ impl AnyRecordBatch {
}
}

pub fn into_table(self) -> PyArrowResult<PyTable> {
let reader = self.into_reader()?;
let schema = reader.schema();
let batches = reader.collect::<Result<_, ArrowError>>()?;
Ok(PyTable::new(batches, schema))
}

pub fn schema(&self) -> PyResult<SchemaRef> {
match self {
Self::RecordBatch(batch) => Ok(batch.as_ref().schema()),
Expand All @@ -53,19 +60,10 @@ pub enum AnyArray {

impl AnyArray {
pub fn into_chunked_array(self) -> PyArrowResult<PyChunkedArray> {
match self {
Self::Array(array) => {
let (array, field) = array.into_inner();
Ok(PyChunkedArray::new(vec![array], field))
}
Self::Stream(stream) => {
let field = stream.field_ref()?;
let chunks = stream
.into_reader()?
.collect::<Result<Vec<_>, ArrowError>>()?;
Ok(PyChunkedArray::new(chunks, field))
}
}
let reader = self.into_reader()?;
let field = reader.field();
let chunks = reader.collect::<Result<_, ArrowError>>()?;
Ok(PyChunkedArray::new(chunks, field))
}

pub fn into_reader(self) -> PyResult<Box<dyn ArrayReader + Send>> {
Expand Down
16 changes: 12 additions & 4 deletions pyo3-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::Display;
use std::sync::Arc;

use arrow::array::AsArray;
use arrow::compute::concat_batches;
use arrow_array::{Array, ArrayRef, RecordBatch, StructArray};
use arrow_schema::{DataType, Field, Schema, SchemaBuilder};
use indexmap::IndexMap;
Expand All @@ -14,7 +15,7 @@ use crate::error::PyArrowResult;
use crate::ffi::from_python::utils::import_array_pycapsules;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
use crate::ffi::to_python::to_array_pycapsules;
use crate::input::{FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
use crate::schema::display_schema;
use crate::{PyArray, PyField, PySchema};

Expand Down Expand Up @@ -101,7 +102,7 @@ impl PyRecordBatch {
if data.hasattr("__arrow_c_array__")? {
Ok(Self::from_arrow(
&py.get_type_bound::<PyRecordBatch>(),
data,
data.extract()?,
)?)
} else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
Self::from_pydict(&py.get_type_bound::<PyRecordBatch>(), mapping, metadata)
Expand Down Expand Up @@ -218,8 +219,15 @@ impl PyRecordBatch {
/// Returns:
/// Self
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
match input {
AnyRecordBatch::RecordBatch(rb) => Ok(rb),
AnyRecordBatch::Stream(stream) => {
let (batches, schema) = stream.into_table()?.into_inner();
let single_batch = concat_batches(&schema, batches.iter())?;
Ok(Self::new(single_batch))
}
}
}

/// Construct this object from a bare Arrow PyCapsule
Expand Down
6 changes: 4 additions & 2 deletions pyo3-arrow/src/record_batch_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::ffi::to_python::chunked::ArrayIterator;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
use crate::ffi::to_python::to_stream_pycapsule;
use crate::input::AnyRecordBatch;
use crate::schema::display_schema;
use crate::{PyRecordBatch, PySchema, PyTable};

Expand Down Expand Up @@ -153,8 +154,9 @@ impl PyRecordBatchReader {
/// It can be called on anything that exports the Arrow stream interface
/// (`__arrow_c_stream__`), such as a `Table` or `RecordBatchReader`.
#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
let reader = input.into_reader()?;
Ok(Self::new(reader))
}

/// Construct this object from a bare Arrow PyCapsule.
Expand Down
4 changes: 2 additions & 2 deletions pyo3-arrow/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ impl PySchema {
}

#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
input
}

#[classmethod]
Expand Down
8 changes: 5 additions & 3 deletions pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use crate::ffi::from_python::utils::import_stream_pycapsule;
use crate::ffi::to_python::chunked::ArrayIterator;
use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
use crate::ffi::to_python::to_stream_pycapsule;
use crate::input::{AnyArray, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
use crate::input::{
AnyArray, AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices,
};
use crate::schema::display_schema;
use crate::{PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PySchema};

Expand Down Expand Up @@ -117,8 +119,8 @@ impl PyTable {
}

#[classmethod]
pub fn from_arrow(_cls: &Bound<PyType>, input: &Bound<PyAny>) -> PyResult<Self> {
input.extract()
pub fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
input.into_table()
}

#[classmethod]
Expand Down

0 comments on commit d5a529a

Please sign in to comment.