From 8dc4e3035b06f42135e55eb1a8d5056a74b15777 Mon Sep 17 00:00:00 2001 From: Aric Coady Date: Sun, 28 Apr 2024 10:15:15 -0700 Subject: [PATCH] Duration scalar. Uses ISO 8601 format for `timedelta` and `MonthDayNano`. The available libraries seem inactive or have dependencies. `isodate` is a likely candidate, if it has another release. --- CHANGELOG.md | 4 +++ graphique/models.py | 4 +-- graphique/scalars.py | 64 +++++++++++++++++++++++++++++++++++--------- tests/test_core.py | 18 +++++++++++++ tests/test_models.py | 9 +++---- 5 files changed, 80 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c2bb5a..7a8aa5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Changed * Pyarrow >=16 required * `group` optimized for datasets +* `Duration` scalar + +### Removed +* `Interval` type ## [1.5](https://pypi.org/project/graphique/1.5/) - 2024-01-24 ### Changed diff --git a/graphique/models.py b/graphique/models.py index dad5add..b321d1a 100644 --- a/graphique/models.py +++ b/graphique/models.py @@ -17,7 +17,7 @@ from strawberry.field import StrawberryField from .core import Column as C from .inputs import links -from .scalars import Interval, Long, scalar_map, type_map +from .scalars import Long, scalar_map, type_map if TYPE_CHECKING: # pragma: no cover from .interface import Dataset @@ -120,7 +120,7 @@ def values(self) -> list[Optional[T]]: return self.array.to_pylist() -@Column.register(timedelta, Interval) +@Column.register(timedelta, pa.MonthDayNano) @strawberry.type(name='Column', description="column of elapsed times") class NominalColumn(Generic[T], Column): values = doc_field(Set.values) diff --git a/graphique/scalars.py b/graphique/scalars.py index 24f221e..c11e258 100644 --- a/graphique/scalars.py +++ b/graphique/scalars.py @@ -2,8 +2,11 @@ GraphQL scalars. """ +import functools +import re from datetime import date, datetime, time, timedelta from decimal import Decimal +from typing import Union, no_type_check import pyarrow as pa import strawberry @@ -14,22 +17,59 @@ def parse_long(value) -> int: raise TypeError(f"Long cannot represent value: {value}") -@strawberry.type(description="month day nano interval") -class Interval: - months: int - days: int - nanoseconds: 'Long' # type: ignore +@no_type_check +def parse_duration(value: str): + months = days = seconds = 0 + d_val, _, t_val = value.partition('T') + parts = re.split(r'(-?\d+\.?\d*)', d_val.lower() + t_val) + if parts.pop(0) != 'p': + raise ValueError("Duration format must start with `P`") + for num, key in zip(parts[::2], parts[1::2]): + if n := float(num) if '.' in num else int(num): + if key in 'ym': + months += n * 12 if key == 'y' else n + elif key in 'wd': + days += n * 7 if key == 'w' else n + elif key in 'HMS': + seconds += n * {'H': 3600, 'M': 60, 'S': 1}[key] + else: + raise ValueError(f"Invalid duration field: {key.upper()}") + if months: + return pa.MonthDayNano([months, days, int(seconds * 1_000_000_000)]) + return timedelta(days, seconds) + + +@functools.singledispatch +def duration_isoformat(td: timedelta) -> str: + days = f'{td.days}D' if td.days else '' + fraction = f'.{td.microseconds:06}' if td.microseconds else '' + return f'P{days}T{td.seconds}{fraction}S' + + +@duration_isoformat.register +def _(mdn: pa.MonthDayNano) -> str: + months = f'{mdn.months}M' if mdn.months else '' + days = f'{mdn.days}D' if mdn.days else '' + seconds, nanoseconds = divmod(mdn.nanoseconds, 1_000_000_000) + fraction = f'.{nanoseconds:09}' if nanoseconds else '' + return f'P{months}{days}T{seconds}{fraction}S' Long = strawberry.scalar(int, name='Long', description="64-bit int", parse_value=parse_long) -Duration = strawberry.scalar( # pragma: no branch - timedelta, +Duration = strawberry.scalar( + Union[timedelta, pa.MonthDayNano], name='Duration', - description="duration float (in seconds)", - serialize=timedelta.total_seconds, - parse_value=lambda s: timedelta(seconds=s), + description="Duration (isoformat)", + specified_by_url="https://en.wikipedia.org/wiki/ISO_8601#Durations", + serialize=duration_isoformat, + parse_value=parse_duration, ) -scalar_map = {bytes: strawberry.scalars.Base64, dict: strawberry.scalars.JSON, timedelta: Duration} +scalar_map = { + bytes: strawberry.scalars.Base64, + dict: strawberry.scalars.JSON, + timedelta: Duration, + pa.MonthDayNano: Duration, +} type_map = { pa.lib.Type_BOOL: bool, @@ -52,7 +92,7 @@ class Interval: pa.lib.Type_TIME32: time, pa.lib.Type_TIME64: time, pa.lib.Type_DURATION: timedelta, - pa.lib.Type_INTERVAL_MONTH_DAY_NANO: Interval, + pa.lib.Type_INTERVAL_MONTH_DAY_NANO: pa.MonthDayNano, pa.lib.Type_BINARY: bytes, pa.lib.Type_FIXED_SIZE_BINARY: bytes, pa.lib.Type_LARGE_BINARY: bytes, diff --git a/tests/test_core.py b/tests/test_core.py index c8aed88..4606c2c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,6 +2,24 @@ import pyarrow.compute as pc import pytest from graphique.core import Agg, ListChunk, Column as C, Table as T +from graphique.scalars import parse_duration, duration_isoformat + + +def test_duration(): + assert duration_isoformat(parse_duration('P1Y1M1DT1H1M1.1S')) == 'P13M1DT3661.100000000S' + assert duration_isoformat(parse_duration('P1M1DT1H1M1.1S')) == 'P1M1DT3661.100000000S' + assert duration_isoformat(parse_duration('P1DT1H1M1.1S')) == 'P1DT3661.100000S' + assert duration_isoformat(parse_duration('PT1H1M1.1S')) == 'PT3661.100000S' + assert duration_isoformat(parse_duration('PT1M1.1S')) == 'PT61.100000S' + assert duration_isoformat(parse_duration('PT1.1S')) == 'PT1.100000S' + assert duration_isoformat(parse_duration('PT1S')) == 'PT1S' + assert duration_isoformat(parse_duration('P0D')) == 'PT0S' + assert duration_isoformat(parse_duration('PT0S')) == 'PT0S' + assert duration_isoformat(parse_duration('P-1DT-1H')) == 'P-2DT82800S' + with pytest.raises(ValueError): + duration_isoformat(parse_duration('T1H')) + with pytest.raises(ValueError): + duration_isoformat(parse_duration('P1H')) def test_dictionary(table): diff --git a/tests/test_models.py b/tests/test_models.py index 21b9454..331219c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -222,16 +222,15 @@ def test_datetime(executor): def test_duration(executor): data = executor( """{ scan(columns: {alias: "diff", checked: true, subtract: [{name: "timestamp"}, {name: "timestamp"}]}) - { column(name: "diff") { ... on DurationColumn { unique { values } } } } }""" + { column(name: "diff") { ... on DurationColumn { unique { values } } } } }""" ) - assert data == {'scan': {'column': {'unique': {'values': [0.0, None]}}}} + assert data == {'scan': {'column': {'unique': {'values': ['PT0S', None]}}}} data = executor('{ runs(split: [{name: "timestamp", gt: 0.0}]) { length } }') assert data == {'runs': {'length': 1}} data = executor("""{ scan(columns: {alias: "diff", temporal: {monthDayNanoIntervalBetween: [{name: "timestamp"}, {name: "timestamp"}]}}) - { column(name: "diff") { ... on IntervalColumn { values { months days nanoseconds } } } } }""") - value = {'months': 0, 'days': 0, 'nanoseconds': 0} - assert data == {'scan': {'column': {'values': [value, None]}}} + { column(name: "diff") { ... on DurationColumn { values } } } }""") + assert data == {'scan': {'column': {'values': ['PT0S', None]}}} def test_list(executor):