diff --git a/graphique/core.py b/graphique/core.py index eef2762..e304c6a 100644 --- a/graphique/core.py +++ b/graphique/core.py @@ -429,7 +429,7 @@ def group( if isinstance(self, pa.Table): self = ds.dataset(self) use_threads = not ordered and Agg.ordered.isdisjoint(funcs) - return Declaration.group(self, *names, **aggs).to_table(use_threads) + return Nodes.group(self, *names, **aggs).to_table(use_threads) def aggregate(self, counts: str = '', **funcs: Sequence[Agg]) -> dict: """Return aggregated scalars as a row of data.""" @@ -523,7 +523,7 @@ def min_max(self, *names: str, **options) -> Batch: def rank(self, k: int, *names: str) -> Self: """Return table filtered by values within dense rank, similar to `select_k_unstable`.""" keys = dict(map(sort_key, names)) - table = Declaration.group(self, *keys).to_table() + table = Nodes.group(self, *keys).to_table() table = table.take(pc.select_k_unstable(table, k, keys.items())) exprs = [] for key, order in keys.items(): @@ -625,8 +625,11 @@ def size(self) -> str: return f'{size:n} {prefix}B' -class Declaration(ac.Declaration): - """[Acero](https://arrow.apache.org/docs/python/api/acero.html) engine declaration.""" +class Nodes(ac.Declaration): + """[Acero](https://arrow.apache.org/docs/python/api/acero.html) engine declaration. + + Provides a `Scanner` interface with no "oneshot" limitation. + """ option_map = { 'table_source': ac.TableSourceNodeOptions, @@ -642,7 +645,7 @@ def __init__(self, name, *args, inputs=None, **options): super().__init__(name, self.option_map[name](*args, **options), inputs) @classmethod - def scan(cls, dataset: ds.Dataset, columns=None) -> Self: + def scan(cls, dataset: ds.Dataset, columns: Optional[Iterable] = None) -> Self: """Return source node from a dataset.""" self = cls('scan', dataset, columns=columns) expr = dataset._scan_options.get('filter') @@ -650,24 +653,53 @@ def scan(cls, dataset: ds.Dataset, columns=None) -> Self: self = self.apply('filter', expr) return self if columns is None else self.project(columns) - def project(self, columns: Iterable) -> Self: + @property + def schema(self) -> pa.Schema: + """projected schema""" + with self.to_reader() as reader: + return reader.schema + + def scanner(self, **options) -> ds.Scanner: + return ds.Scanner.from_batches(self.to_reader(**options)) + + def count_rows(self) -> int: + """Count matching rows.""" + return self.scanner().count_rows() + + def head(self, num_rows: int, **options) -> pa.Table: + """Load the first N rows.""" + return self.scanner(**options).head(num_rows) + + def take(self, indices: Iterable[int], **options) -> pa.Table: + """Select rows by index.""" + return self.scanner(**options).take(indices) + + def apply(self, name: str, *args, **options) -> Self: + """Add a node by name.""" + return type(self)(name, *args, inputs=[self], **options) + + def project(self, columns: Union[Mapping[str, pc.Expression], Iterable[str]]) -> Self: + """Add `project` node from columns names with optional expressions.""" if isinstance(columns, Mapping): return self.apply('project', columns.values(), columns) return self.apply('project', map(pc.field, columns)) - @classmethod - def group(cls, dataset: ds.Dataset, *names, **aggs: tuple) -> Self: - """Return aggregate node from a projected dataset.""" + def group(self, *names, **aggs: tuple) -> Self: + """Add `aggregate` node with dictionary support. + + Also supports datasets because aggregation determines the projection. + """ aggregates, targets = [], set(names) for name, (target, _, _) in aggs.items(): aggregates.append(aggs[name] + (name,)) targets.update([target] if isinstance(target, str) else target) columns = {name: pc.field(name) for name in targets} for name in columns: - field = dataset.schema.field(name) + field = self.schema.field(name) if pa.types.is_dictionary(field.type): columns[name] = columns[name].cast(field.type.value_type) - return cls.scan(dataset, columns).apply('aggregate', aggregates, names) - - def apply(self, name: str, *args, **options) -> Self: - return type(self)(name, *args, inputs=[self], **options) + if isinstance(self, ds.Dataset): + self = Nodes.scan(self, columns) + else: + self = self.project(columns) + return self.apply('aggregate', aggregates, names) diff --git a/tests/test_core.py b/tests/test_core.py index 1317e10..540b78e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,7 +2,7 @@ import pyarrow.compute as pc import pyarrow.dataset as ds import pytest -from graphique.core import Agg, Declaration, ListChunk, Column as C, Table as T +from graphique.core import Agg, ListChunk, Nodes, Column as C, Table as T from graphique.scalars import parse_duration, duration_isoformat @@ -90,16 +90,20 @@ def test_membership(): assert C.index(array, 1, start=2) == -1 -def test_declaration(table): +def test_nodes(table): dataset = ds.dataset(table).filter(pc.field('state') == 'CA') - assert Declaration.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA'] - (column,) = Declaration.scan(dataset, columns={'_': pc.field('state')}).to_table() + assert Nodes.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA'] + (column,) = Nodes.scan(dataset, columns={'_': pc.field('state')}).to_table() assert column.unique().to_pylist() == ['CA'] - table = Declaration.group(dataset, 'county', 'city', counts=Agg.count_all).to_table() + table = Nodes.group(dataset, 'county', 'city', counts=Agg.count_all).to_table() assert len(table) == 1241 assert pc.sum(table['counts']).as_py() == 2647 - table = Declaration.scan(dataset, columns=['state']).to_table() - assert table.schema.names == ['state'] + scanner = Nodes.scan(dataset, columns=['state']) + assert scanner.schema.names == ['state'] + assert scanner.group('state').to_table() == pa.table({'state': ['CA']}) + assert scanner.count_rows() == 2647 + assert scanner.head(3) == pa.table({'state': ['CA'] * 3}) + assert scanner.take([0, 2]) == pa.table({'state': ['CA'] * 2}) def test_group(table):