Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Enable generating data with time zones in parametric testing #16298

Merged
merged 7 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def sequence_to_pyseries(

if (dtype == Datetime) and (value.tzinfo is not None or time_zone is not None):
values_tz = str(value.tzinfo) if value.tzinfo is not None else None
dtype_tz = dtype.time_zone # type: ignore[union-attr]
dtype_tz = time_zone
if values_tz is not None and (dtype_tz is not None and dtype_tz != "UTC"):
msg = (
"time-zone-aware datetimes are converted to UTC"
Expand Down
18 changes: 15 additions & 3 deletions py-polars/polars/testing/parametric/strategies/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def series( # noqa: D417
unique: bool = False,
allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
allow_time_zones: bool = True,
**kwargs: Any,
) -> Series:
"""
Hypothesis strategy for producing polars Series.
Hypothesis strategy for producing Polars Series.

Parameters
----------
Expand Down Expand Up @@ -77,6 +78,8 @@ def series( # noqa: D417
when automatically generating Series data, allow only these dtypes.
excluded_dtypes : {list,set}, optional
when automatically generating Series data, exclude these dtypes.
allow_time_zones
Allow generating `Datetime` Series with a time zone.
**kwargs
Additional keyword arguments that are passed to the underlying data generation
strategies.
Expand Down Expand Up @@ -162,13 +165,16 @@ def series( # noqa: D417
if strategy is None:
if dtype is None:
dtype_strat = dtypes(
allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
allow_time_zones=allow_time_zones,
)
else:
dtype_strat = _instantiate_dtype(
dtype,
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
allow_time_zones=allow_time_zones,
)
dtype = draw(dtype_strat)

Expand Down Expand Up @@ -223,6 +229,7 @@ def dataframes(
allow_chunks: bool = True,
allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
allow_time_zones: bool = True,
**kwargs: Any,
) -> SearchStrategy[DataFrame]: ...

Expand All @@ -242,6 +249,7 @@ def dataframes(
allow_chunks: bool = True,
allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
allow_time_zones: bool = True,
**kwargs: Any,
) -> SearchStrategy[LazyFrame]: ...

Expand All @@ -263,10 +271,11 @@ def dataframes( # noqa: D417
allow_chunks: bool = True,
allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None,
allow_time_zones: bool = True,
**kwargs: Any,
) -> DataFrame | LazyFrame:
"""
Hypothesis strategy for producing polars DataFrames or LazyFrames.
Hypothesis strategy for producing Polars DataFrames or LazyFrames.

Parameters
----------
Expand Down Expand Up @@ -302,6 +311,8 @@ def dataframes( # noqa: D417
when automatically generating data, allow only these dtypes.
excluded_dtypes : {list,set}, optional
when automatically generating data, exclude these dtypes.
allow_time_zones
Allow generating `Datetime` columns with a time zone.
**kwargs
Additional keyword arguments that are passed to the underlying data generation
strategies.
Expand Down Expand Up @@ -436,6 +447,7 @@ def dataframes( # noqa: D417
unique=c.unique,
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
allow_time_zones=allow_time_zones,
**kwargs,
)
)
Expand Down
49 changes: 36 additions & 13 deletions py-polars/polars/testing/parametric/strategies/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import decimal
from datetime import timedelta
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence

import hypothesis.strategies as st
Expand All @@ -24,6 +24,7 @@
U32_MAX,
U64_MAX,
)
from polars._utils.convert import string_to_zoneinfo
from polars.datatypes import (
Array,
Binary,
Expand Down Expand Up @@ -59,7 +60,7 @@
)

if TYPE_CHECKING:
from datetime import date, datetime, time
from datetime import date, time

from hypothesis.strategies import SearchStrategy

Expand Down Expand Up @@ -138,27 +139,47 @@ def dates() -> SearchStrategy[date]:
return st.dates()


def datetimes(time_unit: TimeUnit = "us") -> SearchStrategy[datetime]:
def datetimes(
time_unit: TimeUnit = "us", time_zone: str | None = None
) -> SearchStrategy[datetime]:
"""
Create a strategy for generating `datetime` objects in the time unit's range.

Parameters
----------
time_unit
Time unit for which the datetime objects are valid.
time_zone
Time zone for which the datetime objects are valid.
"""
if time_unit in ("us", "ms"):
# datetime.min/max fall within the range
return st.datetimes()
min_value = datetime.min
max_value = datetime.max
elif time_unit == "ns":
return st.datetimes(
min_value=EPOCH + timedelta(microseconds=I64_MIN // 1000 + 1),
max_value=EPOCH + timedelta(microseconds=I64_MAX // 1000),
)
min_value = EPOCH + timedelta(microseconds=I64_MIN // 1000 + 1)
max_value = EPOCH + timedelta(microseconds=I64_MAX // 1000)
else:
msg = f"invalid time unit: {time_unit}"
msg = f"invalid time unit: {time_unit!r}"
raise InvalidArgument(msg)

if time_zone is None:
return st.datetimes(min_value, max_value)

time_zone_info = string_to_zoneinfo(time_zone)

# Make sure time zone offsets do not cause out-of-bound datetimes
if time_unit == "ns":
min_value += timedelta(days=1)
max_value -= timedelta(days=1)

# Return naive datetimes, but make sure they are valid for the given time zone
return st.datetimes(
min_value=min_value,
max_value=max_value,
timezones=st.just(time_zone_info),
allow_imaginary=False,
).map(lambda dt: dt.astimezone(timezone.utc).replace(tzinfo=None))


def durations(time_unit: TimeUnit = "us") -> SearchStrategy[timedelta]:
"""
Expand Down Expand Up @@ -188,7 +209,7 @@ def durations(time_unit: TimeUnit = "us") -> SearchStrategy[timedelta]:
max_value=timedelta(microseconds=I64_MAX),
)
else:
msg = f"invalid time unit: {time_unit}"
msg = f"invalid time unit: {time_unit!r}"
raise InvalidArgument(msg)


Expand Down Expand Up @@ -365,8 +386,10 @@ def data(
elif dtype == Float64:
strategy = floats(64, allow_infinity=kwargs.pop("allow_infinity", True))
elif dtype == Datetime:
# TODO: Handle time zones
strategy = datetimes(time_unit=getattr(dtype, "time_unit", None) or "us")
strategy = datetimes(
time_unit=getattr(dtype, "time_unit", None) or "us",
time_zone=getattr(dtype, "time_zone", None),
)
elif dtype == Duration:
strategy = durations(time_unit=getattr(dtype, "time_unit", None) or "us")
elif dtype == Categorical:
Expand Down
Loading