Skip to content

Commit

Permalink
Infer data type in from_numpy (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Jul 31, 2024
1 parent 9d5f2d2 commit 736c35a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 26 deletions.
2 changes: 1 addition & 1 deletion arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Array:
"""Construct this object from bare Arrow PyCapsules"""

@classmethod
def from_numpy(cls, array: np.ndarray, type: ArrowSchemaExportable) -> Array:
def from_numpy(cls, array: np.ndarray) -> Array:
"""Construct an Array from a numpy ndarray"""

def to_numpy(self) -> NDArray:
Expand Down
4 changes: 2 additions & 2 deletions pyo3-arrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ impl PyArray {
#[classmethod]
pub fn from_numpy(
_cls: &Bound<PyType>,
py: Python,
array: Bound<'_, PyAny>,
r#type: PyDataType,
) -> PyArrowResult<Self> {
let mut numpy_array = array;
if numpy_array.hasattr("__array__")? {
numpy_array = numpy_array.call_method0("__array__")?;
};
let numpy_array: &PyUntypedArray = FromPyObject::extract_bound(&numpy_array)?;
let arrow_array = from_numpy(numpy_array, r#type.into_inner())?;
let arrow_array = from_numpy(py, numpy_array)?;
Ok(Self::from_array_ref(arrow_array))
}

Expand Down
56 changes: 33 additions & 23 deletions pyo3-arrow/src/interop/numpy/from_numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ use arrow::datatypes::{
UInt64Type, UInt8Type,
};
use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray};
use arrow_schema::DataType;
use numpy::{PyArray1, PyUntypedArray};
use numpy::{dtype_bound, PyArray1, PyArrayDescr, PyUntypedArray};
use pyo3::exceptions::PyValueError;
use pyo3::Python;

use crate::error::PyArrowResult;

pub fn from_numpy(array: &PyUntypedArray, arrow_data_type: DataType) -> PyArrowResult<ArrayRef> {
pub fn from_numpy(py: Python, array: &PyUntypedArray) -> PyArrowResult<ArrayRef> {
macro_rules! numpy_to_arrow {
($rust_type:ty, $arrow_type:ty) => {{
let arr = array.downcast::<PyArray1<$rust_type>>()?;
Expand All @@ -20,25 +20,35 @@ pub fn from_numpy(array: &PyUntypedArray, arrow_data_type: DataType) -> PyArrowR
)))
}};
}

match arrow_data_type {
// DataType::Float16 => numpy_to_arrow!(f16, Float16Type),
DataType::Float32 => numpy_to_arrow!(f32, Float32Type),
DataType::Float64 => numpy_to_arrow!(f64, Float64Type),
DataType::UInt8 => numpy_to_arrow!(u8, UInt8Type),
DataType::UInt16 => numpy_to_arrow!(u16, UInt16Type),
DataType::UInt32 => numpy_to_arrow!(u32, UInt32Type),
DataType::UInt64 => numpy_to_arrow!(u64, UInt64Type),
DataType::Int8 => numpy_to_arrow!(i8, Int8Type),
DataType::Int16 => numpy_to_arrow!(i16, Int16Type),
DataType::Int32 => numpy_to_arrow!(i32, Int32Type),
DataType::Int64 => numpy_to_arrow!(i64, Int64Type),
DataType::Boolean => {
let arr = array.downcast::<PyArray1<bool>>()?;
Ok(Arc::new(BooleanArray::from(arr.to_owned_array().to_vec())))
}
_ => {
Err(PyValueError::new_err(format!("Unsupported data type {}", arrow_data_type)).into())
}
let dtype = array.dtype();
if is_type::<f32>(py, dtype) {
numpy_to_arrow!(f32, Float32Type)
} else if is_type::<f64>(py, dtype) {
numpy_to_arrow!(f64, Float64Type)
} else if is_type::<u8>(py, dtype) {
numpy_to_arrow!(u8, UInt8Type)
} else if is_type::<u16>(py, dtype) {
numpy_to_arrow!(u16, UInt16Type)
} else if is_type::<u32>(py, dtype) {
numpy_to_arrow!(u32, UInt32Type)
} else if is_type::<u64>(py, dtype) {
numpy_to_arrow!(u64, UInt64Type)
} else if is_type::<i8>(py, dtype) {
numpy_to_arrow!(i8, Int8Type)
} else if is_type::<i16>(py, dtype) {
numpy_to_arrow!(i16, Int16Type)
} else if is_type::<i32>(py, dtype) {
numpy_to_arrow!(i32, Int32Type)
} else if is_type::<i64>(py, dtype) {
numpy_to_arrow!(i64, Int64Type)
} else if is_type::<bool>(py, dtype) {
let arr = array.downcast::<PyArray1<bool>>()?;
Ok(Arc::new(BooleanArray::from(arr.to_owned_array().to_vec())))
} else {
Err(PyValueError::new_err(format!("Unsupported data type {}", dtype)).into())
}
}

fn is_type<T: numpy::Element>(py: Python, dtype: &PyArrayDescr) -> bool {
dtype.is_equiv_to(dtype_bound::<T>(py).as_gil_ref())
}
10 changes: 10 additions & 0 deletions tests/core/test_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np
from arro3.core import Array, DataType


def test_from_numpy():
arr = np.array([1, 2, 3, 4], dtype=np.uint8)
assert Array.from_numpy(arr).type == DataType.uint8()

arr = np.array([1, 2, 3, 4], dtype=np.float64)
assert Array.from_numpy(arr).type == DataType.float64()

0 comments on commit 736c35a

Please sign in to comment.