diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index 58d58c5..a10922e 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -24,7 +24,7 @@ class Array: def __len__(self) -> int: ... def __repr__(self) -> str: ... @classmethod - def from_arrow(cls, input: ArrowArrayExportable) -> Array: + def from_arrow(cls, input: ArrowArrayExportable | ArrowStreamExportable) -> Array: """ Construct this object from an existing Arrow object. @@ -75,7 +75,9 @@ class ArrayReader: def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ... def __repr__(self) -> str: ... @classmethod - def from_arrow(cls, input: ArrowStreamExportable) -> ArrayReader: ... + def from_arrow( + cls, input: ArrowArrayExportable | ArrowStreamExportable + ) -> ArrayReader: ... @classmethod def from_arrow_pycapsule(cls, capsule) -> ArrayReader: """Construct this object from a bare Arrow PyCapsule""" @@ -103,7 +105,9 @@ class ChunkedArray: def __len__(self) -> int: ... def __repr__(self) -> str: ... @classmethod - def from_arrow(cls, input: ArrowStreamExportable) -> ChunkedArray: ... + def from_arrow( + cls, input: ArrowArrayExportable | ArrowStreamExportable + ) -> ChunkedArray: ... @classmethod def from_arrow_pycapsule(cls, capsule) -> ChunkedArray: """Construct this object from a bare Arrow PyCapsule""" @@ -630,7 +634,9 @@ class RecordBatch: New RecordBatch """ @classmethod - def from_arrow(cls, input: ArrowArrayExportable) -> RecordBatch: ... + def from_arrow( + cls, input: ArrowArrayExportable | ArrowStreamExportable + ) -> RecordBatch: ... @classmethod def from_arrow_pycapsule(cls, schema_capsule, array_capsule) -> RecordBatch: """Construct this object from bare Arrow PyCapsules""" @@ -720,7 +726,9 @@ class RecordBatchReader: def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ... def __repr__(self) -> str: ... @classmethod - def from_arrow(cls, input: ArrowStreamExportable) -> RecordBatchReader: ... + def from_arrow( + cls, input: ArrowArrayExportable | ArrowStreamExportable + ) -> RecordBatchReader: ... @classmethod def from_arrow_pycapsule(cls, capsule) -> RecordBatchReader: """Construct this object from a bare Arrow PyCapsule""" @@ -963,7 +971,7 @@ class Table: new table """ @classmethod - def from_arrow(cls, input: ArrowStreamExportable) -> Table: + def from_arrow(cls, input: ArrowArrayExportable | ArrowStreamExportable) -> Table: """ Construct this object from an existing Arrow object. diff --git a/pyo3-arrow/src/array.rs b/pyo3-arrow/src/array.rs index 08b84d0..7adcbad 100644 --- a/pyo3-arrow/src/array.rs +++ b/pyo3-arrow/src/array.rs @@ -1,6 +1,7 @@ use std::fmt::Display; use std::sync::Arc; +use arrow::compute::concat; use arrow::datatypes::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, @@ -20,6 +21,7 @@ use crate::error::PyArrowResult; use crate::ffi::from_python::utils::import_array_pycapsules; use crate::ffi::to_array_pycapsules; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array; +use crate::input::AnyArray; use crate::interop::numpy::from_numpy::from_numpy; use crate::interop::numpy::to_numpy::to_numpy; use crate::PyDataType; @@ -228,8 +230,17 @@ impl PyArray { } #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: AnyArray) -> PyArrowResult { + match input { + AnyArray::Array(array) => Ok(array), + AnyArray::Stream(stream) => { + let chunked_array = stream.into_chunked_array()?; + let (chunks, field) = chunked_array.into_inner(); + let chunk_refs = chunks.iter().map(|arr| arr.as_ref()).collect::>(); + let concatted = concat(chunk_refs.as_slice())?; + Ok(Self::new(concatted, field)) + } + } } #[classmethod] diff --git a/pyo3-arrow/src/array_reader.rs b/pyo3-arrow/src/array_reader.rs index 18b925d..85ff2e1 100644 --- a/pyo3-arrow/src/array_reader.rs +++ b/pyo3-arrow/src/array_reader.rs @@ -12,6 +12,7 @@ use crate::ffi::from_python::utils::import_stream_pycapsule; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream; use crate::ffi::to_python::to_stream_pycapsule; use crate::ffi::{ArrayIterator, ArrayReader}; +use crate::input::AnyArray; use crate::{PyArray, PyChunkedArray, PyField}; /// A Python-facing Arrow array reader. @@ -136,8 +137,9 @@ impl PyArrayReader { /// It can be called on anything that exports the Arrow stream interface /// (`__arrow_c_stream__`), such as a `Table` or `ArrayReader`. #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: AnyArray) -> PyArrowResult { + let reader = input.into_reader()?; + Ok(Self::new(reader)) } /// Construct this object from a bare Arrow PyCapsule. diff --git a/pyo3-arrow/src/chunked.rs b/pyo3-arrow/src/chunked.rs index 36a57dd..e0270cf 100644 --- a/pyo3-arrow/src/chunked.rs +++ b/pyo3-arrow/src/chunked.rs @@ -15,6 +15,7 @@ use crate::ffi::from_python::utils::import_stream_pycapsule; use crate::ffi::to_python::chunked::ArrayIterator; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream; use crate::ffi::to_python::to_stream_pycapsule; +use crate::input::AnyArray; use crate::interop::numpy::to_numpy::chunked_to_numpy; use crate::{PyArray, PyDataType}; @@ -287,8 +288,8 @@ impl PyChunkedArray { /// It can be called on anything that exports the Arrow stream interface /// (`__arrow_c_stream__`). All batches will be materialized in memory. #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: AnyArray) -> PyArrowResult { + input.into_chunked_array() } /// Construct this object from a bare Arrow PyCapsule diff --git a/pyo3-arrow/src/datatypes.rs b/pyo3-arrow/src/datatypes.rs index a5799b4..5db6fc5 100644 --- a/pyo3-arrow/src/datatypes.rs +++ b/pyo3-arrow/src/datatypes.rs @@ -131,8 +131,8 @@ impl PyDataType { /// It can be called on anything that exports the Arrow schema interface /// (`__arrow_c_schema__`). #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: Self) -> Self { + input } /// Construct this object from a bare Arrow PyCapsule diff --git a/pyo3-arrow/src/field.rs b/pyo3-arrow/src/field.rs index 6b1a60e..fc3cc0c 100644 --- a/pyo3-arrow/src/field.rs +++ b/pyo3-arrow/src/field.rs @@ -131,8 +131,8 @@ impl PyField { /// It can be called on anything that exports the Arrow schema interface /// (`__arrow_c_schema__`). #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: Self) -> Self { + input } /// Construct this object from a bare Arrow PyCapsule diff --git a/pyo3-arrow/src/input.rs b/pyo3-arrow/src/input.rs index 8e893af..45288f4 100644 --- a/pyo3-arrow/src/input.rs +++ b/pyo3-arrow/src/input.rs @@ -15,7 +15,7 @@ use pyo3::prelude::*; use crate::array_reader::PyArrayReader; use crate::error::PyArrowResult; use crate::ffi::{ArrayIterator, ArrayReader}; -use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader}; +use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PyTable}; /// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either /// Arrow object as input. @@ -36,6 +36,13 @@ impl AnyRecordBatch { } } + pub fn into_table(self) -> PyArrowResult { + let reader = self.into_reader()?; + let schema = reader.schema(); + let batches = reader.collect::>()?; + Ok(PyTable::new(batches, schema)) + } + pub fn schema(&self) -> PyResult { match self { Self::RecordBatch(batch) => Ok(batch.as_ref().schema()), @@ -53,19 +60,10 @@ pub enum AnyArray { impl AnyArray { pub fn into_chunked_array(self) -> PyArrowResult { - match self { - Self::Array(array) => { - let (array, field) = array.into_inner(); - Ok(PyChunkedArray::new(vec![array], field)) - } - Self::Stream(stream) => { - let field = stream.field_ref()?; - let chunks = stream - .into_reader()? - .collect::, ArrowError>>()?; - Ok(PyChunkedArray::new(chunks, field)) - } - } + let reader = self.into_reader()?; + let field = reader.field(); + let chunks = reader.collect::>()?; + Ok(PyChunkedArray::new(chunks, field)) } pub fn into_reader(self) -> PyResult> { diff --git a/pyo3-arrow/src/record_batch.rs b/pyo3-arrow/src/record_batch.rs index ce23a5d..3e1e36c 100644 --- a/pyo3-arrow/src/record_batch.rs +++ b/pyo3-arrow/src/record_batch.rs @@ -2,6 +2,7 @@ use std::fmt::Display; use std::sync::Arc; use arrow::array::AsArray; +use arrow::compute::concat_batches; use arrow_array::{Array, ArrayRef, RecordBatch, StructArray}; use arrow_schema::{DataType, Field, Schema, SchemaBuilder}; use indexmap::IndexMap; @@ -14,7 +15,7 @@ use crate::error::PyArrowResult; use crate::ffi::from_python::utils::import_array_pycapsules; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array; use crate::ffi::to_python::to_array_pycapsules; -use crate::input::{FieldIndexInput, MetadataInput, NameOrField, SelectIndices}; +use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices}; use crate::schema::display_schema; use crate::{PyArray, PyField, PySchema}; @@ -101,7 +102,7 @@ impl PyRecordBatch { if data.hasattr("__arrow_c_array__")? { Ok(Self::from_arrow( &py.get_type_bound::(), - data, + data.extract()?, )?) } else if let Ok(mapping) = data.extract::>() { Self::from_pydict(&py.get_type_bound::(), mapping, metadata) @@ -218,8 +219,15 @@ impl PyRecordBatch { /// Returns: /// Self #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: AnyRecordBatch) -> PyArrowResult { + match input { + AnyRecordBatch::RecordBatch(rb) => Ok(rb), + AnyRecordBatch::Stream(stream) => { + let (batches, schema) = stream.into_table()?.into_inner(); + let single_batch = concat_batches(&schema, batches.iter())?; + Ok(Self::new(single_batch)) + } + } } /// Construct this object from a bare Arrow PyCapsule diff --git a/pyo3-arrow/src/record_batch_reader.rs b/pyo3-arrow/src/record_batch_reader.rs index fb6a590..8719220 100644 --- a/pyo3-arrow/src/record_batch_reader.rs +++ b/pyo3-arrow/src/record_batch_reader.rs @@ -13,6 +13,7 @@ use crate::ffi::from_python::utils::import_stream_pycapsule; use crate::ffi::to_python::chunked::ArrayIterator; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream; use crate::ffi::to_python::to_stream_pycapsule; +use crate::input::AnyRecordBatch; use crate::schema::display_schema; use crate::{PyRecordBatch, PySchema, PyTable}; @@ -153,8 +154,9 @@ impl PyRecordBatchReader { /// It can be called on anything that exports the Arrow stream interface /// (`__arrow_c_stream__`), such as a `Table` or `RecordBatchReader`. #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: AnyRecordBatch) -> PyArrowResult { + let reader = input.into_reader()?; + Ok(Self::new(reader)) } /// Construct this object from a bare Arrow PyCapsule. diff --git a/pyo3-arrow/src/schema.rs b/pyo3-arrow/src/schema.rs index 1610ffb..80a571d 100644 --- a/pyo3-arrow/src/schema.rs +++ b/pyo3-arrow/src/schema.rs @@ -136,8 +136,8 @@ impl PySchema { } #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: Self) -> Self { + input } #[classmethod] diff --git a/pyo3-arrow/src/table.rs b/pyo3-arrow/src/table.rs index 2c5a4c4..f885b15 100644 --- a/pyo3-arrow/src/table.rs +++ b/pyo3-arrow/src/table.rs @@ -17,7 +17,9 @@ use crate::ffi::from_python::utils::import_stream_pycapsule; use crate::ffi::to_python::chunked::ArrayIterator; use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream; use crate::ffi::to_python::to_stream_pycapsule; -use crate::input::{AnyArray, FieldIndexInput, MetadataInput, NameOrField, SelectIndices}; +use crate::input::{ + AnyArray, AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices, +}; use crate::schema::display_schema; use crate::{PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PySchema}; @@ -117,8 +119,8 @@ impl PyTable { } #[classmethod] - pub fn from_arrow(_cls: &Bound, input: &Bound) -> PyResult { - input.extract() + pub fn from_arrow(_cls: &Bound, input: AnyRecordBatch) -> PyArrowResult { + input.into_table() } #[classmethod]