From 5a3f440ce2f1f858d512907d8c9c67240a224cd3 Mon Sep 17 00:00:00 2001 From: Aric Coady Date: Sat, 12 Oct 2024 13:58:39 -0700 Subject: [PATCH] Generalized scanning and projection. --- graphique/core.py | 35 +++++++++++++---------------------- graphique/interface.py | 22 +++++----------------- tests/test_core.py | 1 - tests/test_dataset.py | 6 ++---- 4 files changed, 20 insertions(+), 44 deletions(-) diff --git a/graphique/core.py b/graphique/core.py index b38efdf..b210eae 100644 --- a/graphique/core.py +++ b/graphique/core.py @@ -53,9 +53,6 @@ class Agg: 'tdigest': pc.TDigestOptions, 'variance': pc.VarianceOptions, } - - associatives = {'all', 'any', 'first', 'last', 'max', 'min', 'one', 'product', 'sum'} - associatives |= {'count'} # transformed to be associative ordered = {'first', 'last'} def __init__(self, name: str, alias: str = '', **options): @@ -597,14 +594,18 @@ class Nodes(ac.Declaration): def __init__(self, name, *args, inputs=None, **options): super().__init__(name, self.option_map[name](*args, **options), inputs) - @classmethod - def scan(cls, dataset: ds.Dataset, columns: Optional[Iterable] = None) -> Self: - """Return source node from a dataset.""" - self = cls('scan', dataset, columns=columns) - expr = dataset._scan_options.get('filter') - if expr is not None: - self = self.apply('filter', expr) - return self if columns is None else self.project(columns) + def scan(self, columns: Iterable[str]) -> Self: + """Return projected source node, supporting datasets and tables.""" + if isinstance(self, ds.Dataset): + expr = self._scan_options.get('filter') + self = Nodes('scan', self, columns=columns) + if expr is not None: + self = self.apply('filter', expr) + elif isinstance(self, pa.Table): + self = Nodes('table_source', self) + if isinstance(columns, Mapping): + return self.apply('project', columns.values(), columns) + return self.apply('project', map(pc.field, columns)) @property def schema(self) -> pa.Schema: @@ -633,12 +634,6 @@ def apply(self, name: str, *args, **options) -> Self: filter = functools.partialmethod(apply, 'filter') - def project(self, columns: Union[Mapping[str, pc.Expression], Iterable[str]]) -> Self: - """Add `project` node from columns names with optional expressions.""" - if isinstance(columns, Mapping): - return self.apply('project', columns.values(), columns) - return self.apply('project', map(pc.field, columns)) - def group(self, *names, **aggs: tuple) -> Self: """Add `aggregate` node with dictionary support. @@ -653,8 +648,4 @@ def group(self, *names, **aggs: tuple) -> Self: field = self.schema.field(name) if pa.types.is_dictionary(field.type): columns[name] = columns[name].cast(field.type.value_type) - if isinstance(self, ds.Dataset): - self = Nodes.scan(self, columns) - else: - self = self.project(columns) - return self.apply('aggregate', aggregates, names) + return Nodes.scan(self, columns).apply('aggregate', aggregates, names) diff --git a/graphique/interface.py b/graphique/interface.py index ae2d974..78b708b 100644 --- a/graphique/interface.py +++ b/graphique/interface.py @@ -76,14 +76,12 @@ def select(self, info: Info) -> Source: names = list(self.references(info)) if len(names) >= len(self.schema().names): return self.source - if isinstance(self.source, ds.Dataset): - return Nodes.scan(self.source, names) - if isinstance(self.source, Nodes): - return self.source.project(names) if isinstance(self.source, ds.Scanner): schema = self.source.projected_schema return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names) - return self.source.select(names) + if isinstance(self.source, pa.Table): + return self.source.select(names) + return Nodes.scan(self.source, names) def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table: """Return table with only the rows and columns necessary to proceed.""" @@ -143,9 +141,7 @@ def filter(self, info: Info, **queries: Filter) -> Self: source = T.range(source, name, lower, upper, **includes) if len(query.pop('eq', [])) != 1 or query: break - self = type(self)(source) - expr = Expression.from_query(**queries) - return self if expr.to_arrow() is None else self.scan(info, filter=expr) + return type(self)(source).scan(info, filter=Expression.from_query(**queries)) @doc_field def type(self) -> str: @@ -258,8 +254,6 @@ def group( for agg in values: aggs[agg.alias] = (agg.name, prefix + func, agg.func_options(func)) source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source - if isinstance(source, pa.Table): - source = ds.dataset(source) source = Nodes.group(source, *by, **aggs) if ordered: source = self.add_metric(info, source.to_table(use_threads=False), mode='group') @@ -461,13 +455,7 @@ def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = scanner = ds.Scanner.from_batches(self.source.to_batches(), **options) return type(self)(self.add_metric(info, scanner.to_table(), mode='batch')) source = self.source if expr is None else self.source.filter(expr) - if isinstance(source, ds.Dataset): - return type(self)(Nodes.scan(source, projection) if columns else source) - if isinstance(source, pa.Table): - if not columns: - return type(self)(source.select(list(projection))) - source = Nodes('table_source', source) - return type(self)(source.project(projection)) + return type(self)(Nodes.scan(source, projection) if columns else source) @doc_field( right="name of right table; must be on root Query type", diff --git a/tests/test_core.py b/tests/test_core.py index 43d5f3e..5061910 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -85,7 +85,6 @@ def test_membership(): def test_nodes(table): dataset = ds.dataset(table).filter(pc.field('state') == 'CA') - assert Nodes.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA'] (column,) = Nodes.scan(dataset, columns={'_': pc.field('state')}).to_table() assert column.unique().to_pylist() == ['CA'] table = Nodes.group(dataset, 'county', 'city', counts=([], 'hash_count_all', None)).to_table() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 672e0b3..3fb8776 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -252,11 +252,9 @@ def test_federation(fedclient): data = fedclient.execute( """{ _entities(representations: {__typename: "ZipcodesTable", zipcode: 90001}) { - ... on ZipcodesTable { length row { state } schema { names } } } }""" + ... on ZipcodesTable { length type row { state } } } }""" ) - assert data == { - '_entities': [{'length': 1, 'row': {'state': 'CA'}, 'schema': {'names': ['state']}}] - } + assert data == {'_entities': [{'length': 1, 'type': 'Nodes', 'row': {'state': 'CA'}}]} data = fedclient.execute("""{ states { filter(state: {eq: "CA"}) { columns { indices { takeFrom(field: "zipcodes") { __typename column(name: "state") { length } } } } } } }""") table = data['states']['filter']['columns']['indices']['takeFrom']