Skip to content

Commit

Permalink
Support async iteration of RecordBatchStream (#975)
Browse files Browse the repository at this point in the history
* Support async iteration of RecordBatchStream

* use __anext__

* use await

* fix failing test

* Since we are raising an error instead of returning a None, we can update the type hint.

---------

Co-authored-by: Tim Saucer <timsaucer@gmail.com>
  • Loading branch information
kylebarron and timsaucer authored Jan 9, 2025
1 parent 389164a commit 4b262be
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 19 deletions.
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", "sync"] }
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
arrow = { version = "53", features = ["pyarrow"] }
datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "43.0.0", optional = true }
Expand All @@ -60,4 +61,4 @@ crate-type = ["cdylib", "rlib"]

[profile.release]
lto = true
codegen-units = 1
codegen-units = 1
16 changes: 10 additions & 6 deletions python/datafusion/record_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,24 @@ def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None:
"""This constructor is typically not called by the end user."""
self.rbs = record_batch_stream

def next(self) -> RecordBatch | None:
def next(self) -> RecordBatch:
"""See :py:func:`__next__` for the iterator function."""
try:
next_batch = next(self)
except StopIteration:
return None
return next(self)

return next_batch
async def __anext__(self) -> RecordBatch:
"""Async iterator function."""
next_batch = await self.rbs.__anext__()
return RecordBatch(next_batch)

def __next__(self) -> RecordBatch:
"""Iterator function."""
next_batch = next(self.rbs)
return RecordBatch(next_batch)

def __aiter__(self) -> typing_extensions.Self:
"""Async iterator function."""
return self

def __iter__(self) -> typing_extensions.Self:
"""Iterator function."""
return self
4 changes: 2 additions & 2 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,8 @@ def test_execution_plan(aggregate_df):
batch = stream.next()
assert batch is not None
# there should be no more batches
batch = stream.next()
assert batch is None
with pytest.raises(StopIteration):
stream.next()


def test_repartition(df):
Expand Down
51 changes: 41 additions & 10 deletions src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::utils::wait_for_future;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::StreamExt;
use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
use tokio::sync::Mutex;

#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
pub struct PyRecordBatch {
Expand All @@ -43,31 +47,58 @@ impl From<RecordBatch> for PyRecordBatch {

#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
pub struct PyRecordBatchStream {
stream: SendableRecordBatchStream,
stream: Arc<Mutex<SendableRecordBatchStream>>,
}

impl PyRecordBatchStream {
pub fn new(stream: SendableRecordBatchStream) -> Self {
Self { stream }
Self {
stream: Arc::new(Mutex::new(stream)),
}
}
}

#[pymethods]
impl PyRecordBatchStream {
fn next(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
let result = self.stream.next();
match wait_for_future(py, result) {
None => Ok(None),
Some(Ok(b)) => Ok(Some(b.into())),
Some(Err(e)) => Err(e.into()),
}
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
let stream = self.stream.clone();
wait_for_future(py, next_stream(stream, true))
}

fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
self.next(py)
}

fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let stream = self.stream.clone();
pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false))
}

fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
}

async fn next_stream(
stream: Arc<Mutex<SendableRecordBatchStream>>,
sync: bool,
) -> PyResult<PyRecordBatch> {
let mut stream = stream.lock().await;
match stream.next().await {
Some(Ok(batch)) => Ok(batch.into()),
Some(Err(e)) => Err(e.into()),
None => {
// Depending on whether the iteration is sync or not, we raise either a
// StopIteration or a StopAsyncIteration
if sync {
Err(PyStopIteration::new_err("stream exhausted"))
} else {
Err(PyStopAsyncIteration::new_err("stream exhausted"))
}
}
}
}

0 comments on commit 4b262be

Please sign in to comment.