Skip to content

Commit

Permalink
First and last aggregates.
Browse files Browse the repository at this point in the history
Dictionaries not natively supported.
  • Loading branch information
coady committed Sep 4, 2023
1 parent 39122fe commit 0adedea
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 14 deletions.
7 changes: 7 additions & 0 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ def run_offsets(self, predicate: Callable = pc.not_equal, *args) -> pa.IntegerAr
mask = predicate(Column.diff(self), *args) if args else Column.diff(self, predicate)
return pc.indices_nonzero(pa.chunked_array(ends + mask.chunks + ends))

def first_last(self, **options):
if not pa.types.is_dictionary(self.type):
return pc.first_last(self, **options).as_py()
self = self.combine_chunks()
indices = pc.first_last(self.indices, **options).as_py().items()
return {key: None if index is None else self.dictionary[index] for key, index in indices}

def min_max(self, **options):
if pa.types.is_dictionary(self.type):
self = self.unique().dictionary_decode()
Expand Down
7 changes: 4 additions & 3 deletions graphique/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from strawberry.field import StrawberryField
from strawberry.scalars import JSON
from typing_extensions import Self
from .core import Agg, ListChunk
from .core import Agg
from .scalars import Long

T = TypeVar('T')
Expand Down Expand Up @@ -212,6 +212,9 @@ class ScalarAggregates(Input):
approximate_median: List[ScalarAggregate] = default_field([], func=pc.approximate_median)
count: List[CountAggregate] = default_field([], func=pc.count)
count_distinct: List[CountAggregate] = default_field([], func=pc.count_distinct)
first: List[ScalarAggregate] = default_field([], func=pc.first)
first_last: List[ScalarAggregate] = default_field([], func=pc.first_last)
last: List[ScalarAggregate] = default_field([], func=pc.last)
max: List[ScalarAggregate] = default_field([], func=pc.max)
mean: List[ScalarAggregate] = default_field([], func=pc.mean)
min: List[ScalarAggregate] = default_field([], func=pc.min)
Expand All @@ -231,8 +234,6 @@ class HashAggregates(ScalarAggregates):
distinct: List[CountAggregate] = default_field(
[], description="distinct values within each scalar"
)
first: List[Field] = default_field([], func=ListChunk.first)
last: List[Field] = default_field([], func=ListChunk.last)
list: List[Field] = default_field([], description="all values within each scalar")
one: List[Field] = default_field([], description="arbitrary value within each scalar")

Expand Down
23 changes: 12 additions & 11 deletions graphique/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,24 @@ def drop_null(self) -> List[T]:
@Column.register(date, datetime, time, bytes)
@strawberry.type(name='Column', description="column of ordinal values")
class OrdinalColumn(NominalColumn[T]):
def __init__(self, array):
super().__init__(array)
self.min_max = functools.lru_cache(maxsize=None)(functools.partial(C.min_max, array))
@compute_field
def first(self, skip_nulls: bool = True, min_count: int = 0) -> Optional[T]:
return C.first_last(self.array, skip_nulls=skip_nulls, min_count=min_count)['first']

@doc_field
@compute_field
def last(self, skip_nulls: bool = True, min_count: int = 0) -> Optional[T]:
return C.first_last(self.array, skip_nulls=skip_nulls, min_count=min_count)['last']

@compute_field
def min(self, skip_nulls: bool = True, min_count: int = 0) -> Optional[T]:
"""minimum value"""
return self.min_max(skip_nulls=skip_nulls, min_count=min_count)['min']
return C.min_max(self.array, skip_nulls=skip_nulls, min_count=min_count)['min']

@doc_field
@compute_field
def max(self, skip_nulls: bool = True, min_count: int = 0) -> Optional[T]:
"""maximum value"""
return self.min_max(skip_nulls=skip_nulls, min_count=min_count)['max']
return C.min_max(self.array, skip_nulls=skip_nulls, min_count=min_count)['max']

@doc_field
@compute_field
def index(self, value: T, start: Long = 0, end: Optional[Long] = None) -> Long:
"""Find the index of the first occurrence of a given value."""
return C.index(self.array, value, start, end)

@compute_field
Expand Down
2 changes: 2 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def test_not_implemented():
pc.sort_indices(pa.table({'': dictionary}), [('', 'ascending')])
with pytest.raises(NotImplementedError):
dictionary.index('')
with pytest.raises(NotImplementedError):
pc.first_last(dictionary)
with pytest.raises(NotImplementedError):
pc.min_max(dictionary)
with pytest.raises(NotImplementedError):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def execute(query):
data = execute(f'{{ {name} {{ dropNull }} }}')
assert data == {name: {'dropNull': ['1970-01-01']}}
assert execute(f'{{ {name} {{ min max }} }}')
assert execute(f'{{ {name} {{ first last }} }}')

data = execute('{ timestamp { values } }')
assert data == {'timestamp': {'values': ['1970-01-01T00:00:00', None]}}
Expand Down Expand Up @@ -82,6 +83,7 @@ def execute(query):
'string': {'type': 'dictionary<values=string, indices=int32, ordered=0>'}
}
assert execute('{ string { min max } }')
assert execute('{ string { first last } }') == {'string': {'first': '', 'last': ''}}


def test_boolean(executor):
Expand Down

0 comments on commit 0adedea

Please sign in to comment.