Skip to content

Commit

Permalink
Minimize scanner usage.
Browse files Browse the repository at this point in the history
Supported as a source, but not used internally for scanning.
  • Loading branch information
coady committed Sep 29, 2024
1 parent 574c0b7 commit 70fb2b4
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 41 deletions.
59 changes: 26 additions & 33 deletions graphique/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,6 @@ def select(self, info: Info) -> Source:
return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names)
return self.source.select(names)

def scanner(self, info: Info, **options) -> ds.Scanner:
"""Return scanner with only the columns necessary to proceed."""
options.setdefault('columns', list(self.references(info)))
source = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source
if isinstance(source, ds.Dataset):
return source.scanner(**options)
if isinstance(source, Nodes):
if 'filter' in options: # pragma: no branch
source = source.filter(options['filter'])
if 'columns' in options: # pragma: no branch
source = source.project(options['columns'])
return source.scanner()
options['schema'] = source.projected_schema
return ds.Scanner.from_batches(source.to_batches(), **options)

def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table:
"""Return table with only the rows and columns necessary to proceed."""
source = self.select(info)
Expand Down Expand Up @@ -228,8 +213,12 @@ def column(
This is typically only needed for aliased or casted columns.
If the column is in the schema, `columns` can be used instead.
"""
expr = Expression(name=name, cast=cast, safe=safe).to_arrow() # type: ignore
return Column.cast(*self.scanner(info, columns={'': expr}).to_table())
if isinstance(self.source, pa.Table) and len(name) == 1:
column = self.source.column(*name)
return Column.cast(column.cast(cast, safe) if cast else column)
column = Projection(alias='_', name=name, cast=cast, safe=safe) # type: ignore
source = self.scan(info, Expression(), [column]).source
return Column.cast(*(source if isinstance(source, pa.Table) else source.to_table()))

@doc_field(
offset="number of rows to skip; negative value skips from the end",
Expand Down Expand Up @@ -357,10 +346,12 @@ def sort(
table = self.to_table(info)
else:
expr, by = T.rank_keys(self.source, length, *by, dense=False)
scanner = self.scanner(info, filter=expr)
if expr is not None:
self = type(self)(self.source.filter(expr))
source = self.select(info)
if not by:
return type(self)(self.add_metric(info, scanner.head(length), mode='head'))
table = T.map_batch(scanner, T.sort, *by, **kwargs)
return type(self)(self.add_metric(info, source.head(length), mode='head'))
table = T.map_batch(source, T.sort, *by, **kwargs)
self.add_metric(info, table, mode='batch')
return type(self)(T.sort(table, *by, **kwargs)) # type: ignore

Expand Down Expand Up @@ -492,24 +483,26 @@ def aggregate(

aggregate.deprecation_reason = ListFunction.deprecation

def project(self, info: Info, columns: list[Projection]) -> dict:
"""Return projected columns, including all references from below fields."""
projection = {name: pc.field(name) for name in self.references(info, level=1)}
projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns}
if '' in projection:
raise ValueError(f"projected columns need a name or alias: {projection['']}")
return projection

@doc_field(filter="selected rows", columns="projected columns")
def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self: # type: ignore
"""Select rows and project columns without memory usage."""
expr = filter.to_arrow()
if expr is not None and not columns and isinstance(self.source, ds.Dataset):
return type(self)(self.source.filter(expr))
scanner = self.scanner(info, filter=expr, columns=self.project(info, columns))
projection = {name: pc.field(name) for name in self.references(info, level=1)}
projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns}
if '' in projection:
raise ValueError(f"projected columns need a name or alias: {projection['']}")
if isinstance(self.source, ds.Scanner):
scanner = self.add_metric(info, scanner.to_table(), mode='batch')
return type(self)(scanner)
options = dict(schema=self.source.projected_schema, filter=expr, columns=projection)
scanner = ds.Scanner.from_batches(self.source.to_batches(), **options)
return type(self)(self.add_metric(info, scanner.to_table(), mode='batch'))
source = self.source if expr is None else self.source.filter(expr)
if isinstance(source, ds.Dataset):
return type(self)(Nodes.scan(source, projection) if columns else source)
if isinstance(source, pa.Table):
if not columns:
return type(self)(source.select(list(projection)))
source = Nodes('table_source', source)
return type(self)(source.project(projection))

@doc_field(
right="name of right table; must be on root Query type",
Expand Down
3 changes: 1 addition & 2 deletions graphique/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pyarrow as pa
import pyarrow.dataset as ds
from starlette.config import Config
from graphique.core import Nodes
from graphique.inputs import Expression
from graphique import GraphQL

Expand All @@ -33,7 +32,7 @@
if FILTERS is not None:
root = root.to_table(columns=COLUMNS, filter=Expression.from_query(**FILTERS).to_arrow())
elif COLUMNS:
root = Nodes.scan(root, columns=COLUMNS)
root = root.scanner(columns=COLUMNS)

if FEDERATED:
app = GraphQL.federated({FEDERATED: root}, debug=DEBUG)
Expand Down
2 changes: 1 addition & 1 deletion tests/federated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
fixtures = Path(__file__).parent / 'fixtures'
dataset = ds.dataset(fixtures / 'zipcodes.parquet')
roots = {
'zipcodes': dataset.scanner(),
'zipcodes': core.Nodes.scan(dataset, dataset.schema.names),
'states': core.Table.sort(dataset.to_table(), 'state', 'county', indices='indices'),
'zip_db': ds.dataset(fixtures / 'zip_db.parquet'),
}
Expand Down
8 changes: 4 additions & 4 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ def test_schema(dsclient):
assert set(schema['names']) >= {'zipcode', 'state', 'county'}
assert set(schema['types']) >= {'int32', 'string'}
assert len(schema['partitioning']) in (0, 6)
assert dsclient.execute('{ type }')['type'] in {'FileSystemDataset', 'Scanner'}
assert dsclient.execute('{ scan { type length } }')['scan']['type'] == 'Scanner'
data = dsclient.execute('{ scan { type length l: length } }')
assert data['scan']['type'] in {'Scanner', 'Table'}
data = dsclient.execute('{ scan(filter: {}) { type } }')
assert data == {'scan': {'type': 'FileSystemDataset'}}
data = dsclient.execute('{ scan(columns: {name: "zipcode"}) { type } }')
assert data == {'scan': {'type': 'Nodes'}}
result = dsclient._execute('{ length optional { tables { length } } }')
assert result.data == {'length': 41700, 'optional': None}
assert len(result.errors) == 1
Expand Down
4 changes: 3 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def test_camel(aliasclient):
assert data == {'filter': {'length': 1}}
data = aliasclient.execute('{ filter(camelId: {eq: 1}) { length } }')
assert data == {'filter': {'length': 1}}
data = aliasclient.execute('{ group(by: "camelId") { length } }')
assert data == {'group': {'length': 2}}


def test_snake(executor):
Expand Down Expand Up @@ -349,7 +351,7 @@ def test_conditions(executor):
assert data == {'scan': {'column': {'type': 'float'}}}
with pytest.raises(ValueError, match="no kernel"):
executor("""{ scan(columns: {alias: "bool",
ifElse: [{name: "struct"}, {name: "int32"}, {name: "float"}]}) { type } }""")
ifElse: [{name: "struct"}, {name: "int32"}, {name: "float"}]}) { slice { type } } }""")


def test_long(executor):
Expand Down

0 comments on commit 70fb2b4

Please sign in to comment.