Skip to content

Commit

Permalink
Transform partitioned datasets.
Browse files Browse the repository at this point in the history
Utility for facet counts. Linter updates.
  • Loading branch information
coady committed Aug 16, 2024
1 parent 7143a76 commit c85f48c
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 8 deletions.
6 changes: 6 additions & 0 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,12 @@ def scan(cls, dataset: ds.Dataset, columns=None) -> Self:
self = self.project(map(pc.field, columns))
return self

@classmethod
def facets(cls, dataset: ds.Dataset, columns: Mapping, counts: str = '') -> Self:
"""Return aggregate node from a projected dataset."""
aggs: list = [([], 'hash_count_all', None, counts)] if counts else []
return cls.scan(dataset, columns).aggregate(aggs, columns)

def apply(self, name: str, *args, **options) -> Self:
return type(self)(name, *args, inputs=[self], **options)

Expand Down
2 changes: 1 addition & 1 deletion graphique/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def selections(*fields) -> dict:
def doc_field(func: Optional[Callable] = None, **kwargs: str) -> StrawberryField:
"""Return strawberry field with argument and docstring descriptions."""
if func is None:
return functools.partial(doc_field, **kwargs)
return functools.partial(doc_field, **kwargs) # type: ignore
for name in kwargs:
argument = strawberry.argument(description=kwargs[name])
func.__annotations__[name] = Annotated[func.__annotations__[name], argument]
Expand Down
15 changes: 9 additions & 6 deletions graphique/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
progress on the second pass.
"""

import operator
import shutil
from pathlib import Path
from typing import Annotated
from typing import Annotated, Callable, Optional
import numpy as np
import pyarrow.compute as pc
import pyarrow.dataset as ds
Expand Down Expand Up @@ -42,15 +43,15 @@ def write_batches(
pbar.update(len(batch))


def write_fragments(dataset: ds.Dataset, base_dir: str, sorting=(), **options):
"""Rewrite partition files by fragment to consolidate."""
def write_fragments(dataset: ds.Dataset, base_dir: str, func: Optional[Callable] = None, **options):
"""Rewrite partition files by fragment to consolidate, optionally transforming."""
options['format'] = 'parquet'
exprs = {Path(frag.path).parent: frag.partition_expression for frag in dataset.get_fragments()}
offset = len(dataset.partitioning.schema)
for path in tqdm(exprs, desc="Fragments"):
part_dir = Path(base_dir, *path.parts[-offset:])
part = dataset.filter(exprs[path])
ds.write_dataset(part.sort_by(sorting) if sorting else part, part_dir, **options)
ds.write_dataset(func(part) if func else part, part_dir, **options)


def partition(
Expand All @@ -65,8 +66,10 @@ def partition(
write_batches(ds.dataset(src, partitioning='hive'), str(temp), *partitioning)
dataset = ds.dataset(temp, partitioning='hive')
options = dict(partitioning_flavor='hive', existing_data_behavior='overwrite_or_ignore')
if fragments or sort:
write_fragments(dataset, dest, tuple(map(sort_key, sort)))
if sorting := list(map(sort_key, sort)):
write_fragments(dataset, dest, operator.methodcaller('sort_by', sorting))
elif fragments:
write_fragments(dataset, dest)
else:
with tqdm(desc="Partitions"):
ds.write_dataset(dataset, dest, partitioning=partitioning, **options)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ dependencies = {file = "requirements.in"}

[tool.ruff]
line-length = 100
extend-include = ["*.ipynb"]

[tool.ruff.format]
quote-style = "preserve"
Expand Down
3 changes: 3 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def test_declaration(table):
assert Declaration.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA']
(column,) = Declaration.scan(dataset, columns={'_': pc.field('state')}).to_table()
assert column.unique().to_pylist() == ['CA']
table = Declaration.facets(dataset, ['county', 'city'], counts='counts').to_table()
assert len(table) == 1241
assert pc.sum(table['counts']).as_py() == 2647


def test_group(table):
Expand Down

0 comments on commit c85f48c

Please sign in to comment.