From 81bf8c833605020178ca5e92e1c2f0c87b69c9dd Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 29 Jul 2024 21:01:43 -0400 Subject: [PATCH] Implement cast method on Array and ChunkedArray (#79) --- arro3-core/python/arro3/core/_core.pyi | 13 +++++++++++++ pyo3-arrow/src/array.rs | 7 +++++++ pyo3-arrow/src/chunked.rs | 11 +++++++++++ 3 files changed, 31 insertions(+) diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index 14107e8..91a4a14 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -49,6 +49,13 @@ class Array: def to_numpy(self) -> NDArray: """Return a numpy copy of this array.""" + def cast(self, target_type: ArrowSchemaExportable) -> Array: + """Cast array values to another data type + + Args: + target_type: Type to cast array to. + """ + def slice(self, offset: int = 0, length: int | None = None) -> Array: """Compute zero-copy slice of this array. @@ -100,6 +107,12 @@ class ChunkedArray: @classmethod def from_arrow_pycapsule(cls, capsule) -> ChunkedArray: """Construct this object from a bare Arrow PyCapsule""" + def cast(self, target_type: ArrowSchemaExportable) -> ChunkedArray: + """Cast array values to another data type + + Args: + target_type: Type to cast array to. + """ def chunk(self, i: int) -> Array: ... @property def chunks(self) -> list[Array]: ... diff --git a/pyo3-arrow/src/array.rs b/pyo3-arrow/src/array.rs index 84faa8e..08b84d0 100644 --- a/pyo3-arrow/src/array.rs +++ b/pyo3-arrow/src/array.rs @@ -257,6 +257,13 @@ impl PyArray { Ok(Self::from_array_ref(arrow_array)) } + fn cast(&self, py: Python, target_type: PyDataType) -> PyArrowResult { + let target_type = target_type.into_inner(); + let new_array = arrow::compute::cast(self.as_ref(), &target_type)?; + let new_field = self.field.as_ref().clone().with_data_type(target_type); + Ok(PyArray::new(new_array, new_field.into()).to_arro3(py)?) + } + #[pyo3(signature = (offset=0, length=None))] pub fn slice(&self, py: Python, offset: usize, length: Option) -> PyResult { let length = length.unwrap_or_else(|| self.array.len() - offset); diff --git a/pyo3-arrow/src/chunked.rs b/pyo3-arrow/src/chunked.rs index 68f9e4e..c215d37 100644 --- a/pyo3-arrow/src/chunked.rs +++ b/pyo3-arrow/src/chunked.rs @@ -309,6 +309,17 @@ impl PyChunkedArray { Ok(PyChunkedArray::new(chunks, field)) } + fn cast(&self, py: Python, target_type: PyDataType) -> PyArrowResult { + let target_type = target_type.into_inner(); + let new_chunks = self + .chunks + .iter() + .map(|chunk| arrow::compute::cast(&chunk, &target_type)) + .collect::, ArrowError>>()?; + let new_field = self.field.as_ref().clone().with_data_type(target_type); + Ok(PyChunkedArray::new(new_chunks, new_field.into()).to_arro3(py)?) + } + pub fn chunk(&self, py: Python, i: usize) -> PyResult { let field = self.field().clone(); let array = self