Skip to content

Commit

Permalink
Duration scalar.
Browse files Browse the repository at this point in the history
Uses ISO 8601 format for `timedelta` and `MonthDayNano`. The available libraries seem inactive or have dependencies. `isodate` is a likely candidate, if it has another release.
  • Loading branch information
coady committed Apr 28, 2024
1 parent 33af8d9 commit 8dc4e30
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 19 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Changed
* Pyarrow >=16 required
* `group` optimized for datasets
* `Duration` scalar

### Removed
* `Interval` type

## [1.5](https://pypi.org/project/graphique/1.5/) - 2024-01-24
### Changed
Expand Down
4 changes: 2 additions & 2 deletions graphique/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from strawberry.field import StrawberryField
from .core import Column as C
from .inputs import links
from .scalars import Interval, Long, scalar_map, type_map
from .scalars import Long, scalar_map, type_map

if TYPE_CHECKING: # pragma: no cover
from .interface import Dataset
Expand Down Expand Up @@ -120,7 +120,7 @@ def values(self) -> list[Optional[T]]:
return self.array.to_pylist()


@Column.register(timedelta, Interval)
@Column.register(timedelta, pa.MonthDayNano)
@strawberry.type(name='Column', description="column of elapsed times")
class NominalColumn(Generic[T], Column):
values = doc_field(Set.values)
Expand Down
64 changes: 52 additions & 12 deletions graphique/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
GraphQL scalars.
"""

import functools
import re
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import Union, no_type_check
import pyarrow as pa
import strawberry

Expand All @@ -14,22 +17,59 @@ def parse_long(value) -> int:
raise TypeError(f"Long cannot represent value: {value}")


@strawberry.type(description="month day nano interval")
class Interval:
months: int
days: int
nanoseconds: 'Long' # type: ignore
@no_type_check
def parse_duration(value: str):
months = days = seconds = 0
d_val, _, t_val = value.partition('T')
parts = re.split(r'(-?\d+\.?\d*)', d_val.lower() + t_val)
if parts.pop(0) != 'p':
raise ValueError("Duration format must start with `P`")
for num, key in zip(parts[::2], parts[1::2]):
if n := float(num) if '.' in num else int(num):
if key in 'ym':
months += n * 12 if key == 'y' else n
elif key in 'wd':
days += n * 7 if key == 'w' else n
elif key in 'HMS':
seconds += n * {'H': 3600, 'M': 60, 'S': 1}[key]
else:
raise ValueError(f"Invalid duration field: {key.upper()}")
if months:
return pa.MonthDayNano([months, days, int(seconds * 1_000_000_000)])
return timedelta(days, seconds)


@functools.singledispatch
def duration_isoformat(td: timedelta) -> str:
days = f'{td.days}D' if td.days else ''
fraction = f'.{td.microseconds:06}' if td.microseconds else ''
return f'P{days}T{td.seconds}{fraction}S'


@duration_isoformat.register
def _(mdn: pa.MonthDayNano) -> str:
months = f'{mdn.months}M' if mdn.months else ''
days = f'{mdn.days}D' if mdn.days else ''
seconds, nanoseconds = divmod(mdn.nanoseconds, 1_000_000_000)
fraction = f'.{nanoseconds:09}' if nanoseconds else ''
return f'P{months}{days}T{seconds}{fraction}S'


Long = strawberry.scalar(int, name='Long', description="64-bit int", parse_value=parse_long)
Duration = strawberry.scalar( # pragma: no branch
timedelta,
Duration = strawberry.scalar(
Union[timedelta, pa.MonthDayNano],
name='Duration',
description="duration float (in seconds)",
serialize=timedelta.total_seconds,
parse_value=lambda s: timedelta(seconds=s),
description="Duration (isoformat)",
specified_by_url="https://en.wikipedia.org/wiki/ISO_8601#Durations",
serialize=duration_isoformat,
parse_value=parse_duration,
)
scalar_map = {bytes: strawberry.scalars.Base64, dict: strawberry.scalars.JSON, timedelta: Duration}
scalar_map = {
bytes: strawberry.scalars.Base64,
dict: strawberry.scalars.JSON,
timedelta: Duration,
pa.MonthDayNano: Duration,
}

type_map = {
pa.lib.Type_BOOL: bool,
Expand All @@ -52,7 +92,7 @@ class Interval:
pa.lib.Type_TIME32: time,
pa.lib.Type_TIME64: time,
pa.lib.Type_DURATION: timedelta,
pa.lib.Type_INTERVAL_MONTH_DAY_NANO: Interval,
pa.lib.Type_INTERVAL_MONTH_DAY_NANO: pa.MonthDayNano,
pa.lib.Type_BINARY: bytes,
pa.lib.Type_FIXED_SIZE_BINARY: bytes,
pa.lib.Type_LARGE_BINARY: bytes,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@
import pyarrow.compute as pc
import pytest
from graphique.core import Agg, ListChunk, Column as C, Table as T
from graphique.scalars import parse_duration, duration_isoformat


def test_duration():
assert duration_isoformat(parse_duration('P1Y1M1DT1H1M1.1S')) == 'P13M1DT3661.100000000S'
assert duration_isoformat(parse_duration('P1M1DT1H1M1.1S')) == 'P1M1DT3661.100000000S'
assert duration_isoformat(parse_duration('P1DT1H1M1.1S')) == 'P1DT3661.100000S'
assert duration_isoformat(parse_duration('PT1H1M1.1S')) == 'PT3661.100000S'
assert duration_isoformat(parse_duration('PT1M1.1S')) == 'PT61.100000S'
assert duration_isoformat(parse_duration('PT1.1S')) == 'PT1.100000S'
assert duration_isoformat(parse_duration('PT1S')) == 'PT1S'
assert duration_isoformat(parse_duration('P0D')) == 'PT0S'
assert duration_isoformat(parse_duration('PT0S')) == 'PT0S'
assert duration_isoformat(parse_duration('P-1DT-1H')) == 'P-2DT82800S'
with pytest.raises(ValueError):
duration_isoformat(parse_duration('T1H'))
with pytest.raises(ValueError):
duration_isoformat(parse_duration('P1H'))


def test_dictionary(table):
Expand Down
9 changes: 4 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,15 @@ def test_datetime(executor):
def test_duration(executor):
data = executor(
"""{ scan(columns: {alias: "diff", checked: true, subtract: [{name: "timestamp"}, {name: "timestamp"}]})
{ column(name: "diff") { ... on DurationColumn { unique { values } } } } }"""
{ column(name: "diff") { ... on DurationColumn { unique { values } } } } }"""
)
assert data == {'scan': {'column': {'unique': {'values': [0.0, None]}}}}
assert data == {'scan': {'column': {'unique': {'values': ['PT0S', None]}}}}
data = executor('{ runs(split: [{name: "timestamp", gt: 0.0}]) { length } }')
assert data == {'runs': {'length': 1}}
data = executor("""{ scan(columns: {alias: "diff", temporal:
{monthDayNanoIntervalBetween: [{name: "timestamp"}, {name: "timestamp"}]}})
{ column(name: "diff") { ... on IntervalColumn { values { months days nanoseconds } } } } }""")
value = {'months': 0, 'days': 0, 'nanoseconds': 0}
assert data == {'scan': {'column': {'values': [value, None]}}}
{ column(name: "diff") { ... on DurationColumn { values } } } }""")
assert data == {'scan': {'column': {'values': ['PT0S', None]}}}


def test_list(executor):
Expand Down

0 comments on commit 8dc4e30

Please sign in to comment.