From 770d17b4436d33756f0009c56d65fb5e1994150c Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Fri, 8 Sep 2023 16:32:44 -0700 Subject: [PATCH] feat: Async back-pressure on adding to a source --- Cargo.lock | 11 ++ Cargo.toml | 1 + crates/sparrow-merge/Cargo.toml | 1 + crates/sparrow-merge/src/in_memory_batches.rs | 31 +++-- .../sparrow-runtime/src/key_hash_inverse.rs | 21 ---- .../sparrow-runtime/src/prepare/preparer.rs | 16 +-- crates/sparrow-session/src/table.rs | 6 +- python/Cargo.lock | 11 ++ python/pysrc/kaskada/_ffi.pyi | 3 +- python/pysrc/kaskada/sources/arrow.py | 115 +++++++++++++++--- python/pysrc/kaskada/sources/source.py | 2 +- python/pytests/csv_string_source_test.py | 9 +- python/src/table.rs | 32 +++-- 13 files changed, 179 insertions(+), 80 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4d8eea26a..10d9cc725 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -421,6 +421,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "async-broadcast" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c48ccdbf6ca6b121e0f586cbc0e73ae440e56c67c30fa0873b4e110d9c26d2b" +dependencies = [ + "event-listener", + "futures-core", +] + [[package]] name = "async-channel" version = "1.9.0" @@ -4598,6 +4608,7 @@ dependencies = [ "arrow-ord", "arrow-schema", "arrow-select", + "async-broadcast", "async-stream", "bit-set", "derive_more", diff --git a/Cargo.toml b/Cargo.toml index 1a6abd684..758966c74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ arrow-ord = { version = "43.0.0" } arrow-schema = { version = "43.0.0", features = ["serde"] } arrow-select = { version = "43.0.0" } arrow-string = { version = "43.0.0" } +async-broadcast = "0.5.1" async-once-cell = "0.5.3" async-stream = "0.3.4" async-trait = "0.1.68" diff --git a/crates/sparrow-merge/Cargo.toml b/crates/sparrow-merge/Cargo.toml index 71ba75939..4af9ae59a 100644 --- a/crates/sparrow-merge/Cargo.toml +++ b/crates/sparrow-merge/Cargo.toml @@ -19,6 +19,7 @@ arrow-array.workspace = true arrow-csv = { workspace = true, optional = true } arrow-schema.workspace = true arrow-select.workspace = true +async-broadcast.workspace = true async-stream.workspace = true bit-set.workspace = true derive_more.workspace = true diff --git a/crates/sparrow-merge/src/in_memory_batches.rs b/crates/sparrow-merge/src/in_memory_batches.rs index d5207716f..19a43707b 100644 --- a/crates/sparrow-merge/src/in_memory_batches.rs +++ b/crates/sparrow-merge/src/in_memory_batches.rs @@ -22,10 +22,10 @@ impl error_stack::Context for Error {} pub struct InMemoryBatches { retained: bool, current: RwLock, - updates: tokio::sync::broadcast::Sender<(usize, RecordBatch)>, + sender: async_broadcast::Sender<(usize, RecordBatch)>, /// A subscriber that is never used -- it exists only to keep the sender /// alive. - _subscriber: tokio::sync::broadcast::Receiver<(usize, RecordBatch)>, + _receiver: async_broadcast::InactiveReceiver<(usize, RecordBatch)>, } #[derive(Debug)] @@ -62,20 +62,24 @@ impl Current { impl InMemoryBatches { pub fn new(retained: bool, schema: SchemaRef) -> Self { - let (updates, _subscriber) = tokio::sync::broadcast::channel(10); + let (mut sender, receiver) = async_broadcast::broadcast(10); + + // Don't wait for a receiver. If no-one receives, `send` will fail. + sender.set_await_active(false); + let current = RwLock::new(Current::new(schema.clone())); Self { retained, current, - updates, - _subscriber, + sender, + _receiver: receiver.deactivate(), } } /// Add a batch, merging it into the in-memory version. /// /// Publishes the new batch to the subscribers. - pub fn add_batch(&self, batch: RecordBatch) -> error_stack::Result<(), Error> { + pub async fn add_batch(&self, batch: RecordBatch) -> error_stack::Result<(), Error> { if batch.num_rows() == 0 { return Ok(()); } @@ -89,10 +93,11 @@ impl InMemoryBatches { write.version }; - self.updates - .send((new_version, batch)) - .into_report() - .change_context(Error::Add)?; + let send_result = self.sender.broadcast((new_version, batch)).await; + if send_result.is_err() { + assert!(!self.sender.is_closed()); + tracing::info!("No-one subscribed for new batch"); + } Ok(()) } @@ -107,7 +112,7 @@ impl InMemoryBatches { let read = self.current.read().unwrap(); (read.version, read.batch.clone()) }; - let mut recv = self.updates.subscribe(); + let mut recv = self.sender.new_receiver(); async_stream::try_stream! { tracing::info!("Starting subscriber with version {version}"); @@ -124,11 +129,11 @@ impl InMemoryBatches { tracing::warn!("Ignoring old version {recv_version}"); } } - Err(tokio::sync::broadcast::error::RecvError::Closed) => { + Err(async_broadcast::RecvError::Closed) => { tracing::info!("Sender closed."); break; }, - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { + Err(async_broadcast::RecvError::Overflowed(_)) => { Err(Error::ReceiverLagged)?; } } diff --git a/crates/sparrow-runtime/src/key_hash_inverse.rs b/crates/sparrow-runtime/src/key_hash_inverse.rs index 055bacf59..69c3b4b71 100644 --- a/crates/sparrow-runtime/src/key_hash_inverse.rs +++ b/crates/sparrow-runtime/src/key_hash_inverse.rs @@ -321,27 +321,6 @@ impl ThreadSafeKeyHashInverse { } } - pub fn blocking_add( - &self, - keys: &dyn Array, - key_hashes: &UInt64Array, - ) -> error_stack::Result<(), Error> { - error_stack::ensure!( - keys.len() == key_hashes.len(), - Error::MismatchedLengths { - keys: keys.len(), - key_hashes: key_hashes.len() - } - ); - let has_new_keys = self.key_map.blocking_read().has_new_keys(key_hashes); - - if has_new_keys { - self.key_map.blocking_write().add(keys, key_hashes) - } else { - Ok(()) - } - } - /// Stores the KeyHashInverse to the compute store. /// /// This method is thread-safe and acquires the read-lock. diff --git a/crates/sparrow-runtime/src/prepare/preparer.rs b/crates/sparrow-runtime/src/prepare/preparer.rs index 40859a882..91b2841dc 100644 --- a/crates/sparrow-runtime/src/prepare/preparer.rs +++ b/crates/sparrow-runtime/src/prepare/preparer.rs @@ -1,3 +1,4 @@ +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use arrow::array::{ArrayRef, UInt64Array}; @@ -31,7 +32,7 @@ pub struct Preparer { prepared_schema: SchemaRef, time_column_name: String, subsort_column_name: Option, - next_subsort: u64, + next_subsort: AtomicU64, key_column_name: String, time_multiplier: Option, } @@ -51,7 +52,7 @@ impl Preparer { prepared_schema, time_column_name, subsort_column_name, - next_subsort: prepare_hash, + next_subsort: prepare_hash.into(), key_column_name, time_multiplier, }) @@ -66,10 +67,7 @@ impl Preparer { /// - This computes and adds the key columns. /// - This sorts the batch by time, subsort and key hash. /// - This adds or casts columns as needed. - /// - /// Self is mutated as necessary to ensure the `subsort` column is increasing, if - /// it is added. - pub fn prepare_batch(&mut self, batch: RecordBatch) -> error_stack::Result { + pub fn prepare_batch(&self, batch: RecordBatch) -> error_stack::Result { let time = get_required_column(&batch, &self.time_column_name)?; let time = cast_to_timestamp(time, self.time_multiplier)?; @@ -80,8 +78,10 @@ impl Preparer { .into_report() .change_context_lazy(|| Error::ConvertSubsort(subsort.data_type().clone()))? } else { - let subsort: UInt64Array = (self.next_subsort..).take(num_rows).collect(); - self.next_subsort += num_rows as u64; + let subsort_start = self + .next_subsort + .fetch_add(num_rows as u64, Ordering::SeqCst); + let subsort: UInt64Array = (subsort_start..).take(num_rows).collect(); Arc::new(subsort) }; diff --git a/crates/sparrow-session/src/table.rs b/crates/sparrow-session/src/table.rs index e39fbba65..cf0a7c2be 100644 --- a/crates/sparrow-session/src/table.rs +++ b/crates/sparrow-session/src/table.rs @@ -66,7 +66,7 @@ impl Table { self.preparer.schema() } - pub fn add_data(&mut self, batch: RecordBatch) -> error_stack::Result<(), Error> { + pub async fn add_data(&self, batch: RecordBatch) -> error_stack::Result<(), Error> { let prepared = self .preparer .prepare_batch(batch) @@ -75,11 +75,13 @@ impl Table { let key_hashes = prepared.column(2).as_primitive(); let keys = prepared.column(self.key_column); self.key_hash_inverse - .blocking_add(keys.as_ref(), key_hashes) + .add(keys.as_ref(), key_hashes) + .await .change_context(Error::Prepare)?; self.in_memory_batches .add_batch(prepared) + .await .change_context(Error::Prepare)?; Ok(()) } diff --git a/python/Cargo.lock b/python/Cargo.lock index 4434298d3..7ca902638 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -393,6 +393,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "async-broadcast" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c48ccdbf6ca6b121e0f586cbc0e73ae440e56c67c30fa0873b4e110d9c26d2b" +dependencies = [ + "event-listener", + "futures-core", +] + [[package]] name = "async-channel" version = "1.9.0" @@ -3850,6 +3860,7 @@ dependencies = [ "arrow-csv", "arrow-schema", "arrow-select", + "async-broadcast", "async-stream", "bit-set", "derive_more", diff --git a/python/pysrc/kaskada/_ffi.pyi b/python/pysrc/kaskada/_ffi.pyi index 7da90e133..13c09757c 100644 --- a/python/pysrc/kaskada/_ffi.pyi +++ b/python/pysrc/kaskada/_ffi.pyi @@ -45,7 +45,7 @@ class Expr: def execute(self, options: Optional[_ExecutionOptions] = None) -> Execution: ... def grouping(self) -> Optional[str]: ... -class Table(Expr): +class Table: def __init__( self, session: Session, @@ -61,6 +61,7 @@ class Table(Expr): @property def name(self) -> str: ... def add_pyarrow(self, data: pa.RecordBatch) -> None: ... + def expr(self) -> Expr: ... class Udf(object): def __init__(self, result_ty: str, result_fn: Callable[..., pa.Array]) -> None: ... diff --git a/python/pysrc/kaskada/sources/arrow.py b/python/pysrc/kaskada/sources/arrow.py index 36fb7a57b..6a012cf82 100644 --- a/python/pysrc/kaskada/sources/arrow.py +++ b/python/pysrc/kaskada/sources/arrow.py @@ -66,20 +66,18 @@ class PyDict(Source): def __init__( self, - rows: dict | list[dict], *, time_column: str, key_column: str, + schema: pa.Schema, retained: bool = True, subsort_column: Optional[str] = None, - schema: Optional[pa.Schema] = None, grouping_name: Optional[str] = None, time_unit: Optional[TimeUnit] = None, ) -> None: """Create a source reading from rows represented as dicts. Args: - rows: One or more rows represented as dicts. time_column: The name of the column containing the time. key_column: The name of the column containing the key. retained: Whether added rows should be retained for queries. @@ -96,8 +94,6 @@ def __init__( time_unit: The unit of the time column. One of `ns`, `us`, `ms`, or `s`. If not specified (and not specified in the data), nanosecond will be assumed. """ - if schema is None: - schema = pa.Table.from_pylist(rows).schema super().__init__( retained=retained, schema=schema, @@ -109,15 +105,61 @@ def __init__( ) self._convert_options = pyarrow.csv.ConvertOptions(column_types=schema) - self.add_rows(rows) - def add_rows(self, rows: dict | list[dict]) -> None: + @staticmethod + async def create( + *, + time_column: str, + key_column: str, + retained: bool = True, + rows: Optional[dict | list[dict]] = None, + subsort_column: Optional[str] = None, + schema: Optional[pa.Schema] = None, + grouping_name: Optional[str] = None, + time_unit: Optional[TimeUnit] = None) -> PyDict: + """Create a source reading from rows represented as dicts. + + Args: + time_column: The name of the column containing the time. + key_column: The name of the column containing the key. + retained: Whether added rows should be retained for queries. + If True, rows (both provided to the constructor and added later) will be retained + for interactive queries. If False, rows will be discarded after being sent to any + running materializations. Consider setting this to False when the source will only + be used for materialization to avoid unnecessary memory usage. + subsort_column: The name of the column containing the subsort. + If not provided, the subsort will be assigned by the system. + schema: The schema to use. If not provided, it will be inferred from the input. + grouping_name: The name of the group associated with each key. + This is used to ensure implicit joins are only performed between data grouped + by the same entity. + time_unit: The unit of the time column. One of `ns`, `us`, `ms`, or `s`. + If not specified (and not specified in the data), nanosecond will be assumed. + """ + if schema is None: + if rows is None: + raise ValueError("Must provide schema or rows") + schema = pa.Table.from_pylist(rows).schema + source = PyDict( + time_column=time_column, + key_column=key_column, + retained=retained, + subsort_column=subsort_column, + schema=schema, + grouping_name=grouping_name, + time_unit=time_unit, + ) + if rows: + await source.add_rows(rows) + return source + + async def add_rows(self, rows: dict | list[dict]) -> None: """Add data to the source.""" if isinstance(rows, dict): rows = [rows] table = pa.Table.from_pylist(rows, schema=self._schema) for batch in table.to_batches(): - self._ffi_table.add_pyarrow(batch) + await self._ffi_table.add_pyarrow(batch) # TODO: We should be able to go straight from CSV to PyArrow, but @@ -127,19 +169,17 @@ class CsvString(Source): def __init__( self, - csv_string: str | BytesIO, *, + schema: pa.Schema, time_column: str, key_column: str, subsort_column: Optional[str] = None, - schema: Optional[pa.Schema] = None, grouping_name: Optional[str] = None, time_unit: Optional[TimeUnit] = None, ) -> None: """Create a CSV String Source. Args: - csv_string: The CSV string to start from. time_column: The name of the column containing the time. key_column: The name of the column containing the key. subsort_column: The name of the column containing the subsort. @@ -151,11 +191,6 @@ def __init__( time_unit: The unit of the time column. One of `ns`, `us`, `ms`, or `s`. If not specified (and not specified in the data), nanosecond will be assumed. """ - if isinstance(csv_string, str): - csv_string = BytesIO(csv_string.encode("utf-8")) - if schema is None: - schema = pa.csv.read_csv(csv_string).schema - csv_string.seek(0) super().__init__( schema=schema, time_column=time_column, @@ -169,15 +204,57 @@ def __init__( column_types=schema, strings_can_be_null=True, ) - self.add_string(csv_string) - def add_string(self, csv_string: str | BytesIO) -> None: + @staticmethod + async def create( + csv_string: Optional[str | BytesIO] = None, + *, + schema: Optional[pa.Schema] = None, + time_column: str, + key_column: str, + subsort_column: Optional[str] = None, + grouping_name: Optional[str] = None, + time_unit: Optional[TimeUnit] = None, + ) -> CsvString: + """Create a CSV String Source with data. + + Args: + csv_string: The CSV string to start from. + time_column: The name of the column containing the time. + key_column: The name of the column containing the key. + subsort_column: The name of the column containing the subsort. + If not provided, the subsort will be assigned by the system. + schema: The schema to use. If not provided, it will be inferred from the input. + grouping_name: The name of the group associated with each key. + This is used to ensure implicit joins are only performed between data grouped + by the same entity. + time_unit: The unit of the time column. One of `ns`, `us`, `ms`, or `s`. + If not specified (and not specified in the data), nanosecond will be assumed. + """ + if isinstance(csv_string, str): + csv_string = BytesIO(csv_string.encode("utf-8")) + if schema is None: + if csv_string is None: + raise ValueError("Must provide schema or csv_string") + schema = pa.csv.read_csv(csv_string).schema + csv_string.seek(0) + source = CsvString( + schema=schema, + time_column=time_column, + key_column=key_column, + subsort_column=subsort_column, + grouping_name=grouping_name, + time_unit=time_unit) + await source.add_string(csv_string) + return source + + async def add_string(self, csv_string: str | BytesIO) -> None: """Add data to the source.""" if isinstance(csv_string, str): csv_string = BytesIO(csv_string.encode("utf-8")) content = pa.csv.read_csv(csv_string, convert_options=self._convert_options) for batch in content.to_batches(): - self._ffi_table.add_pyarrow(batch) + await self._ffi_table.add_pyarrow(batch) class JsonlString(Source): diff --git a/python/pysrc/kaskada/sources/source.py b/python/pysrc/kaskada/sources/source.py index 2ceddfb69..873e96772 100644 --- a/python/pysrc/kaskada/sources/source.py +++ b/python/pysrc/kaskada/sources/source.py @@ -68,7 +68,7 @@ def fix_field(field: pa.Field) -> pa.Field: grouping_name, time_unit, ) - super().__init__(ffi_table) + super().__init__(ffi_table.expr()) self._schema = schema self._ffi_table = ffi_table diff --git a/python/pytests/csv_string_source_test.py b/python/pytests/csv_string_source_test.py index 72677e38c..362fe52ef 100644 --- a/python/pytests/csv_string_source_test.py +++ b/python/pytests/csv_string_source_test.py @@ -1,7 +1,8 @@ import kaskada as kd +import pytest - -def test_read_csv(golden) -> None: +@pytest.mark.asyncio +async def test_read_csv(golden) -> None: content1 = "\n".join( [ "time,key,m,n", @@ -24,12 +25,12 @@ def test_read_csv(golden) -> None: "1996-12-19T17:40:02,A,,", ] ) - source = kd.sources.CsvString( + source = await kd.sources.CsvString.create( content1, time_column="time", key_column="key", ) golden.jsonl(source) - source.add_string(content2) + await source.add_string(content2) golden.jsonl(source) diff --git a/python/src/table.rs b/python/src/table.rs index 376973a49..4edfa1b62 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -10,11 +10,12 @@ use crate::error::Result; use crate::expr::Expr; use crate::session::Session; -#[pyclass(extends=Expr, subclass)] +#[pyclass] pub(crate) struct Table { #[pyo3(get)] name: String, - rust_table: RustTable, + rust_table: Arc, + session: Session, } #[pymethods] @@ -33,7 +34,7 @@ impl Table { subsort_column: Option<&str>, grouping_name: Option<&str>, time_unit: Option<&str>, - ) -> Result<(Self, Expr)> { + ) -> Result { let raw_schema = Arc::new(schema.0); let rust_table = session.rust_session()?.add_table( @@ -47,23 +48,32 @@ impl Table { time_unit, )?; - let rust_expr = rust_table.expr.clone(); - let table = Table { name, rust_table }; - let expr = Expr { rust_expr, session }; - Ok((table, expr)) + let table = Table { name, rust_table: Arc::new(rust_table), session }; + Ok(table) } - /// Add PyArrow data to the given table. + fn expr(&self) -> Expr { + let rust_expr = self.rust_table.expr.clone(); + Expr { rust_expr, session: self.session.clone() } + } + /// Add PyArrow data to the given table. /// /// TODO: Support other kinds of data: /// - pyarrow RecordBatchReader /// - Parquet file URLs /// - Python generators? /// TODO: Error handling - fn add_pyarrow(&mut self, data: &PyAny) -> Result<()> { + fn add_pyarrow<'py>(&self, data: &'py PyAny, py: Python<'py>) -> Result<&'py PyAny> { let data = RecordBatch::from_pyarrow(data)?; - self.rust_table.add_data(data)?; - Ok(()) + + let rust_table = self.rust_table.clone(); + Ok(pyo3_asyncio::tokio::future_into_py(py, async move { + let result = rust_table.add_data(data).await; + Python::with_gil(|py| { + result.unwrap(); + Ok(py.None()) + }) + })?) } }