diff --git a/graphique/interface.py b/graphique/interface.py index 8cf9dac..8013a89 100644 --- a/graphique/interface.py +++ b/graphique/interface.py @@ -85,21 +85,6 @@ def select(self, info: Info) -> Source: return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names) return self.source.select(names) - def scanner(self, info: Info, **options) -> ds.Scanner: - """Return scanner with only the columns necessary to proceed.""" - options.setdefault('columns', list(self.references(info))) - source = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source - if isinstance(source, ds.Dataset): - return source.scanner(**options) - if isinstance(source, Nodes): - if 'filter' in options: # pragma: no branch - source = source.filter(options['filter']) - if 'columns' in options: # pragma: no branch - source = source.project(options['columns']) - return source.scanner() - options['schema'] = source.projected_schema - return ds.Scanner.from_batches(source.to_batches(), **options) - def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table: """Return table with only the rows and columns necessary to proceed.""" source = self.select(info) @@ -228,8 +213,12 @@ def column( This is typically only needed for aliased or casted columns. If the column is in the schema, `columns` can be used instead. """ - expr = Expression(name=name, cast=cast, safe=safe).to_arrow() # type: ignore - return Column.cast(*self.scanner(info, columns={'': expr}).to_table()) + if isinstance(self.source, pa.Table) and len(name) == 1: + column = self.source.column(*name) + return Column.cast(column.cast(cast, safe) if cast else column) + column = Projection(alias='_', name=name, cast=cast, safe=safe) # type: ignore + source = self.scan(info, Expression(), [column]).source + return Column.cast(*(source if isinstance(source, pa.Table) else source.to_table())) @doc_field( offset="number of rows to skip; negative value skips from the end", @@ -357,10 +346,12 @@ def sort( table = self.to_table(info) else: expr, by = T.rank_keys(self.source, length, *by, dense=False) - scanner = self.scanner(info, filter=expr) + if expr is not None: + self = type(self)(self.source.filter(expr)) + source = self.select(info) if not by: - return type(self)(self.add_metric(info, scanner.head(length), mode='head')) - table = T.map_batch(scanner, T.sort, *by, **kwargs) + return type(self)(self.add_metric(info, source.head(length), mode='head')) + table = T.map_batch(source, T.sort, *by, **kwargs) self.add_metric(info, table, mode='batch') return type(self)(T.sort(table, *by, **kwargs)) # type: ignore @@ -492,24 +483,26 @@ def aggregate( aggregate.deprecation_reason = ListFunction.deprecation - def project(self, info: Info, columns: list[Projection]) -> dict: - """Return projected columns, including all references from below fields.""" - projection = {name: pc.field(name) for name in self.references(info, level=1)} - projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns} - if '' in projection: - raise ValueError(f"projected columns need a name or alias: {projection['']}") - return projection - @doc_field(filter="selected rows", columns="projected columns") def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self: # type: ignore """Select rows and project columns without memory usage.""" expr = filter.to_arrow() - if expr is not None and not columns and isinstance(self.source, ds.Dataset): - return type(self)(self.source.filter(expr)) - scanner = self.scanner(info, filter=expr, columns=self.project(info, columns)) + projection = {name: pc.field(name) for name in self.references(info, level=1)} + projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns} + if '' in projection: + raise ValueError(f"projected columns need a name or alias: {projection['']}") if isinstance(self.source, ds.Scanner): - scanner = self.add_metric(info, scanner.to_table(), mode='batch') - return type(self)(scanner) + options = dict(schema=self.source.projected_schema, filter=expr, columns=projection) + scanner = ds.Scanner.from_batches(self.source.to_batches(), **options) + return type(self)(self.add_metric(info, scanner.to_table(), mode='batch')) + source = self.source if expr is None else self.source.filter(expr) + if isinstance(source, ds.Dataset): + return type(self)(Nodes.scan(source, projection) if columns else source) + if isinstance(source, pa.Table): + if not columns: + return type(self)(source.select(list(projection))) + source = Nodes('table_source', source) + return type(self)(source.project(projection)) @doc_field( right="name of right table; must be on root Query type", diff --git a/graphique/service.py b/graphique/service.py index a6baee0..94f4141 100644 --- a/graphique/service.py +++ b/graphique/service.py @@ -12,7 +12,6 @@ import pyarrow as pa import pyarrow.dataset as ds from starlette.config import Config -from graphique.core import Nodes from graphique.inputs import Expression from graphique import GraphQL @@ -33,7 +32,7 @@ if FILTERS is not None: root = root.to_table(columns=COLUMNS, filter=Expression.from_query(**FILTERS).to_arrow()) elif COLUMNS: - root = Nodes.scan(root, columns=COLUMNS) + root = root.scanner(columns=COLUMNS) if FEDERATED: app = GraphQL.federated({FEDERATED: root}, debug=DEBUG) diff --git a/tests/federated.py b/tests/federated.py index 05e1744..8fd43ee 100644 --- a/tests/federated.py +++ b/tests/federated.py @@ -5,7 +5,7 @@ fixtures = Path(__file__).parent / 'fixtures' dataset = ds.dataset(fixtures / 'zipcodes.parquet') roots = { - 'zipcodes': dataset.scanner(), + 'zipcodes': core.Nodes.scan(dataset, dataset.schema.names), 'states': core.Table.sort(dataset.to_table(), 'state', 'county', indices='indices'), 'zip_db': ds.dataset(fixtures / 'zip_db.parquet'), } diff --git a/tests/test_dataset.py b/tests/test_dataset.py index da4fe1c..163104f 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -151,10 +151,10 @@ def test_schema(dsclient): assert set(schema['names']) >= {'zipcode', 'state', 'county'} assert set(schema['types']) >= {'int32', 'string'} assert len(schema['partitioning']) in (0, 6) - assert dsclient.execute('{ type }')['type'] in {'FileSystemDataset', 'Scanner'} - assert dsclient.execute('{ scan { type length } }')['scan']['type'] == 'Scanner' - data = dsclient.execute('{ scan { type length l: length } }') - assert data['scan']['type'] in {'Scanner', 'Table'} + data = dsclient.execute('{ scan(filter: {}) { type } }') + assert data == {'scan': {'type': 'FileSystemDataset'}} + data = dsclient.execute('{ scan(columns: {name: "zipcode"}) { type } }') + assert data == {'scan': {'type': 'Nodes'}} result = dsclient._execute('{ length optional { tables { length } } }') assert result.data == {'length': 41700, 'optional': None} assert len(result.errors) == 1 diff --git a/tests/test_models.py b/tests/test_models.py index 9a17fd4..964125b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -10,6 +10,8 @@ def test_camel(aliasclient): assert data == {'filter': {'length': 1}} data = aliasclient.execute('{ filter(camelId: {eq: 1}) { length } }') assert data == {'filter': {'length': 1}} + data = aliasclient.execute('{ group(by: "camelId") { length } }') + assert data == {'group': {'length': 2}} def test_snake(executor): @@ -349,7 +351,7 @@ def test_conditions(executor): assert data == {'scan': {'column': {'type': 'float'}}} with pytest.raises(ValueError, match="no kernel"): executor("""{ scan(columns: {alias: "bool", - ifElse: [{name: "struct"}, {name: "int32"}, {name: "float"}]}) { type } }""") + ifElse: [{name: "struct"}, {name: "int32"}, {name: "float"}]}) { slice { type } } }""") def test_long(executor):