Skip to content

Commit

Permalink
feat: Configurable retention on PyDict sources (#744)
Browse files Browse the repository at this point in the history
This renames the `PyList` source to `PyDict`.

This allows the in-memory retention to be disabled for any source, but
exposes this for the `PyDict` source, which is often used with
materializations.
  • Loading branch information
bjchambers authored Sep 8, 2023
1 parent 2d261b5 commit d0b0f85
Show file tree
Hide file tree
Showing 26 changed files with 147 additions and 58 deletions.
69 changes: 48 additions & 21 deletions crates/sparrow-merge/src/in_memory_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,53 @@ impl error_stack::Context for Error {}
/// Struct for managing in-memory batches.
#[derive(Debug)]
pub struct InMemoryBatches {
pub schema: SchemaRef,
current: RwLock<(usize, RecordBatch)>,
retained: bool,
current: RwLock<Current>,
updates: tokio::sync::broadcast::Sender<(usize, RecordBatch)>,
/// A subscriber that is never used -- it exists only to keep the sender
/// alive.
_subscriber: tokio::sync::broadcast::Receiver<(usize, RecordBatch)>,
}

impl InMemoryBatches {
#[derive(Debug)]
struct Current {
schema: SchemaRef,
version: usize,
batch: RecordBatch,
}

impl Current {
pub fn new(schema: SchemaRef) -> Self {
let (updates, _subscriber) = tokio::sync::broadcast::channel(10);
let merged = RecordBatch::new_empty(schema.clone());
let batch = RecordBatch::new_empty(schema.clone());
Self {
schema,
current: RwLock::new((0, merged)),
version: 0,
batch,
}
}

pub fn add_batch(&mut self, batch: &RecordBatch) -> error_stack::Result<(), Error> {
if self.batch.num_rows() == 0 {
self.batch = batch.clone();
} else {
// This assumes that cloning the old batch is cheap.
// If it isn't, we could replace it with an empty batch (`std::mem::replace`),
// put it in an option, or allow `homogeneous_merge` to take `&RecordBatch`.
self.batch = homogeneous_merge(&self.schema, vec![self.batch.clone(), batch.clone()])
.into_report()
.change_context(Error::Add)?;
}
Ok(())
}
}

impl InMemoryBatches {
pub fn new(retained: bool, schema: SchemaRef) -> Self {
let (updates, _subscriber) = tokio::sync::broadcast::channel(10);
let current = RwLock::new(Current::new(schema.clone()));
Self {
retained,
current,
updates,
_subscriber,
}
Expand All @@ -50,19 +82,11 @@ impl InMemoryBatches {

let new_version = {
let mut write = self.current.write().map_err(|_| Error::Add)?;
let (version, old) = &*write;
let version = *version;

let merged = if old.num_rows() == 0 {
batch.clone()
} else {
homogeneous_merge(&self.schema, vec![old.clone(), batch.clone()])
.into_report()
.change_context(Error::Add)?
};

*write = (version + 1, merged);
version + 1
if self.retained {
write.add_batch(&batch)?;
}
write.version += 1;
write.version
};

self.updates
Expand All @@ -79,7 +103,10 @@ impl InMemoryBatches {
pub fn subscribe(
&self,
) -> impl Stream<Item = error_stack::Result<RecordBatch, Error>> + 'static {
let (mut version, merged) = self.current.read().unwrap().clone();
let (mut version, merged) = {
let read = self.current.read().unwrap();
(read.version, read.batch.clone())
};
let mut recv = self.updates.subscribe();

async_stream::try_stream! {
Expand Down Expand Up @@ -111,6 +138,6 @@ impl InMemoryBatches {

/// Retrieve the current in-memory batch.
pub fn current(&self) -> RecordBatch {
self.current.read().unwrap().1.clone()
self.current.read().unwrap().batch.clone()
}
}
16 changes: 10 additions & 6 deletions crates/sparrow-runtime/src/execute/operation/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,16 @@ impl ScanOperation {
.boxed()
} else {
let batch = in_memory.current();
futures::stream::once(async move {
Batch::try_new_from_batch(batch)
.into_report()
.change_context(Error::internal_msg("invalid input"))
})
.boxed()
if batch.num_rows() != 0 {
futures::stream::once(async move {
Batch::try_new_from_batch(batch)
.into_report()
.change_context(Error::internal_msg("invalid input"))
})
.boxed()
} else {
futures::stream::empty().boxed()
}
};
return Ok(Box::new(Self {
projected_schema,
Expand Down
21 changes: 19 additions & 2 deletions crates/sparrow-session/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl Session {
name: &str,
schema: SchemaRef,
time_column_name: &str,
retained: bool,
subsort_column_name: Option<&str>,
key_column_name: &str,
grouping_name: Option<&str>,
Expand Down Expand Up @@ -144,7 +145,14 @@ impl Session {
})
.clone();

Table::new(table_info, key_hash_inverse, key_column, expr, time_unit)
Table::new(
table_info,
key_hash_inverse,
key_column,
expr,
retained,
time_unit,
)
}

pub fn add_cast(
Expand Down Expand Up @@ -575,7 +583,16 @@ mod tests {
Field::new("b", DataType::Int64, true),
]));
let table = session
.add_table("table", schema, "time", None, "key", Some("user"), None)
.add_table(
"table",
schema,
"time",
true,
None,
"key",
Some("user"),
None,
)
.unwrap();

let field_name = session
Expand Down
3 changes: 2 additions & 1 deletion crates/sparrow-session/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl Table {
key_hash_inverse: Arc<ThreadSafeKeyHashInverse>,
key_column: usize,
expr: Expr,
retained: bool,
time_unit: Option<&str>,
) -> error_stack::Result<Self, Error> {
let prepared_fields: Fields = KEY_FIELDS
Expand All @@ -37,7 +38,7 @@ impl Table {
let prepare_hash = 0;

assert!(table_info.in_memory.is_none());
let in_memory_batches = Arc::new(InMemoryBatches::new(prepared_schema.clone()));
let in_memory_batches = Arc::new(InMemoryBatches::new(retained, prepared_schema.clone()));
table_info.in_memory = Some(in_memory_batches.clone());

let preparer = Preparer::new(
Expand Down
17 changes: 9 additions & 8 deletions examples/event-api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@

async def main():
kd.init_session()

start = time.time()
requestmap = dict()

# Initialize event source with historical data
events = kd.sources.PyList(
# Initialize event source with schema from historical data.
events = kd.sources.PyDict(
rows = [{"ts": start, "user": "user_1", "request_id": "12345678-1234-5678-1234-567812345678"}],
time_column = "ts",
key_column = "user",
time_unit = "s"
time_unit = "s",
retained=False,
)

# Compute features over events
Expand All @@ -32,11 +33,11 @@ async def main():
async def handle_http(req: web.Request) -> web.Response:
data = await req.json()

# Add the current time to the event
# Add the current time to the event
data["ts"] = time.time()

# Create a future so the aggregated result can be returned in the API response
request_id = str(uuid.uuid4())
request_id = str(uuid.uuid4())
requestmap[request_id] = asyncio.Future()
data["request_id"] = request_id

Expand All @@ -59,7 +60,7 @@ async def handle_http(req: web.Request) -> web.Response:
await runner.setup()
site = web.TCPSite(runner, 'localhost', 8080)
await site.start()


# Handle each conversation as it occurs
print(f"Waiting for events...")
Expand All @@ -80,7 +81,7 @@ async def handle_http(req: web.Request) -> web.Response:
fut.set_result(row["response"])

except Exception as e:
print(f"Failed to handle live event from Kaskada: {e}")
print(f"Failed to handle live event from Kaskada: {e}")

# Wait for web server to terminate gracefully
await runner.cleanup()
Expand Down
1 change: 0 additions & 1 deletion python/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
# TODO: Version switcher.
# This would require hosting multiple versions of the docs.
# https://pydata-sphinx-theme.readthedocs.io/en/stable/user_guide/version-dropdown.html

}

templates_path = ["_templates"]
Expand Down
2 changes: 1 addition & 1 deletion python/docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import kaskada as kd
kd.init_session()

# Bootstrap from historical data
messages = kd.sources.PyList(
messages = kd.sources.PyDict(
rows = pyarrow.parquet.read_table("./messages.parquet")
.to_pylist(),
time_column = "ts",
Expand Down
2 changes: 1 addition & 1 deletion python/docs/source/reference/sources.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
JsonlString
Pandas
Parquet
PyList
PyDict
```
1 change: 1 addition & 0 deletions python/pysrc/kaskada/_ffi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Table(Expr):
time_column: str,
key_column: str,
schema: pa.Schema,
retained: bool,
subsort_column: Optional[str],
grouping_name: Optional[str],
time_unit: Optional[str],
Expand Down
4 changes: 3 additions & 1 deletion python/pysrc/kaskada/_timestream.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def _literal(value: LiteralValue, session: _ffi.Session) -> Timestream:
seconds = int(us / 1_000_000)
# Get the leftover nanoseconds
nanoseconds = int((us % 1_000_000) * 1_000)
return Timestream(_ffi.Expr.literal_timedelta(session, seconds, nanoseconds))
return Timestream(
_ffi.Expr.literal_timedelta(session, seconds, nanoseconds)
)
else:
return Timestream(_ffi.Expr.literal(session, value))

Expand Down
4 changes: 2 additions & 2 deletions python/pysrc/kaskada/sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Sources of data for Kaskada queries."""
from .arrow import CsvString, JsonlString, Pandas, Parquet, PyList
from .arrow import CsvString, JsonlString, Pandas, Parquet, PyDict
from .source import Source


__all__ = ["Source", "CsvString", "Pandas", "JsonlString", "PyList", "Parquet"]
__all__ = ["Source", "CsvString", "Pandas", "JsonlString", "PyDict", "Parquet"]
9 changes: 8 additions & 1 deletion python/pysrc/kaskada/sources/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def add_data(self, data: pd.DataFrame) -> None:
self._ffi_table.add_pyarrow(batch)


class PyList(Source):
class PyDict(Source):
"""Source reading data from lists of dicts."""

def __init__(
Expand All @@ -70,6 +70,7 @@ def __init__(
*,
time_column: str,
key_column: str,
retained: bool = True,
subsort_column: Optional[str] = None,
schema: Optional[pa.Schema] = None,
grouping_name: Optional[str] = None,
Expand All @@ -81,6 +82,11 @@ def __init__(
rows: One or more rows represented as dicts.
time_column: The name of the column containing the time.
key_column: The name of the column containing the key.
retained: Whether added rows should be retained for queries.
If True, rows (both provided to the constructor and added later) will be retained
for interactive queries. If False, rows will be discarded after being sent to any
running materializations. Consider setting this to False when the source will only
be used for materialization to avoid unnecessary memory usage.
subsort_column: The name of the column containing the subsort.
If not provided, the subsort will be assigned by the system.
schema: The schema to use. If not provided, it will be inferred from the input.
Expand All @@ -93,6 +99,7 @@ def __init__(
if schema is None:
schema = pa.Table.from_pylist(rows).schema
super().__init__(
retained=retained,
schema=schema,
time_column=time_column,
key_column=key_column,
Expand Down
2 changes: 2 additions & 0 deletions python/pysrc/kaskada/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
schema: pa.Schema,
time_column: str,
key_column: str,
retained: bool = True,
subsort_column: Optional[str] = None,
grouping_name: Optional[str] = None,
time_unit: Optional[TimeUnit] = None,
Expand Down Expand Up @@ -62,6 +63,7 @@ def fix_field(field: pa.Field) -> pa.Field:
time_column,
key_column,
schema,
retained,
subsort_column,
grouping_name,
time_unit,
Expand Down
4 changes: 2 additions & 2 deletions python/pytests/aggregation/sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def source() -> kd.sources.CsvString:


@pytest.fixture(scope="module")
def source_spread_across_days() -> kd.sources.PyList:
return kd.sources.PyList(
def source_spread_across_days() -> kd.sources.PyDict:
return kd.sources.PyDict(
rows=[
{"time": "2021-01-01T00:00:00", "key": "A", "m": 1, "n": 2},
{"time": "2021-01-01T01:10:01", "key": "A", "m": 3, "n": 4},
Expand Down
8 changes: 6 additions & 2 deletions python/pytests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,16 @@ def test_history(golden, source_int64) -> None:
golden.jsonl(query.to_pandas(kd.results.History()))
golden.jsonl(
query.to_pandas(
kd.results.History(since=datetime.fromisoformat("1996-12-19T16:39:59+00:00"))
kd.results.History(
since=datetime.fromisoformat("1996-12-19T16:39:59+00:00")
)
)
)
golden.jsonl(
query.to_pandas(
kd.results.History(until=datetime.fromisoformat("1996-12-20T12:00:00+00:00"))
kd.results.History(
until=datetime.fromisoformat("1996-12-20T12:00:00+00:00")
)
)
)
golden.jsonl(
Expand Down
2 changes: 1 addition & 1 deletion python/pytests/flatten_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def test_flatten(golden) -> None:
source = kd.sources.PyList(
source = kd.sources.PyDict(
[
{"time": "1996-12-19T16:39:57", "user": "A", "m": [[5]]},
{"time": "1996-12-19T17:39:57", "user": "A", "m": []},
Expand Down
Loading

0 comments on commit d0b0f85

Please sign in to comment.