diff --git a/graphique/core.py b/graphique/core.py index 99817cb..8fa2537 100644 --- a/graphique/core.py +++ b/graphique/core.py @@ -677,6 +677,8 @@ def apply(self, name: str, *args, **options) -> Self: """Add a node by name.""" return type(self)(name, *args, inputs=[self], **options) + filter = functools.partialmethod(apply, 'filter') + def project(self, columns: Union[Mapping[str, pc.Expression], Iterable[str]]) -> Self: """Add `project` node from columns names with optional expressions.""" if isinstance(columns, Mapping): diff --git a/graphique/interface.py b/graphique/interface.py index a0c7955..73bdb00 100644 --- a/graphique/interface.py +++ b/graphique/interface.py @@ -8,7 +8,7 @@ import collections import inspect import itertools -from collections.abc import Callable, Iterable, Iterator, Mapping +from collections.abc import Callable, Iterable, Iterator, Mapping, Sized from datetime import timedelta from typing import Annotated, Optional, Union, no_type_check import pyarrow as pa @@ -18,14 +18,14 @@ from strawberry import Info from strawberry.extensions.utils import get_path_from_info from typing_extensions import Self -from .core import Batch, Column as C, ListChunk, Table as T +from .core import Batch, Column as C, ListChunk, Nodes, Table as T from .inputs import CountAggregate, Cumulative, Diff, Expression, Field, Filter from .inputs import HashAggregates, ListFunction, Pairwise, Projection, Rank from .inputs import ScalarAggregate, TDigestAggregate, VarianceAggregate, links, provisional from .models import Column, doc_field, selections from .scalars import Long -Source = Union[ds.Dataset, ds.Scanner, pa.Table] +Source = Union[ds.Dataset, Nodes, ds.Scanner, pa.Table] def references(field) -> Iterator: @@ -71,23 +71,43 @@ def references(self, info: Info, level: int = 0) -> set: fields = itertools.chain(*[field.selections for field in fields]) return set(itertools.chain(*map(references, fields))) & set(self.schema().names) + def select(self, info: Info) -> Source: + """Return source with only the columns necessary to proceed.""" + names = list(self.references(info)) + if len(names) >= len(self.schema().names): + return self.source + if isinstance(self.source, ds.Dataset): + return Nodes.scan(self.source, names) + if isinstance(self.source, Nodes): + return self.source.project(names) + if isinstance(self.source, ds.Scanner): + schema = self.source.projected_schema + return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names) + return self.source.select(names) + def scanner(self, info: Info, **options) -> ds.Scanner: """Return scanner with only the columns necessary to proceed.""" options.setdefault('columns', list(self.references(info))) - dataset = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source - if isinstance(dataset, ds.Dataset): - return dataset.scanner(**options) - options['schema'] = dataset.projected_schema - return ds.Scanner.from_batches(dataset.to_batches(), **options) - - def select(self, info: Info, length: Optional[int] = None) -> pa.Table: + source = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source + if isinstance(source, ds.Dataset): + return source.scanner(**options) + if isinstance(source, Nodes): + if 'filter' in options: # pragma: no branch + source = source.filter(options['filter']) + if 'columns' in options: # pragma: no branch + source = source.project(options['columns']) + return source.scanner() + options['schema'] = source.projected_schema + return ds.Scanner.from_batches(source.to_batches(), **options) + + def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table: """Return table with only the rows and columns necessary to proceed.""" - if isinstance(self.source, pa.Table): - return self.source.select(self.references(info)) - scanner = self.scanner(info) + source = self.select(info) + if isinstance(source, pa.Table): + return source if length is None: - return self.add_metric(info, scanner.to_table(), mode='read') - return self.add_metric(info, scanner.head(length), mode='head') + return self.add_metric(info, source.to_table(), mode='read') + return self.add_metric(info, source.head(length), mode='head') @classmethod @no_type_check @@ -99,12 +119,12 @@ def resolve_reference(cls, info: Info, **keys) -> Self: def columns(self, info: Info) -> dict: """fields for each column""" - table = self.select(info) + table = self.to_table(info) return {name: Column.cast(table[name]) for name in table.schema.names} def row(self, info: Info, index: int = 0) -> dict: """Return scalar values at index.""" - table = self.select(info, index + 1 if index >= 0 else None) + table = self.to_table(info, index + 1 if index >= 0 else None) row = {} for name in table.schema.names: scalar = table[name][index] @@ -169,11 +189,6 @@ def optional(self) -> Optional[Self]: """ return self - @staticmethod - def add_context(info: Info, key: str, **data): # pragma: no cover - """Add data to context with path info.""" - info.context.setdefault(key, []).append(dict(data, path=get_path_from_info(info))) - @staticmethod def add_metric(info: Info, table: pa.Table, **data): """Add memory usage and other metrics to context with path info.""" @@ -184,7 +199,7 @@ def add_metric(info: Info, table: pa.Table, **data): @doc_field def length(self) -> Long: """number of rows""" - return len(self.source) if hasattr(self.source, '__len__') else self.source.count_rows() + return len(self.source) if isinstance(self.source, Sized) else self.source.count_rows() @doc_field def any(self, info: Info, length: Long = 1) -> bool: @@ -192,7 +207,7 @@ def any(self, info: Info, length: Long = 1) -> bool: May be significantly faster than `length` for out-of-core data. """ - table = self.select(info, length) + table = self.to_table(info, length) return len(table) >= length @doc_field @@ -225,7 +240,7 @@ def slice( self, info: Info, offset: Long = 0, length: Optional[Long] = None, reverse: bool = False ) -> Self: """Return zero-copy slice of table.""" - table = self.select(info, length and (offset + length if offset >= 0 else None)) + table = self.to_table(info, length and (offset + length if offset >= 0 else None)) table = table[offset:][:length] # `slice` bug: ARROW-15412 return type(self)(table[::-1] if reverse else table) @@ -252,7 +267,7 @@ def group( refs = {agg.name for values in aggs.values() for agg in values} fragments = set(T.fragment_keys(self.source)) if isinstance(source, ds.Scanner): - source = self.select(info) + source = self.to_table(info) if fragments and set(by) <= fragments: if set(by) == fragments: return type(self)(self.fragments(info, counts, aggregate)) @@ -310,7 +325,7 @@ def runs( Differs from `group` by relying on adjacency, and is typically faster. Other columns are transformed into list columns. See `column` and `tables` to further access lists. """ - table = self.select(info) + table = self.to_table(info) predicates = {} for diff in map(dict, split): name = diff.pop('name') @@ -339,7 +354,7 @@ def sort( """ kwargs = dict(length=length, null_placement=null_placement) if isinstance(self.source, pa.Table) or length is None: - table = self.select(info) + table = self.to_table(info) else: expr, by = T.rank_keys(self.source, length, *by, dense=False) scanner = self.scanner(info, filter=expr) @@ -360,7 +375,7 @@ def rank(self, info: Info, by: list[str], max: int = 1) -> Self: if not by: return type(self)(source) if not isinstance(source, ds.Dataset): - source = self.select(info) + source = self.to_table(info) return type(self)(T.rank(source, max, *by)) @staticmethod @@ -459,7 +474,7 @@ def aggregate( variance: doc_argument(list[VarianceAggregate], func=pc.variance) = [], ) -> Self: """Return table with scalar aggregate functions applied to list columns.""" - table = self.select(info) + table = self.to_table(info) columns = T.columns(table) agg_fields: dict = collections.defaultdict(dict) keys: tuple = 'approximate_median', 'count', 'count_distinct', 'distinct', 'first', 'last' @@ -529,7 +544,7 @@ def join( ) -> Self: """Provisional: [join](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.join) this table with another table on the root Query type.""" left, right = ( - root.source if isinstance(root.source, ds.Dataset) else root.select(info) + root.source if isinstance(root.source, ds.Dataset) else root.to_table(info) for root in (self, getattr(info.root_value, right)) ) table = left.join( @@ -548,14 +563,14 @@ def join( @doc_field def take(self, info: Info, indices: list[Long]) -> Self: """Select rows from indices.""" - table = self.scanner(info).take(indices) + table = self.select(info).take(indices) return type(self)(self.add_metric(info, table, mode='take')) @doc_field def drop_null(self, info: Info) -> Self: """Remove missing values from referenced columns in the table.""" if isinstance(self.source, pa.Table): - return type(self)(pc.drop_null(self.select(info))) + return type(self)(pc.drop_null(self.to_table(info))) scanner = self.scanner(info) batches = map(pc.drop_null, scanner.to_batches()) scanner = ds.Scanner.from_batches(batches, schema=scanner.projected_schema) diff --git a/graphique/service.py b/graphique/service.py index 94f4141..a6baee0 100644 --- a/graphique/service.py +++ b/graphique/service.py @@ -12,6 +12,7 @@ import pyarrow as pa import pyarrow.dataset as ds from starlette.config import Config +from graphique.core import Nodes from graphique.inputs import Expression from graphique import GraphQL @@ -32,7 +33,7 @@ if FILTERS is not None: root = root.to_table(columns=COLUMNS, filter=Expression.from_query(**FILTERS).to_arrow()) elif COLUMNS: - root = root.scanner(columns=COLUMNS) + root = Nodes.scan(root, columns=COLUMNS) if FEDERATED: app = GraphQL.federated({FEDERATED: root}, debug=DEBUG)