Skip to content

Commit

Permalink
Rechunk table and chunked array (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Aug 13, 2024
1 parent 16210f5 commit f2a30b1
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 12 deletions.
28 changes: 28 additions & 0 deletions arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,15 @@ class ChunkedArray:
@property
def num_chunks(self) -> int:
"""Number of underlying chunks."""
def rechunk(self, *, max_chunksize: int | None = None) -> ChunkedArray:
"""Rechunk a ChunkedArray with a maximum number of rows per chunk.
Args:
max_chunksize: The maximum number of rows per internal array. Defaults to None, which rechunks into a single array.
Returns:
The rechunked ChunkedArray.
"""
def slice(self, offset: int = 0, length: int | None = None) -> ChunkedArray:
"""Compute zero-copy slice of this ChunkedArray
Expand Down Expand Up @@ -1394,6 +1403,15 @@ class Table:
Due to the definition of a table, all columns have the same number of rows.
"""
def rechunk(self, *, max_chunksize: int | None = None) -> Table:
"""Rechunk a table with a maximum number of rows per chunk.
Args:
max_chunksize: The maximum number of rows per internal RecordBatch. Defaults to None, which rechunks into a single batch.
Returns:
The rechunked table.
"""
def remove_column(self, i: int) -> Table:
"""Create new Table with the indicated column removed.
Expand Down Expand Up @@ -1450,6 +1468,16 @@ class Table:
Returns:
(number of rows, number of columns)
"""
def slice(self, offset: int = 0, length: int | None = None) -> Table:
"""Compute zero-copy slice of this table.
Args:
offset: Defaults to 0.
length: Defaults to None.
Returns:
The sliced table
"""
def to_batches(self) -> list[RecordBatch]:
"""Convert Table to a list of RecordBatch objects.
Expand Down
36 changes: 25 additions & 11 deletions pyo3-arrow/src/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,12 @@ impl PyChunkedArray {
let chunks = chunk_lengths
.iter()
.map(|chunk_length| {
let sliced_chunks = self.slice(offset, *chunk_length)?;
let arr_refs = sliced_chunks.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
let sliced_chunked_array = self.slice(offset, *chunk_length)?;
let arr_refs = sliced_chunked_array
.chunks
.iter()
.map(|a| a.as_ref())
.collect::<Vec<_>>();
let sliced_concatted = concat(&arr_refs)?;
offset += chunk_length;
Ok(sliced_concatted)
Expand All @@ -124,11 +128,7 @@ impl PyChunkedArray {
Ok(PyChunkedArray::try_new(chunks, self.field.clone())?)
}

pub(crate) fn slice(
&self,
mut offset: usize,
mut length: usize,
) -> PyArrowResult<Vec<ArrayRef>> {
pub(crate) fn slice(&self, mut offset: usize, mut length: usize) -> PyArrowResult<Self> {
if offset + length > self.length() {
return Err(
PyValueError::new_err("offset + length may not exceed length of array").into(),
Expand All @@ -148,7 +148,7 @@ impl PyChunkedArray {
continue;
}

let take_count = length.min(chunk.len());
let take_count = length.min(chunk.len() - offset);
let sliced_chunk = chunk.slice(offset, take_count);
sliced_chunks.push(sliced_chunk);

Expand All @@ -162,7 +162,7 @@ impl PyChunkedArray {
}
}

Ok(sliced_chunks)
Ok(Self::try_new(sliced_chunks, self.field.clone())?)
}

/// Export this to a Python `arro3.core.ChunkedArray`.
Expand Down Expand Up @@ -376,6 +376,20 @@ impl PyChunkedArray {
self.chunks.len()
}

#[pyo3(signature = (*, max_chunksize=None))]
#[pyo3(name = "rechunk")]
fn rechunk_py(&self, py: Python, max_chunksize: Option<usize>) -> PyArrowResult<PyObject> {
let max_chunksize = max_chunksize.unwrap_or(self.len());
let mut chunk_lengths = vec![];
let mut offset = 0;
while offset < self.len() {
let chunk_length = max_chunksize.min(self.len() - offset);
offset += chunk_length;
chunk_lengths.push(chunk_length);
}
Ok(self.rechunk(chunk_lengths)?.to_arro3(py)?)
}

#[pyo3(signature = (offset=0, length=None))]
#[pyo3(name = "slice")]
fn slice_py(
Expand All @@ -385,8 +399,8 @@ impl PyChunkedArray {
length: Option<usize>,
) -> PyArrowResult<PyObject> {
let length = length.unwrap_or_else(|| self.len() - offset);
let sliced_chunks = self.slice(offset, length)?;
Ok(PyChunkedArray::try_new(sliced_chunks, self.field.clone())?.to_arro3(py)?)
let sliced_chunked_array = self.slice(offset, length)?;
Ok(sliced_chunked_array.to_arro3(py)?)
}

fn to_numpy(&self, py: Python) -> PyResult<PyObject> {
Expand Down
95 changes: 94 additions & 1 deletion pyo3-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,74 @@ impl PyTable {
.call1(PyTuple::new_bound(py, vec![self.into_py(py)]))?;
Ok(pyarrow_obj.to_object(py))
}

pub(crate) fn rechunk(&self, chunk_lengths: Vec<usize>) -> PyArrowResult<Self> {
let total_chunk_length = chunk_lengths.iter().sum::<usize>();
if total_chunk_length != self.num_rows() {
return Err(
PyValueError::new_err("Chunk lengths do not add up to table length").into(),
);
}

// If the desired rechunking is the existing chunking, return early
let matches_existing_chunking = chunk_lengths
.iter()
.zip(self.batches())
.all(|(length, batch)| *length == batch.num_rows());
if matches_existing_chunking {
return Ok(Self::try_new(self.batches.clone(), self.schema.clone())?);
}

let mut offset = 0;
let batches = chunk_lengths
.iter()
.map(|chunk_length| {
let sliced_table = self.slice(offset, *chunk_length)?;
let sliced_concatted = concat_batches(&self.schema, sliced_table.batches.iter())?;
offset += chunk_length;
Ok(sliced_concatted)
})
.collect::<PyArrowResult<Vec<_>>>()?;

Ok(Self::try_new(batches, self.schema.clone())?)
}

pub(crate) fn slice(&self, mut offset: usize, mut length: usize) -> PyArrowResult<Self> {
if offset + length > self.num_rows() {
return Err(
PyValueError::new_err("offset + length may not exceed length of array").into(),
);
}

let mut sliced_batches: Vec<RecordBatch> = vec![];
for chunk in self.batches() {
if chunk.num_rows() == 0 {
continue;
}

// If the offset is greater than the len of this chunk, don't include any rows from
// this chunk
if offset >= chunk.num_rows() {
offset -= chunk.num_rows();
continue;
}

let take_count = length.min(chunk.num_rows() - offset);
let sliced_chunk = chunk.slice(offset, take_count);
sliced_batches.push(sliced_chunk);

length -= take_count;

// If we've selected all rows, exit
if length == 0 {
break;
} else {
offset = 0;
}
}

Ok(Self::try_new(sliced_batches, self.schema.clone())?)
}
}

impl Display for PyTable {
Expand Down Expand Up @@ -401,7 +469,19 @@ impl PyTable {
.fold(0, |acc, batch| acc + batch.num_rows())
}

// fn rechunk(&self, py: Python, max_chunksize: usize) {}
#[pyo3(signature = (*, max_chunksize=None))]
#[pyo3(name = "rechunk")]
fn rechunk_py(&self, py: Python, max_chunksize: Option<usize>) -> PyArrowResult<PyObject> {
let max_chunksize = max_chunksize.unwrap_or(self.num_rows());
let mut chunk_lengths = vec![];
let mut offset = 0;
while offset < self.num_rows() {
let chunk_length = max_chunksize.min(self.num_rows() - offset);
offset += chunk_length;
chunk_lengths.push(chunk_length);
}
Ok(self.rechunk(chunk_lengths)?.to_arro3(py)?)
}

fn remove_column(&self, py: Python, i: usize) -> PyArrowResult<PyObject> {
let mut fields = self.schema.fields().to_vec();
Expand Down Expand Up @@ -507,6 +587,19 @@ impl PyTable {
(self.num_rows(), self.num_columns())
}

#[pyo3(signature = (offset=0, length=None))]
#[pyo3(name = "slice")]
fn slice_py(
&self,
py: Python,
offset: usize,
length: Option<usize>,
) -> PyArrowResult<PyObject> {
let length = length.unwrap_or_else(|| self.num_rows() - offset);
let sliced_chunked_array = self.slice(offset, length)?;
Ok(sliced_chunked_array.to_arro3(py)?)
}

fn to_batches(&self, py: Python) -> PyResult<Vec<PyObject>> {
self.batches
.iter()
Expand Down
23 changes: 23 additions & 0 deletions tests/core/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,36 @@
from arro3.core import Array, DataType, Table


def test_constructor():
arr = Array([1, 2, 3], DataType.int16())
assert pa.array(arr) == pa.array([1, 2, 3], pa.int16())

arr = Array((1, 2, 3), DataType.int16())
assert pa.array(arr) == pa.array([1, 2, 3], pa.int16())

arr = Array([1, 2, 3], DataType.float64())
assert pa.array(arr) == pa.array([1, 2, 3], pa.float64())

arr = Array(["1", "2", "3"], DataType.string())
assert pa.array(arr) == pa.array(["1", "2", "3"], pa.string())

arr = Array([b"1", b"2", b"3"], DataType.binary())
assert pa.array(arr) == pa.array([b"1", b"2", b"3"], pa.binary())

# arr = Array([b"1", b"2", b"3"], DataType.binary(1))
# assert pa.array(arr) == pa.array([b"1", b"2", b"3"], pa.binary(1))


def test_from_numpy():
arr = np.array([1, 2, 3, 4], dtype=np.uint8)
assert Array.from_numpy(arr).type == DataType.uint8()

arr = np.array([1, 2, 3, 4], dtype=np.float64)
assert Array.from_numpy(arr).type == DataType.float64()

# arr = np.array([b"1", b"2", b"3"], np.object_)
# Array.from_numpy(arr)


def test_extension_array_meta_persists():
arr = pa.array([1, 2, 3])
Expand Down
27 changes: 27 additions & 0 deletions tests/core/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,30 @@ def test_table_from_batches_empty_columns_with_len():
no_columns = df[[]]
pa_table = pa.Table.from_pandas(no_columns)
_table = Table.from_batches(pa_table.to_batches())


def test_rechunk():
a = pa.chunked_array([[1, 2, 3, 4]])
b = pa.chunked_array([["a", "b", "c", "d"]])
table = Table.from_pydict({"a": a, "b": b})

rechunked1 = table.rechunk(max_chunksize=1)
assert rechunked1.chunk_lengths == [1, 1, 1, 1]

rechunked2 = rechunked1.rechunk(max_chunksize=2)
assert rechunked2.chunk_lengths == [2, 2]
assert rechunked2.rechunk().chunk_lengths == [4]


def test_slice():
a = pa.chunked_array([[1, 2], [3, 4]])
b = pa.chunked_array([["a", "b"], ["c", "d"]])
table = Table.from_pydict({"a": a, "b": b})

sliced1 = table.slice(0, 1)
assert sliced1.num_rows == 1
assert sliced1.chunk_lengths == [1]

sliced2 = table.slice(1, 2)
assert sliced2.num_rows == 2
assert sliced2.chunk_lengths == [1, 1]

0 comments on commit f2a30b1

Please sign in to comment.