Skip to content

Commit

Permalink
Generalize source data.
Browse files Browse the repository at this point in the history
Supports schema selection of a dataset without scanning.
  • Loading branch information
coady committed Sep 19, 2024
1 parent 5d81523 commit c2d6c51
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 49 deletions.
7 changes: 3 additions & 4 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,10 @@ def get_fragments(self) -> Iterator[ds.Fragment]:

def fragment_keys(self) -> list:
"""Filtered partitioned datasets may not have fragments."""
try:
with contextlib.suppress(AttributeError, ValueError):
Table.get_fragments(self)
except (AttributeError, ValueError):
return []
return self.partitioning.schema.names
return self.partitioning.schema.names
return []

def rank_keys(self, k: int, *names: str, dense: bool = True) -> tuple:
"""Return expression and unmatched fields for partitioned dataset which filters by rank.
Expand Down
82 changes: 41 additions & 41 deletions graphique/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .models import Column, doc_field, selections
from .scalars import Long

Root = Union[ds.Dataset, ds.Scanner, pa.Table]
Source = Union[ds.Dataset, ds.Scanner, pa.Table]


def references(field) -> Iterator:
Expand Down Expand Up @@ -61,8 +61,8 @@ class Schema:

@strawberry.interface(description="an arrow dataset, scanner, or table")
class Dataset:
def __init__(self, table: Root):
self.table = table
def __init__(self, source: Source):
self.source = source

def references(self, info: Info, level: int = 0) -> set:
"""Return set of every possible future column reference."""
Expand All @@ -74,16 +74,16 @@ def references(self, info: Info, level: int = 0) -> set:
def scanner(self, info: Info, **options) -> ds.Scanner:
"""Return scanner with only the columns necessary to proceed."""
options.setdefault('columns', list(self.references(info)))
dataset = ds.dataset(self.table) if isinstance(self.table, pa.Table) else self.table
dataset = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source
if isinstance(dataset, ds.Dataset):
return dataset.scanner(**options)
options['schema'] = dataset.projected_schema
return ds.Scanner.from_batches(dataset.to_batches(), **options)

def select(self, info: Info, length: Optional[int] = None) -> pa.Table:
"""Return table with only the rows and columns necessary to proceed."""
if isinstance(self.table, pa.Table):
return self.table.select(self.references(info))
if isinstance(self.source, pa.Table):
return self.source.select(self.references(info))
scanner = self.scanner(info)
if length is None:
return self.add_metric(info, scanner.to_table(), mode='read')
Expand Down Expand Up @@ -118,41 +118,41 @@ def filter(self, info: Info, **queries: Filter) -> Self:
See `scan(filter: ...)` for more advanced queries. Additional feature: sorted tables
support binary search
"""
table = self.table
source = self.source
prev = info.path.prev
search = isinstance(table, pa.Table) and (prev is None or prev.typename == 'Query')
search = isinstance(source, pa.Table) and (prev is None or prev.typename == 'Query')
for name in self.schema().index if search else []:
assert not table[name].null_count, f"search requires non-null column: {name}"
assert not source[name].null_count, f"search requires non-null column: {name}"
query = dict(queries.pop(name))
if 'eq' in query:
table = T.is_in(table, name, *query['eq'])
source = T.is_in(source, name, *query['eq'])
if 'ne' in query:
table = T.not_equal(table, name, query['ne'])
source = T.not_equal(source, name, query['ne'])
lower, upper = query.get('gt'), query.get('lt')
includes = {'include_lower': False, 'include_upper': False}
if 'ge' in query and (lower is None or query['ge'] > lower):
lower, includes['include_lower'] = query['ge'], True
if 'le' in query and (upper is None or query['le'] > upper):
upper, includes['include_upper'] = query['le'], True
if {lower, upper} != {None}:
table = T.range(table, name, lower, upper, **includes)
source = T.range(source, name, lower, upper, **includes)
if len(query.pop('eq', [])) != 1 or query:
break
self = type(self)(table)
self = type(self)(source)
expr = Expression.from_query(**queries)
return self if expr.to_arrow() is None else self.scan(info, filter=expr)

@doc_field
def type(self) -> str:
"""[arrow type](https://arrow.apache.org/docs/python/api/dataset.html#classes)"""
return type(self.table).__name__
return type(self.source).__name__

@doc_field
def schema(self) -> Schema:
"""dataset schema"""
table = self.table
schema = table.projected_schema if isinstance(table, ds.Scanner) else table.schema
partitioning = getattr(table, 'partitioning', None)
source = self.source
schema = source.projected_schema if isinstance(source, ds.Scanner) else source.schema
partitioning = getattr(source, 'partitioning', None)
index = (schema.pandas_metadata or {}).get('index_columns', [])
return Schema(
names=schema.names,
Expand Down Expand Up @@ -184,7 +184,7 @@ def add_metric(info: Info, table: pa.Table, **data):
@doc_field
def length(self) -> Long:
"""number of rows"""
return len(self.table) if hasattr(self.table, '__len__') else self.table.count_rows()
return len(self.source) if hasattr(self.source, '__len__') else self.source.count_rows()

@doc_field
def any(self, info: Info, length: Long = 1) -> bool:
Expand All @@ -198,7 +198,7 @@ def any(self, info: Info, length: Long = 1) -> bool:
@doc_field
def size(self) -> Optional[Long]:
"""buffer size in bytes; null if table is not loaded"""
return getattr(self.table, 'nbytes', None)
return getattr(self.source, 'nbytes', None)

@doc_field(
name="column name(s); multiple names access nested struct fields",
Expand Down Expand Up @@ -248,38 +248,38 @@ def group(
See `column` for accessing any column which has changed type. See `tables` to split on any
aggregated list columns.
"""
table, aggs = self.table, dict(aggregate)
source, aggs = self.source, dict(aggregate)
refs = {agg.name for values in aggs.values() for agg in values}
fragments = set(T.fragment_keys(self.table))
if isinstance(table, ds.Scanner):
table = self.select(info)
fragments = set(T.fragment_keys(self.source))
if isinstance(source, ds.Scanner):
source = self.select(info)
if fragments and set(by) <= fragments:
if set(by) == fragments:
return type(self)(self.fragments(info, counts, aggregate))
if fragments.isdisjoint(refs) and set(aggs) <= Field.associatives:
table = self.fragments(info, counts, aggregate)
source = self.fragments(info, counts, aggregate)
aggs.setdefault('sum', []).extend(Field(agg.alias) for agg in aggs.pop('count', []))
if counts:
aggs['sum'].append(Field(counts))
counts = ''
for agg in itertools.chain(*aggs.values()):
agg.name = agg.alias
loaded = isinstance(table, pa.Table)
table = T.group(table, *by, counts=counts, ordered=ordered, **aggs)
loaded = isinstance(source, pa.Table)
table = T.group(source, *by, counts=counts, ordered=ordered, **aggs)
return type(self)(table if loaded else self.add_metric(info, table, mode='group'))

def fragments(self, info: Info, counts: str = '', aggregate: HashAggregates = {}) -> pa.Table: # type: ignore
"""Return table from scanning fragments and grouping by partitions.
Requires a partitioned dataset. Faster and less memory intensive than `group`.
"""
schema = self.table.partitioning.schema # requires a Dataset
schema = self.source.partitioning.schema # requires a Dataset
aggs = dict(aggregate)
names = self.references(info, level=1)
names.update(agg.name for value in aggs.values() for agg in value)
projection = {name: pc.field(name) for name in names - set(schema.names)}
columns = collections.defaultdict(list)
for fragment in T.get_fragments(self.table):
for fragment in T.get_fragments(self.source):
row = ds.get_partition_keys(fragment.partition_expression)
if projection:
table = fragment.to_table(columns=projection)
Expand Down Expand Up @@ -338,10 +338,10 @@ def sort(
Optimized for length == 1; matches min or max values.
"""
kwargs = dict(length=length, null_placement=null_placement)
if isinstance(self.table, pa.Table) or length is None:
if isinstance(self.source, pa.Table) or length is None:
table = self.select(info)
else:
expr, by = T.rank_keys(self.table, length, *by, dense=False)
expr, by = T.rank_keys(self.source, length, *by, dense=False)
scanner = self.scanner(info, filter=expr)
if not by:
return type(self)(self.add_metric(info, scanner.head(length), mode='head'))
Expand All @@ -355,13 +355,13 @@ def sort(
)
def rank(self, info: Info, by: list[str], max: int = 1) -> Self:
"""Return table selected by maximum dense rank."""
expr, by = T.rank_keys(self.table, max, *by)
table = self.table if expr is None else self.table.filter(expr)
expr, by = T.rank_keys(self.source, max, *by)
source = self.source if expr is None else self.source.filter(expr)
if not by:
return type(self)(table)
if not isinstance(table, ds.Dataset):
table = self.select(info)
return type(self)(T.rank(table, max, *by))
return type(self)(source)
if not isinstance(source, ds.Dataset):
source = self.select(info)
return type(self)(T.rank(source, max, *by))

@staticmethod
def apply_list(table: Batch, list_: ListFunction) -> Batch:
Expand Down Expand Up @@ -500,10 +500,10 @@ def oneshot(cls, info: Info, scanner: ds.Scanner) -> Union[ds.Scanner, pa.Table]
def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self: # type: ignore
"""Select rows and project columns without memory usage."""
expr = filter.to_arrow()
if expr is not None and not columns and isinstance(self.table, ds.Dataset):
return type(self)(self.table.filter(expr))
if expr is not None and not columns and isinstance(self.source, ds.Dataset):
return type(self)(self.source.filter(expr))
scanner = self.scanner(info, filter=expr, columns=self.project(info, columns))
if isinstance(self.table, ds.Scanner):
if isinstance(self.source, ds.Scanner):
scanner = self.oneshot(info, scanner)
return type(self)(scanner)

Expand All @@ -529,7 +529,7 @@ def join(
) -> Self:
"""Provisional: [join](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.join) this table with another table on the root Query type."""
left, right = (
root.table if isinstance(root.table, ds.Dataset) else root.select(info)
root.source if isinstance(root.source, ds.Dataset) else root.select(info)
for root in (self, getattr(info.root_value, right))
)
table = left.join(
Expand All @@ -554,7 +554,7 @@ def take(self, info: Info, indices: list[Long]) -> Self:
@doc_field
def drop_null(self, info: Info) -> Self:
"""Remove missing values from referenced columns in the table."""
if isinstance(self.table, pa.Table):
if isinstance(self.source, pa.Table):
return type(self)(pc.drop_null(self.select(info)))
scanner = self.scanner(info)
batches = map(pc.drop_null, scanner.to_batches())
Expand Down
8 changes: 4 additions & 4 deletions graphique/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from strawberry.extensions import tracing
from strawberry.utils.str_converters import to_camel_case
from .inputs import Filter
from .interface import Dataset, Root
from .interface import Dataset, Source
from .models import Column, doc_field
from .scalars import Long, py_type, scalar_map

Expand Down Expand Up @@ -48,7 +48,7 @@ class GraphQL(strawberry.asgi.GraphQL):

options = dict(types=Column.registry.values(), scalar_overrides=scalar_map)

def __init__(self, root: Root, debug: bool = False, **kwargs):
def __init__(self, root: Source, debug: bool = False, **kwargs):
options: dict = dict(self.options, extensions=(MetricsExtension,) * bool(debug))
if type(root).__name__ == 'Query':
self.root_value = root
Expand All @@ -63,7 +63,7 @@ async def get_root_value(self, request):
return self.root_value

@classmethod
def federated(cls, roots: Mapping[str, Root], keys: Mapping[str, Iterable] = {}, **kwargs):
def federated(cls, roots: Mapping[str, Source], keys: Mapping[str, Iterable] = {}, **kwargs):
"""Construct GraphQL app with multiple federated datasets.
Args:
Expand All @@ -77,7 +77,7 @@ def federated(cls, roots: Mapping[str, Root], keys: Mapping[str, Iterable] = {},
return cls(strawberry.type(Query)(**root_values), **kwargs)


def implemented(root: Root, name: str = '', keys: Iterable = ()):
def implemented(root: Source, name: str = '', keys: Iterable = ()):
"""Return type which extends the Dataset interface with knowledge of the schema."""
schema = root.projected_schema if isinstance(root, ds.Scanner) else root.schema
types = {field.name: py_type(field.type) for field in schema}
Expand Down
4 changes: 4 additions & 0 deletions graphique/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import json
from pathlib import Path
import pyarrow as pa
import pyarrow.dataset as ds
from starlette.config import Config
from graphique.inputs import Expression
Expand All @@ -25,6 +26,9 @@

if isinstance(COLUMNS, dict):
COLUMNS = {alias: ds.field(name) for alias, name in COLUMNS.items()}
elif COLUMNS:
root = root.replace_schema(pa.schema(map(root.schema.field, COLUMNS), root.schema.metadata))
COLUMNS = None
if FILTERS is not None:
root = root.to_table(columns=COLUMNS, filter=Expression.from_query(**FILTERS).to_arrow())
elif COLUMNS:
Expand Down

0 comments on commit c2d6c51

Please sign in to comment.