Skip to content

Commit

Permalink
Fixed numpy array interface (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Jul 31, 2024
1 parent c48f6dd commit 57ff015
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 9 deletions.
11 changes: 9 additions & 2 deletions arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Array:
obj: A sequence of input objects.
type: Explicit type to attempt to coerce to.
"""
def __array__(self) -> NDArray: ...
def __array__(self, dtype=None, copy=None) -> NDArray: ...
def __arrow_c_array__(
self, requested_schema: object | None = None
) -> tuple[object, object]: ...
Expand Down Expand Up @@ -101,7 +101,7 @@ class ChunkedArray:
arrays: Sequence[ArrowArrayExportable],
type: ArrowSchemaExportable | None = None,
) -> None: ...
def __array__(self) -> NDArray: ...
def __array__(self, dtype=None, copy=None) -> NDArray: ...
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ...
def __eq__(self, other) -> bool: ...
def __len__(self) -> int: ...
Expand Down Expand Up @@ -1007,6 +1007,13 @@ class Table:
Returns:
_description_
"""
@classmethod
def from_batches(
cls,
batches: Sequence[ArrowArrayExportable],
*,
schema: ArrowSchemaExportable | None = None,
) -> Table: ...
@overload
@classmethod
def from_pydict(
Expand Down
11 changes: 9 additions & 2 deletions pyo3-arrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,14 @@ impl PyArray {

/// An implementation of the Array interface, for interoperability with numpy and other
/// array libraries.
pub fn __array__(&self, py: Python) -> PyResult<PyObject> {
#[pyo3(signature = (dtype=None, copy=None))]
#[allow(unused_variables)]
pub fn __array__(
&self,
py: Python,
dtype: Option<PyObject>,
copy: Option<PyObject>,
) -> PyResult<PyObject> {
to_numpy(py, &self.array)
}

Expand Down Expand Up @@ -289,7 +296,7 @@ impl PyArray {

/// Copy this array to a `numpy` NDArray
pub fn to_numpy(&self, py: Python) -> PyResult<PyObject> {
self.__array__(py)
self.__array__(py, None, None)
}

#[getter]
Expand Down
11 changes: 9 additions & 2 deletions pyo3-arrow/src/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,14 @@ impl PyChunkedArray {

/// An implementation of the Array interface, for interoperability with numpy and other
/// array libraries.
pub fn __array__(&self, py: Python) -> PyResult<PyObject> {
#[pyo3(signature = (dtype=None, copy=None))]
#[allow(unused_variables)]
pub fn __array__(
&self,
py: Python,
dtype: Option<PyObject>,
copy: Option<PyObject>,
) -> PyResult<PyObject> {
let chunk_refs = self
.chunks
.iter()
Expand Down Expand Up @@ -386,7 +393,7 @@ impl PyChunkedArray {

/// Copy this array to a `numpy` NDArray
pub fn to_numpy(&self, py: Python) -> PyResult<PyObject> {
self.__array__(py)
self.__array__(py, None, None)
}

pub fn r#type(&self, py: Python) -> PyResult<PyObject> {
Expand Down
76 changes: 73 additions & 3 deletions pyo3-arrow/src/interop/numpy/to_numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use arrow::datatypes::*;
use arrow_array::Array;
use arrow_schema::DataType;
use numpy::ToPyArray;
use pyo3::exceptions::PyValueError;
use pyo3::types::PyAnyMethods;
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
use pyo3::types::{PyAnyMethods, PyBytes, PyDict, PyList, PyString, PyTuple};
use pyo3::{intern, PyObject, PyResult, Python, ToPyObject};

pub fn to_numpy(py: Python, arr: &dyn Array) -> PyResult<PyObject> {
Expand Down Expand Up @@ -39,7 +39,77 @@ pub fn to_numpy(py: Python, arr: &dyn Array) -> PyResult<PyObject> {
let bools = arr.as_boolean().values().iter().collect::<Vec<_>>();
bools.to_pyarray_bound(py).to_object(py)
}
_ => todo!(),
// For other data types we create Python objects and then create an object-typed numpy
// array
DataType::Binary => {
let mut py_bytes = Vec::with_capacity(arr.len());
arr.as_binary::<i32>()
.iter()
.for_each(|x| py_bytes.push(PyBytes::new_bound(py, x.unwrap())));
let py_list = PyList::new_bound(py, py_bytes);
let numpy_mod = py.import_bound(intern!(py, "numpy"))?;
let kwargs = PyDict::new_bound(py);
kwargs.set_item("dtype", numpy_mod.getattr(intern!(py, "object_"))?)?;
let np_arr = numpy_mod.call_method(
intern!(py, "array"),
PyTuple::new_bound(py, vec![py_list]),
Some(&kwargs),
)?;
np_arr.into()
}
DataType::LargeBinary => {
let mut py_bytes = Vec::with_capacity(arr.len());
arr.as_binary::<i64>()
.iter()
.for_each(|x| py_bytes.push(PyBytes::new_bound(py, x.unwrap())));
let py_list = PyList::new_bound(py, py_bytes);
let numpy_mod = py.import_bound(intern!(py, "numpy"))?;
let kwargs = PyDict::new_bound(py);
kwargs.set_item("dtype", numpy_mod.getattr(intern!(py, "object_"))?)?;
let np_arr = numpy_mod.call_method(
intern!(py, "array"),
PyTuple::new_bound(py, vec![py_list]),
Some(&kwargs),
)?;
np_arr.into()
}
DataType::Utf8 => {
let mut py_bytes = Vec::with_capacity(arr.len());
arr.as_string::<i32>()
.iter()
.for_each(|x| py_bytes.push(PyString::new_bound(py, x.unwrap())));
let py_list = PyList::new_bound(py, py_bytes);
let numpy_mod = py.import_bound(intern!(py, "numpy"))?;
let kwargs = PyDict::new_bound(py);
kwargs.set_item("dtype", numpy_mod.getattr(intern!(py, "object_"))?)?;
let np_arr = numpy_mod.call_method(
intern!(py, "array"),
PyTuple::new_bound(py, vec![py_list]),
Some(&kwargs),
)?;
np_arr.into()
}
DataType::LargeUtf8 => {
let mut py_bytes = Vec::with_capacity(arr.len());
arr.as_string::<i64>()
.iter()
.for_each(|x| py_bytes.push(PyString::new_bound(py, x.unwrap())));
let py_list = PyList::new_bound(py, py_bytes);
let numpy_mod = py.import_bound(intern!(py, "numpy"))?;
let kwargs = PyDict::new_bound(py);
kwargs.set_item("dtype", numpy_mod.getattr(intern!(py, "object_"))?)?;
let np_arr = numpy_mod.call_method(
intern!(py, "array"),
PyTuple::new_bound(py, vec![py_list]),
Some(&kwargs),
)?;
np_arr.into()
}
dt => {
return Err(PyNotImplementedError::new_err(format!(
"Unsupported type in to_numpy {dt}"
)))
}
};
Ok(result)
}
Expand Down
24 changes: 24 additions & 0 deletions pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,30 @@ impl PyTable {
Ok(Self::new(batches, schema))
}

#[classmethod]
#[pyo3(signature = (batches, *, schema=None))]
pub fn from_batches(
_cls: &Bound<PyType>,
batches: Vec<PyRecordBatch>,
schema: Option<PySchema>,
) -> PyArrowResult<Self> {
if batches.is_empty() {
let schema = schema.ok_or(PyValueError::new_err(
"schema must be passed for an empty list of batches",
))?;
return Ok(Self::new(vec![], schema.into_inner()));
}

let batches = batches
.into_iter()
.map(|batch| batch.into_inner())
.collect::<Vec<_>>();
let schema = schema
.map(|s| s.into_inner())
.unwrap_or(batches.first().unwrap().schema());
Ok(Self::new(batches, schema))
}

#[classmethod]
#[pyo3(signature = (mapping, *, schema=None, metadata=None))]
pub fn from_pydict(
Expand Down

0 comments on commit 57ff015

Please sign in to comment.