diff --git a/graphique/core.py b/graphique/core.py index e304c6a..99817cb 100644 --- a/graphique/core.py +++ b/graphique/core.py @@ -554,11 +554,10 @@ def get_fragments(self) -> Iterator[ds.Fragment]: def fragment_keys(self) -> list: """Filtered partitioned datasets may not have fragments.""" - try: + with contextlib.suppress(AttributeError, ValueError): Table.get_fragments(self) - except (AttributeError, ValueError): - return [] - return self.partitioning.schema.names + return self.partitioning.schema.names + return [] def rank_keys(self, k: int, *names: str, dense: bool = True) -> tuple: """Return expression and unmatched fields for partitioned dataset which filters by rank. diff --git a/graphique/interface.py b/graphique/interface.py index 3dc192b..a0c7955 100644 --- a/graphique/interface.py +++ b/graphique/interface.py @@ -25,7 +25,7 @@ from .models import Column, doc_field, selections from .scalars import Long -Root = Union[ds.Dataset, ds.Scanner, pa.Table] +Source = Union[ds.Dataset, ds.Scanner, pa.Table] def references(field) -> Iterator: @@ -61,8 +61,8 @@ class Schema: @strawberry.interface(description="an arrow dataset, scanner, or table") class Dataset: - def __init__(self, table: Root): - self.table = table + def __init__(self, source: Source): + self.source = source def references(self, info: Info, level: int = 0) -> set: """Return set of every possible future column reference.""" @@ -74,7 +74,7 @@ def references(self, info: Info, level: int = 0) -> set: def scanner(self, info: Info, **options) -> ds.Scanner: """Return scanner with only the columns necessary to proceed.""" options.setdefault('columns', list(self.references(info))) - dataset = ds.dataset(self.table) if isinstance(self.table, pa.Table) else self.table + dataset = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source if isinstance(dataset, ds.Dataset): return dataset.scanner(**options) options['schema'] = dataset.projected_schema @@ -82,8 +82,8 @@ def scanner(self, info: Info, **options) -> ds.Scanner: def select(self, info: Info, length: Optional[int] = None) -> pa.Table: """Return table with only the rows and columns necessary to proceed.""" - if isinstance(self.table, pa.Table): - return self.table.select(self.references(info)) + if isinstance(self.source, pa.Table): + return self.source.select(self.references(info)) scanner = self.scanner(info) if length is None: return self.add_metric(info, scanner.to_table(), mode='read') @@ -118,16 +118,16 @@ def filter(self, info: Info, **queries: Filter) -> Self: See `scan(filter: ...)` for more advanced queries. Additional feature: sorted tables support binary search """ - table = self.table + source = self.source prev = info.path.prev - search = isinstance(table, pa.Table) and (prev is None or prev.typename == 'Query') + search = isinstance(source, pa.Table) and (prev is None or prev.typename == 'Query') for name in self.schema().index if search else []: - assert not table[name].null_count, f"search requires non-null column: {name}" + assert not source[name].null_count, f"search requires non-null column: {name}" query = dict(queries.pop(name)) if 'eq' in query: - table = T.is_in(table, name, *query['eq']) + source = T.is_in(source, name, *query['eq']) if 'ne' in query: - table = T.not_equal(table, name, query['ne']) + source = T.not_equal(source, name, query['ne']) lower, upper = query.get('gt'), query.get('lt') includes = {'include_lower': False, 'include_upper': False} if 'ge' in query and (lower is None or query['ge'] > lower): @@ -135,24 +135,24 @@ def filter(self, info: Info, **queries: Filter) -> Self: if 'le' in query and (upper is None or query['le'] > upper): upper, includes['include_upper'] = query['le'], True if {lower, upper} != {None}: - table = T.range(table, name, lower, upper, **includes) + source = T.range(source, name, lower, upper, **includes) if len(query.pop('eq', [])) != 1 or query: break - self = type(self)(table) + self = type(self)(source) expr = Expression.from_query(**queries) return self if expr.to_arrow() is None else self.scan(info, filter=expr) @doc_field def type(self) -> str: """[arrow type](https://arrow.apache.org/docs/python/api/dataset.html#classes)""" - return type(self.table).__name__ + return type(self.source).__name__ @doc_field def schema(self) -> Schema: """dataset schema""" - table = self.table - schema = table.projected_schema if isinstance(table, ds.Scanner) else table.schema - partitioning = getattr(table, 'partitioning', None) + source = self.source + schema = source.projected_schema if isinstance(source, ds.Scanner) else source.schema + partitioning = getattr(source, 'partitioning', None) index = (schema.pandas_metadata or {}).get('index_columns', []) return Schema( names=schema.names, @@ -184,7 +184,7 @@ def add_metric(info: Info, table: pa.Table, **data): @doc_field def length(self) -> Long: """number of rows""" - return len(self.table) if hasattr(self.table, '__len__') else self.table.count_rows() + return len(self.source) if hasattr(self.source, '__len__') else self.source.count_rows() @doc_field def any(self, info: Info, length: Long = 1) -> bool: @@ -198,7 +198,7 @@ def any(self, info: Info, length: Long = 1) -> bool: @doc_field def size(self) -> Optional[Long]: """buffer size in bytes; null if table is not loaded""" - return getattr(self.table, 'nbytes', None) + return getattr(self.source, 'nbytes', None) @doc_field( name="column name(s); multiple names access nested struct fields", @@ -248,24 +248,24 @@ def group( See `column` for accessing any column which has changed type. See `tables` to split on any aggregated list columns. """ - table, aggs = self.table, dict(aggregate) + source, aggs = self.source, dict(aggregate) refs = {agg.name for values in aggs.values() for agg in values} - fragments = set(T.fragment_keys(self.table)) - if isinstance(table, ds.Scanner): - table = self.select(info) + fragments = set(T.fragment_keys(self.source)) + if isinstance(source, ds.Scanner): + source = self.select(info) if fragments and set(by) <= fragments: if set(by) == fragments: return type(self)(self.fragments(info, counts, aggregate)) if fragments.isdisjoint(refs) and set(aggs) <= Field.associatives: - table = self.fragments(info, counts, aggregate) + source = self.fragments(info, counts, aggregate) aggs.setdefault('sum', []).extend(Field(agg.alias) for agg in aggs.pop('count', [])) if counts: aggs['sum'].append(Field(counts)) counts = '' for agg in itertools.chain(*aggs.values()): agg.name = agg.alias - loaded = isinstance(table, pa.Table) - table = T.group(table, *by, counts=counts, ordered=ordered, **aggs) + loaded = isinstance(source, pa.Table) + table = T.group(source, *by, counts=counts, ordered=ordered, **aggs) return type(self)(table if loaded else self.add_metric(info, table, mode='group')) def fragments(self, info: Info, counts: str = '', aggregate: HashAggregates = {}) -> pa.Table: # type: ignore @@ -273,13 +273,13 @@ def fragments(self, info: Info, counts: str = '', aggregate: HashAggregates = {} Requires a partitioned dataset. Faster and less memory intensive than `group`. """ - schema = self.table.partitioning.schema # requires a Dataset + schema = self.source.partitioning.schema # requires a Dataset aggs = dict(aggregate) names = self.references(info, level=1) names.update(agg.name for value in aggs.values() for agg in value) projection = {name: pc.field(name) for name in names - set(schema.names)} columns = collections.defaultdict(list) - for fragment in T.get_fragments(self.table): + for fragment in T.get_fragments(self.source): row = ds.get_partition_keys(fragment.partition_expression) if projection: table = fragment.to_table(columns=projection) @@ -338,10 +338,10 @@ def sort( Optimized for length == 1; matches min or max values. """ kwargs = dict(length=length, null_placement=null_placement) - if isinstance(self.table, pa.Table) or length is None: + if isinstance(self.source, pa.Table) or length is None: table = self.select(info) else: - expr, by = T.rank_keys(self.table, length, *by, dense=False) + expr, by = T.rank_keys(self.source, length, *by, dense=False) scanner = self.scanner(info, filter=expr) if not by: return type(self)(self.add_metric(info, scanner.head(length), mode='head')) @@ -355,13 +355,13 @@ def sort( ) def rank(self, info: Info, by: list[str], max: int = 1) -> Self: """Return table selected by maximum dense rank.""" - expr, by = T.rank_keys(self.table, max, *by) - table = self.table if expr is None else self.table.filter(expr) + expr, by = T.rank_keys(self.source, max, *by) + source = self.source if expr is None else self.source.filter(expr) if not by: - return type(self)(table) - if not isinstance(table, ds.Dataset): - table = self.select(info) - return type(self)(T.rank(table, max, *by)) + return type(self)(source) + if not isinstance(source, ds.Dataset): + source = self.select(info) + return type(self)(T.rank(source, max, *by)) @staticmethod def apply_list(table: Batch, list_: ListFunction) -> Batch: @@ -500,10 +500,10 @@ def oneshot(cls, info: Info, scanner: ds.Scanner) -> Union[ds.Scanner, pa.Table] def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self: # type: ignore """Select rows and project columns without memory usage.""" expr = filter.to_arrow() - if expr is not None and not columns and isinstance(self.table, ds.Dataset): - return type(self)(self.table.filter(expr)) + if expr is not None and not columns and isinstance(self.source, ds.Dataset): + return type(self)(self.source.filter(expr)) scanner = self.scanner(info, filter=expr, columns=self.project(info, columns)) - if isinstance(self.table, ds.Scanner): + if isinstance(self.source, ds.Scanner): scanner = self.oneshot(info, scanner) return type(self)(scanner) @@ -529,7 +529,7 @@ def join( ) -> Self: """Provisional: [join](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.join) this table with another table on the root Query type.""" left, right = ( - root.table if isinstance(root.table, ds.Dataset) else root.select(info) + root.source if isinstance(root.source, ds.Dataset) else root.select(info) for root in (self, getattr(info.root_value, right)) ) table = left.join( @@ -554,7 +554,7 @@ def take(self, info: Info, indices: list[Long]) -> Self: @doc_field def drop_null(self, info: Info) -> Self: """Remove missing values from referenced columns in the table.""" - if isinstance(self.table, pa.Table): + if isinstance(self.source, pa.Table): return type(self)(pc.drop_null(self.select(info))) scanner = self.scanner(info) batches = map(pc.drop_null, scanner.to_batches()) diff --git a/graphique/middleware.py b/graphique/middleware.py index b9523a6..970b9f3 100644 --- a/graphique/middleware.py +++ b/graphique/middleware.py @@ -13,7 +13,7 @@ from strawberry.extensions import tracing from strawberry.utils.str_converters import to_camel_case from .inputs import Filter -from .interface import Dataset, Root +from .interface import Dataset, Source from .models import Column, doc_field from .scalars import Long, py_type, scalar_map @@ -48,7 +48,7 @@ class GraphQL(strawberry.asgi.GraphQL): options = dict(types=Column.registry.values(), scalar_overrides=scalar_map) - def __init__(self, root: Root, debug: bool = False, **kwargs): + def __init__(self, root: Source, debug: bool = False, **kwargs): options: dict = dict(self.options, extensions=(MetricsExtension,) * bool(debug)) if type(root).__name__ == 'Query': self.root_value = root @@ -63,7 +63,7 @@ async def get_root_value(self, request): return self.root_value @classmethod - def federated(cls, roots: Mapping[str, Root], keys: Mapping[str, Iterable] = {}, **kwargs): + def federated(cls, roots: Mapping[str, Source], keys: Mapping[str, Iterable] = {}, **kwargs): """Construct GraphQL app with multiple federated datasets. Args: @@ -77,7 +77,7 @@ def federated(cls, roots: Mapping[str, Root], keys: Mapping[str, Iterable] = {}, return cls(strawberry.type(Query)(**root_values), **kwargs) -def implemented(root: Root, name: str = '', keys: Iterable = ()): +def implemented(root: Source, name: str = '', keys: Iterable = ()): """Return type which extends the Dataset interface with knowledge of the schema.""" schema = root.projected_schema if isinstance(root, ds.Scanner) else root.schema types = {field.name: py_type(field.type) for field in schema} diff --git a/graphique/service.py b/graphique/service.py index 3ac9f0b..94f4141 100644 --- a/graphique/service.py +++ b/graphique/service.py @@ -9,6 +9,7 @@ import json from pathlib import Path +import pyarrow as pa import pyarrow.dataset as ds from starlette.config import Config from graphique.inputs import Expression @@ -25,6 +26,9 @@ if isinstance(COLUMNS, dict): COLUMNS = {alias: ds.field(name) for alias, name in COLUMNS.items()} +elif COLUMNS: + root = root.replace_schema(pa.schema(map(root.schema.field, COLUMNS), root.schema.metadata)) + COLUMNS = None if FILTERS is not None: root = root.to_table(columns=COLUMNS, filter=Expression.from_query(**FILTERS).to_arrow()) elif COLUMNS: