Skip to content

Commit

Permalink
Nodes with a Scanner interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
coady committed Sep 8, 2024
1 parent 663932f commit 9de5006
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 21 deletions.
60 changes: 46 additions & 14 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def group(
if isinstance(self, pa.Table):
self = ds.dataset(self)
use_threads = not ordered and Agg.ordered.isdisjoint(funcs)
return Declaration.group(self, *names, **aggs).to_table(use_threads)
return Nodes.group(self, *names, **aggs).to_table(use_threads)

def aggregate(self, counts: str = '', **funcs: Sequence[Agg]) -> dict:
"""Return aggregated scalars as a row of data."""
Expand Down Expand Up @@ -523,7 +523,7 @@ def min_max(self, *names: str, **options) -> Batch:
def rank(self, k: int, *names: str) -> Self:
"""Return table filtered by values within dense rank, similar to `select_k_unstable`."""
keys = dict(map(sort_key, names))
table = Declaration.group(self, *keys).to_table()
table = Nodes.group(self, *keys).to_table()
table = table.take(pc.select_k_unstable(table, k, keys.items()))
exprs = []
for key, order in keys.items():
Expand Down Expand Up @@ -625,8 +625,11 @@ def size(self) -> str:
return f'{size:n} {prefix}B'


class Declaration(ac.Declaration):
"""[Acero](https://arrow.apache.org/docs/python/api/acero.html) engine declaration."""
class Nodes(ac.Declaration):
"""[Acero](https://arrow.apache.org/docs/python/api/acero.html) engine declaration.
Provides a `Scanner` interface with no "oneshot" limitation.
"""

option_map = {
'table_source': ac.TableSourceNodeOptions,
Expand All @@ -642,32 +645,61 @@ def __init__(self, name, *args, inputs=None, **options):
super().__init__(name, self.option_map[name](*args, **options), inputs)

@classmethod
def scan(cls, dataset: ds.Dataset, columns=None) -> Self:
def scan(cls, dataset: ds.Dataset, columns: Optional[Iterable] = None) -> Self:
"""Return source node from a dataset."""
self = cls('scan', dataset, columns=columns)
expr = dataset._scan_options.get('filter')
if expr is not None:
self = self.apply('filter', expr)
return self if columns is None else self.project(columns)

def project(self, columns: Iterable) -> Self:
@property
def schema(self) -> pa.Schema:
"""projected schema"""
with self.to_reader() as reader:
return reader.schema

def scanner(self, **options) -> ds.Scanner:
return ds.Scanner.from_batches(self.to_reader(**options))

def count_rows(self) -> int:
"""Count matching rows."""
return self.scanner().count_rows()

def head(self, num_rows: int, **options) -> pa.Table:
"""Load the first N rows."""
return self.scanner(**options).head(num_rows)

def take(self, indices: Iterable[int], **options) -> pa.Table:
"""Select rows by index."""
return self.scanner(**options).take(indices)

def apply(self, name: str, *args, **options) -> Self:
"""Add a node by name."""
return type(self)(name, *args, inputs=[self], **options)

def project(self, columns: Union[Mapping[str, pc.Expression], Iterable[str]]) -> Self:
"""Add `project` node from columns names with optional expressions."""
if isinstance(columns, Mapping):
return self.apply('project', columns.values(), columns)
return self.apply('project', map(pc.field, columns))

@classmethod
def group(cls, dataset: ds.Dataset, *names, **aggs: tuple) -> Self:
"""Return aggregate node from a projected dataset."""
def group(self, *names, **aggs: tuple) -> Self:
"""Add `aggregate` node with dictionary support.
Also supports datasets because aggregation determines the projection.
"""
aggregates, targets = [], set(names)
for name, (target, _, _) in aggs.items():
aggregates.append(aggs[name] + (name,))
targets.update([target] if isinstance(target, str) else target)
columns = {name: pc.field(name) for name in targets}
for name in columns:
field = dataset.schema.field(name)
field = self.schema.field(name)
if pa.types.is_dictionary(field.type):
columns[name] = columns[name].cast(field.type.value_type)
return cls.scan(dataset, columns).apply('aggregate', aggregates, names)

def apply(self, name: str, *args, **options) -> Self:
return type(self)(name, *args, inputs=[self], **options)
if isinstance(self, ds.Dataset):
self = Nodes.scan(self, columns)
else:
self = self.project(columns)
return self.apply('aggregate', aggregates, names)
18 changes: 11 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pytest
from graphique.core import Agg, Declaration, ListChunk, Column as C, Table as T
from graphique.core import Agg, ListChunk, Nodes, Column as C, Table as T
from graphique.scalars import parse_duration, duration_isoformat


Expand Down Expand Up @@ -90,16 +90,20 @@ def test_membership():
assert C.index(array, 1, start=2) == -1


def test_declaration(table):
def test_nodes(table):
dataset = ds.dataset(table).filter(pc.field('state') == 'CA')
assert Declaration.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA']
(column,) = Declaration.scan(dataset, columns={'_': pc.field('state')}).to_table()
assert Nodes.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA']
(column,) = Nodes.scan(dataset, columns={'_': pc.field('state')}).to_table()
assert column.unique().to_pylist() == ['CA']
table = Declaration.group(dataset, 'county', 'city', counts=Agg.count_all).to_table()
table = Nodes.group(dataset, 'county', 'city', counts=Agg.count_all).to_table()
assert len(table) == 1241
assert pc.sum(table['counts']).as_py() == 2647
table = Declaration.scan(dataset, columns=['state']).to_table()
assert table.schema.names == ['state']
scanner = Nodes.scan(dataset, columns=['state'])
assert scanner.schema.names == ['state']
assert scanner.group('state').to_table() == pa.table({'state': ['CA']})
assert scanner.count_rows() == 2647
assert scanner.head(3) == pa.table({'state': ['CA'] * 3})
assert scanner.take([0, 2]) == pa.table({'state': ['CA'] * 2})


def test_group(table):
Expand Down

0 comments on commit 9de5006

Please sign in to comment.