Skip to content

Commit

Permalink
Implement take method on Array and RecordBatch (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Jul 31, 2024
1 parent b6cbaa4 commit c48f6dd
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 4 deletions.
2 changes: 2 additions & 0 deletions arro3-compute/python/arro3/compute/_compute.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,5 @@ def struct_field(
Returns:
_description_
"""

def take(values: ArrowArrayExportable, indices: ArrowArrayExportable) -> Array: ...
3 changes: 1 addition & 2 deletions arro3-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ fn _compute(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {

m.add_wrapped(wrap_pyfunction!(cast::cast))?;
m.add_wrapped(wrap_pyfunction!(concat::concat))?;
m.add_wrapped(wrap_pyfunction!(take::take))?;

m.add_wrapped(wrap_pyfunction!(list_flatten::list_flatten))?;
m.add_wrapped(wrap_pyfunction!(list_offsets::list_offsets))?;
m.add_wrapped(wrap_pyfunction!(struct_field::struct_field))?;
m.add_wrapped(wrap_pyfunction!(take::take))?;

Ok(())
}
3 changes: 2 additions & 1 deletion arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Array:
Returns:
The sliced array
"""

def take(self, indices: ArrowArrayExportable) -> Array: ...
@property
def type(self) -> DataType:
"""The data type of this array."""
Expand Down Expand Up @@ -724,6 +724,7 @@ class RecordBatch:
@property
def shape(self) -> tuple[int, int]: ...
def slice(self, offset: int = 0, length: int | None = None) -> RecordBatch: ...
def take(self, indices: ArrowArrayExportable) -> RecordBatch: ...
def to_struct_array(self) -> Array: ...
def with_schema(self, schema: ArrowSchemaExportable) -> RecordBatch: ...

Expand Down
5 changes: 5 additions & 0 deletions pyo3-arrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ impl PyArray {
PyArray::new(new_array, self.field().clone()).to_arro3(py)
}

fn take(&self, py: Python, indices: PyArray) -> PyArrowResult<PyObject> {
let new_array = arrow::compute::take(self.as_ref(), indices.as_ref(), None)?;
Ok(PyArray::new(new_array, self.field.clone()).to_arro3(py)?)
}

/// Copy this array to a `numpy` NDArray
pub fn to_numpy(&self, py: Python) -> PyResult<PyObject> {
self.__array__(py)
Expand Down
7 changes: 6 additions & 1 deletion pyo3-arrow/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt::Display;
use std::sync::Arc;

use arrow::array::AsArray;
use arrow::compute::concat_batches;
use arrow::compute::{concat_batches, take_record_batch};
use arrow_array::{Array, ArrayRef, RecordBatch, StructArray};
use arrow_schema::{DataType, Field, Schema, SchemaBuilder};
use indexmap::IndexMap;
Expand Down Expand Up @@ -386,6 +386,11 @@ impl PyRecordBatch {
PyRecordBatch::new(self.0.slice(offset, length)).to_arro3(py)
}

fn take(&self, py: Python, indices: PyArray) -> PyArrowResult<PyObject> {
let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
Ok(PyRecordBatch::new(new_batch).to_arro3(py)?)
}

pub fn to_struct_array(&self, py: Python) -> PyArrowResult<PyObject> {
let struct_array: StructArray = self.0.clone().into();
let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
Expand Down

0 comments on commit c48f6dd

Please sign in to comment.