From f2a30b1f7b72710aeaa0efeccffbfb3a58bab9f8 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 13 Aug 2024 19:45:33 -0400 Subject: [PATCH] Rechunk table and chunked array (#126) --- arro3-core/python/arro3/core/_core.pyi | 28 ++++++++ pyo3-arrow/src/chunked.rs | 36 +++++++--- pyo3-arrow/src/table.rs | 95 +++++++++++++++++++++++++- tests/core/test_array.py | 23 +++++++ tests/core/test_table.py | 27 ++++++++ 5 files changed, 197 insertions(+), 12 deletions(-) diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index 66d3f87..eb8d9d4 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -223,6 +223,15 @@ class ChunkedArray: @property def num_chunks(self) -> int: """Number of underlying chunks.""" + def rechunk(self, *, max_chunksize: int | None = None) -> ChunkedArray: + """Rechunk a ChunkedArray with a maximum number of rows per chunk. + + Args: + max_chunksize: The maximum number of rows per internal array. Defaults to None, which rechunks into a single array. + + Returns: + The rechunked ChunkedArray. + """ def slice(self, offset: int = 0, length: int | None = None) -> ChunkedArray: """Compute zero-copy slice of this ChunkedArray @@ -1394,6 +1403,15 @@ class Table: Due to the definition of a table, all columns have the same number of rows. """ + def rechunk(self, *, max_chunksize: int | None = None) -> Table: + """Rechunk a table with a maximum number of rows per chunk. + + Args: + max_chunksize: The maximum number of rows per internal RecordBatch. Defaults to None, which rechunks into a single batch. + + Returns: + The rechunked table. + """ def remove_column(self, i: int) -> Table: """Create new Table with the indicated column removed. @@ -1450,6 +1468,16 @@ class Table: Returns: (number of rows, number of columns) """ + def slice(self, offset: int = 0, length: int | None = None) -> Table: + """Compute zero-copy slice of this table. + + Args: + offset: Defaults to 0. + length: Defaults to None. + + Returns: + The sliced table + """ def to_batches(self) -> list[RecordBatch]: """Convert Table to a list of RecordBatch objects. diff --git a/pyo3-arrow/src/chunked.rs b/pyo3-arrow/src/chunked.rs index e81d9b2..96235dd 100644 --- a/pyo3-arrow/src/chunked.rs +++ b/pyo3-arrow/src/chunked.rs @@ -113,8 +113,12 @@ impl PyChunkedArray { let chunks = chunk_lengths .iter() .map(|chunk_length| { - let sliced_chunks = self.slice(offset, *chunk_length)?; - let arr_refs = sliced_chunks.iter().map(|a| a.as_ref()).collect::>(); + let sliced_chunked_array = self.slice(offset, *chunk_length)?; + let arr_refs = sliced_chunked_array + .chunks + .iter() + .map(|a| a.as_ref()) + .collect::>(); let sliced_concatted = concat(&arr_refs)?; offset += chunk_length; Ok(sliced_concatted) @@ -124,11 +128,7 @@ impl PyChunkedArray { Ok(PyChunkedArray::try_new(chunks, self.field.clone())?) } - pub(crate) fn slice( - &self, - mut offset: usize, - mut length: usize, - ) -> PyArrowResult> { + pub(crate) fn slice(&self, mut offset: usize, mut length: usize) -> PyArrowResult { if offset + length > self.length() { return Err( PyValueError::new_err("offset + length may not exceed length of array").into(), @@ -148,7 +148,7 @@ impl PyChunkedArray { continue; } - let take_count = length.min(chunk.len()); + let take_count = length.min(chunk.len() - offset); let sliced_chunk = chunk.slice(offset, take_count); sliced_chunks.push(sliced_chunk); @@ -162,7 +162,7 @@ impl PyChunkedArray { } } - Ok(sliced_chunks) + Ok(Self::try_new(sliced_chunks, self.field.clone())?) } /// Export this to a Python `arro3.core.ChunkedArray`. @@ -376,6 +376,20 @@ impl PyChunkedArray { self.chunks.len() } + #[pyo3(signature = (*, max_chunksize=None))] + #[pyo3(name = "rechunk")] + fn rechunk_py(&self, py: Python, max_chunksize: Option) -> PyArrowResult { + let max_chunksize = max_chunksize.unwrap_or(self.len()); + let mut chunk_lengths = vec![]; + let mut offset = 0; + while offset < self.len() { + let chunk_length = max_chunksize.min(self.len() - offset); + offset += chunk_length; + chunk_lengths.push(chunk_length); + } + Ok(self.rechunk(chunk_lengths)?.to_arro3(py)?) + } + #[pyo3(signature = (offset=0, length=None))] #[pyo3(name = "slice")] fn slice_py( @@ -385,8 +399,8 @@ impl PyChunkedArray { length: Option, ) -> PyArrowResult { let length = length.unwrap_or_else(|| self.len() - offset); - let sliced_chunks = self.slice(offset, length)?; - Ok(PyChunkedArray::try_new(sliced_chunks, self.field.clone())?.to_arro3(py)?) + let sliced_chunked_array = self.slice(offset, length)?; + Ok(sliced_chunked_array.to_arro3(py)?) } fn to_numpy(&self, py: Python) -> PyResult { diff --git a/pyo3-arrow/src/table.rs b/pyo3-arrow/src/table.rs index 456d7ef..13b8ba2 100644 --- a/pyo3-arrow/src/table.rs +++ b/pyo3-arrow/src/table.rs @@ -82,6 +82,74 @@ impl PyTable { .call1(PyTuple::new_bound(py, vec![self.into_py(py)]))?; Ok(pyarrow_obj.to_object(py)) } + + pub(crate) fn rechunk(&self, chunk_lengths: Vec) -> PyArrowResult { + let total_chunk_length = chunk_lengths.iter().sum::(); + if total_chunk_length != self.num_rows() { + return Err( + PyValueError::new_err("Chunk lengths do not add up to table length").into(), + ); + } + + // If the desired rechunking is the existing chunking, return early + let matches_existing_chunking = chunk_lengths + .iter() + .zip(self.batches()) + .all(|(length, batch)| *length == batch.num_rows()); + if matches_existing_chunking { + return Ok(Self::try_new(self.batches.clone(), self.schema.clone())?); + } + + let mut offset = 0; + let batches = chunk_lengths + .iter() + .map(|chunk_length| { + let sliced_table = self.slice(offset, *chunk_length)?; + let sliced_concatted = concat_batches(&self.schema, sliced_table.batches.iter())?; + offset += chunk_length; + Ok(sliced_concatted) + }) + .collect::>>()?; + + Ok(Self::try_new(batches, self.schema.clone())?) + } + + pub(crate) fn slice(&self, mut offset: usize, mut length: usize) -> PyArrowResult { + if offset + length > self.num_rows() { + return Err( + PyValueError::new_err("offset + length may not exceed length of array").into(), + ); + } + + let mut sliced_batches: Vec = vec![]; + for chunk in self.batches() { + if chunk.num_rows() == 0 { + continue; + } + + // If the offset is greater than the len of this chunk, don't include any rows from + // this chunk + if offset >= chunk.num_rows() { + offset -= chunk.num_rows(); + continue; + } + + let take_count = length.min(chunk.num_rows() - offset); + let sliced_chunk = chunk.slice(offset, take_count); + sliced_batches.push(sliced_chunk); + + length -= take_count; + + // If we've selected all rows, exit + if length == 0 { + break; + } else { + offset = 0; + } + } + + Ok(Self::try_new(sliced_batches, self.schema.clone())?) + } } impl Display for PyTable { @@ -401,7 +469,19 @@ impl PyTable { .fold(0, |acc, batch| acc + batch.num_rows()) } - // fn rechunk(&self, py: Python, max_chunksize: usize) {} + #[pyo3(signature = (*, max_chunksize=None))] + #[pyo3(name = "rechunk")] + fn rechunk_py(&self, py: Python, max_chunksize: Option) -> PyArrowResult { + let max_chunksize = max_chunksize.unwrap_or(self.num_rows()); + let mut chunk_lengths = vec![]; + let mut offset = 0; + while offset < self.num_rows() { + let chunk_length = max_chunksize.min(self.num_rows() - offset); + offset += chunk_length; + chunk_lengths.push(chunk_length); + } + Ok(self.rechunk(chunk_lengths)?.to_arro3(py)?) + } fn remove_column(&self, py: Python, i: usize) -> PyArrowResult { let mut fields = self.schema.fields().to_vec(); @@ -507,6 +587,19 @@ impl PyTable { (self.num_rows(), self.num_columns()) } + #[pyo3(signature = (offset=0, length=None))] + #[pyo3(name = "slice")] + fn slice_py( + &self, + py: Python, + offset: usize, + length: Option, + ) -> PyArrowResult { + let length = length.unwrap_or_else(|| self.num_rows() - offset); + let sliced_chunked_array = self.slice(offset, length)?; + Ok(sliced_chunked_array.to_arro3(py)?) + } + fn to_batches(&self, py: Python) -> PyResult> { self.batches .iter() diff --git a/tests/core/test_array.py b/tests/core/test_array.py index a89395e..f00a88c 100644 --- a/tests/core/test_array.py +++ b/tests/core/test_array.py @@ -3,6 +3,26 @@ from arro3.core import Array, DataType, Table +def test_constructor(): + arr = Array([1, 2, 3], DataType.int16()) + assert pa.array(arr) == pa.array([1, 2, 3], pa.int16()) + + arr = Array((1, 2, 3), DataType.int16()) + assert pa.array(arr) == pa.array([1, 2, 3], pa.int16()) + + arr = Array([1, 2, 3], DataType.float64()) + assert pa.array(arr) == pa.array([1, 2, 3], pa.float64()) + + arr = Array(["1", "2", "3"], DataType.string()) + assert pa.array(arr) == pa.array(["1", "2", "3"], pa.string()) + + arr = Array([b"1", b"2", b"3"], DataType.binary()) + assert pa.array(arr) == pa.array([b"1", b"2", b"3"], pa.binary()) + + # arr = Array([b"1", b"2", b"3"], DataType.binary(1)) + # assert pa.array(arr) == pa.array([b"1", b"2", b"3"], pa.binary(1)) + + def test_from_numpy(): arr = np.array([1, 2, 3, 4], dtype=np.uint8) assert Array.from_numpy(arr).type == DataType.uint8() @@ -10,6 +30,9 @@ def test_from_numpy(): arr = np.array([1, 2, 3, 4], dtype=np.float64) assert Array.from_numpy(arr).type == DataType.float64() + # arr = np.array([b"1", b"2", b"3"], np.object_) + # Array.from_numpy(arr) + def test_extension_array_meta_persists(): arr = pa.array([1, 2, 3]) diff --git a/tests/core/test_table.py b/tests/core/test_table.py index 8a7e7d2..5c430c4 100644 --- a/tests/core/test_table.py +++ b/tests/core/test_table.py @@ -55,3 +55,30 @@ def test_table_from_batches_empty_columns_with_len(): no_columns = df[[]] pa_table = pa.Table.from_pandas(no_columns) _table = Table.from_batches(pa_table.to_batches()) + + +def test_rechunk(): + a = pa.chunked_array([[1, 2, 3, 4]]) + b = pa.chunked_array([["a", "b", "c", "d"]]) + table = Table.from_pydict({"a": a, "b": b}) + + rechunked1 = table.rechunk(max_chunksize=1) + assert rechunked1.chunk_lengths == [1, 1, 1, 1] + + rechunked2 = rechunked1.rechunk(max_chunksize=2) + assert rechunked2.chunk_lengths == [2, 2] + assert rechunked2.rechunk().chunk_lengths == [4] + + +def test_slice(): + a = pa.chunked_array([[1, 2], [3, 4]]) + b = pa.chunked_array([["a", "b"], ["c", "d"]]) + table = Table.from_pydict({"a": a, "b": b}) + + sliced1 = table.slice(0, 1) + assert sliced1.num_rows == 1 + assert sliced1.chunk_lengths == [1] + + sliced2 = table.slice(1, 2) + assert sliced2.num_rows == 2 + assert sliced2.chunk_lengths == [1, 1]