Skip to content

Commit

Permalink
Fix constructors for nested arrays (#97)
Browse files Browse the repository at this point in the history
* Fix constructors

* Add struct constructor test
  • Loading branch information
kylebarron authored Jul 31, 2024
1 parent 35211f5 commit 2a243ed
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 20 deletions.
54 changes: 34 additions & 20 deletions arro3-core/src/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,30 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::{PyArray, PyField};
use pyo3_arrow::{PyArray, PyDataType, PyField};

#[pyfunction]
#[pyo3(signature=(values, list_size, *, r#type=None))]
pub(crate) fn fixed_size_list_array(
py: Python,
values: PyArray,
list_size: i32,
r#type: Option<PyField>,
r#type: Option<PyDataType>,
) -> PyArrowResult<PyObject> {
let (values_array, values_field) = values.into_inner();
let field = r#type.map(|f| f.into_inner()).unwrap_or_else(|| {
Arc::new(Field::new_fixed_size_list(
"",
values_field,
list_size,
true,
))
});

let array = FixedSizeListArray::try_new(field.clone(), list_size, values_array, None)?;
Ok(PyArray::new(Arc::new(array), field).to_arro3(py)?)
let list_data_type = r#type
.map(|t| t.into_inner())
.unwrap_or_else(|| DataType::FixedSizeList(values_field.clone(), list_size));
let inner_field = match &list_data_type {
DataType::FixedSizeList(inner_field, _) => inner_field,
_ => {
return Err(
PyValueError::new_err("Expected fixed size list as the outer data type").into(),
)
}
};
let array = FixedSizeListArray::try_new(inner_field.clone(), list_size, values_array, None)?;
Ok(PyArray::new(Arc::new(array), Field::new("", list_data_type, true).into()).to_arro3(py)?)
}

#[pyfunction]
Expand All @@ -39,7 +41,7 @@ pub(crate) fn list_array(
py: Python,
offsets: PyArray,
values: PyArray,
r#type: Option<PyField>,
r#type: Option<PyDataType>,
) -> PyArrowResult<PyObject> {
let (values_array, values_field) = values.into_inner();
let (offsets_array, _) = offsets.into_inner();
Expand All @@ -52,30 +54,42 @@ pub(crate) fn list_array(
)
}
};
let field = r#type.map(|f| f.into_inner()).unwrap_or_else(|| {
let list_data_type = r#type.map(|t| t.into_inner()).unwrap_or_else(|| {
if large_offsets {
Arc::new(Field::new_large_list("item", values_field, true))
DataType::LargeList(values_field.clone())
} else {
Arc::new(Field::new_list("item", values_field, true))
DataType::List(values_field.clone())
}
});
let inner_field = match &list_data_type {
DataType::List(inner_field) | DataType::LargeList(inner_field) => inner_field,
_ => {
return Err(
PyValueError::new_err("Expected fixed size list as the outer data type").into(),
)
}
};

let list_array: ArrayRef = if large_offsets {
Arc::new(LargeListArray::try_new(
field.clone(),
inner_field.clone(),
OffsetBuffer::new(offsets_array.as_primitive::<Int64Type>().values().clone()),
values_array,
None,
)?)
} else {
Arc::new(ListArray::try_new(
field.clone(),
inner_field.clone(),
OffsetBuffer::new(offsets_array.as_primitive::<Int32Type>().values().clone()),
values_array,
None,
)?)
};
Ok(PyArray::new(Arc::new(list_array), field).to_arro3(py)?)
Ok(PyArray::new(
Arc::new(list_array),
Field::new("", list_data_type, true).into(),
)
.to_arro3(py)?)
}

#[pyfunction]
Expand Down
65 changes: 65 additions & 0 deletions tests/core/test_constructors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import pyarrow as pa
from arro3.core import (
Array,
DataType,
fixed_size_list_array,
Field,
list_array,
struct_array,
)
from arro3.compute import list_offsets


def test_fixed_size_list_array():
np_arr = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)
flat_array = Array.from_numpy(np_arr)
array = fixed_size_list_array(flat_array, 2)
pa_array = pa.array(array)
assert pa.types.is_fixed_size_list(pa_array.type)
assert pa_array.type.list_size == 2


def test_fixed_size_list_array_with_type():
np_arr = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)
flat_array = Array.from_numpy(np_arr)
list_type = DataType.list(Field("inner", DataType.float64()), 2)
array = fixed_size_list_array(flat_array, 2, type=list_type)
pa_array = pa.array(array)
assert pa.types.is_fixed_size_list(pa_array.type)
assert pa_array.type.list_size == 2
assert pa_array.type.field(0).name == "inner"


def test_list_array():
np_arr = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)
flat_array = Array.from_numpy(np_arr)
offsets_array = Array.from_numpy(np.array([0, 2, 5, 6], dtype=np.int32))
array = list_array(offsets_array, flat_array)
pa_array = pa.array(array)
assert pa.types.is_list(pa_array.type)
assert list_offsets(array) == offsets_array


def test_list_array_with_type():
np_arr = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)
flat_array = Array.from_numpy(np_arr)
offsets_array = Array.from_numpy(np.array([0, 2, 5, 6], dtype=np.int32))

list_type = DataType.list(Field("inner", DataType.float64()))
array = list_array(offsets_array, flat_array, type=list_type)
pa_array = pa.array(array)
assert pa.types.is_list(pa_array.type)
assert list_offsets(array) == offsets_array
assert pa_array.type.field(0).name == "inner"


def test_struct_array():
a = pa.array([1, 2, 3, 4])
b = pa.array(["a", "b", "c", "d"])

arr = struct_array([a, b], fields=[Field("a", a.type), Field("b", b.type)])
pa_type = pa.array(arr).type
assert pa.types.is_struct(pa_type)
assert pa_type.field(0).name == "a"
assert pa_type.field(1).name == "b"

0 comments on commit 2a243ed

Please sign in to comment.