From 431dfe9c6dc7b86256670350b6cd96842642a18a Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 27 Feb 2024 12:46:30 +0100 Subject: [PATCH] perf(python): 2-3x speedup in creating literals/Series of type `Date` (#14716) --- py-polars/polars/functions/lit.py | 50 ++++++++++----------- py-polars/polars/utils/convert.py | 10 ++--- py-polars/tests/unit/utils/test_utils.py | 57 ++++++++++++++++-------- 3 files changed, 68 insertions(+), 49 deletions(-) diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index bb29f2c9b069..fc03ee0fa005 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -10,6 +10,7 @@ from polars.dependencies import numpy as np from polars.utils._wrap import wrap_expr from polars.utils.convert import ( + _date_to_pl_date, _datetime_to_pl_timestamp, _time_to_pl_time, _timedelta_to_pl_timedelta, @@ -35,7 +36,8 @@ def lit( value Value that should be used as a `literal`. dtype - Optionally define a dtype. + The data type of the resulting expression. + If set to `None` (default), the data type is inferred from the `value` input. allow_object If type is unknown use an 'object' type. By default, we will raise a `ValueException` @@ -43,7 +45,7 @@ def lit( Notes ----- - Expected datatypes + Expected datatypes: - `pl.lit([])` -> empty Series Float32 - `pl.lit([1, 2, 3])` -> Series Int64 @@ -80,27 +82,22 @@ def lit( else: time_unit = "us" - time_zone = ( - value.tzinfo - if getattr(dtype, "time_zone", None) is None - else getattr(dtype, "time_zone", None) - ) - if ( - value.tzinfo is not None - and getattr(dtype, "time_zone", None) is not None - and dtype.time_zone != str(value.tzinfo) # type: ignore[union-attr] - ): - msg = f"time zone of dtype ({dtype.time_zone!r}) differs from time zone of value ({value.tzinfo!r})" # type: ignore[union-attr] - raise TypeError(msg) - e = lit( - _datetime_to_pl_timestamp(value.replace(tzinfo=timezone.utc), time_unit) - ).cast(Datetime(time_unit)) + time_zone: str | None = getattr(dtype, "time_zone", None) + if (tzinfo := value.tzinfo) is not None: + tzinfo_str = str(tzinfo) + if time_zone is not None and time_zone != tzinfo_str: + msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})" + raise TypeError(msg) + time_zone = tzinfo_str + + dt_utc = value.replace(tzinfo=timezone.utc) + dt_int = _datetime_to_pl_timestamp(dt_utc, time_unit) + expr = lit(dt_int).cast(Datetime(time_unit)) if time_zone is not None: - return e.dt.replace_time_zone( - str(time_zone), ambiguous="earliest" if value.fold == 0 else "latest" + expr = expr.dt.replace_time_zone( + time_zone, ambiguous="earliest" if value.fold == 0 else "latest" ) - else: - return e + return expr elif isinstance(value, timedelta): if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: @@ -108,15 +105,16 @@ def lit( else: time_unit = "us" - return lit(_timedelta_to_pl_timedelta(value, time_unit)).cast( - Duration(time_unit) - ) + td_int = _timedelta_to_pl_timedelta(value, time_unit) + return lit(td_int).cast(Duration(time_unit)) elif isinstance(value, time): - return lit(_time_to_pl_time(value)).cast(Time) + time_int = _time_to_pl_time(value) + return lit(time_int).cast(Time) elif isinstance(value, date): - return lit(datetime(value.year, value.month, value.day)).cast(Date) + date_int = _date_to_pl_date(value) + return lit(date_int).cast(Date) elif isinstance(value, pl.Series): value = value._s diff --git a/py-polars/polars/utils/convert.py b/py-polars/polars/utils/convert.py index 2e963cf04a34..7008f82235ff 100644 --- a/py-polars/polars/utils/convert.py +++ b/py-polars/polars/utils/convert.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from datetime import datetime, time, timedelta, timezone +from datetime import date, datetime, time, timedelta, timezone from decimal import Context from functools import lru_cache from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar, overload @@ -10,7 +10,7 @@ if TYPE_CHECKING: from collections.abc import Reversible - from datetime import date, tzinfo + from datetime import tzinfo from decimal import Decimal from polars.type_aliases import TimeUnit @@ -51,6 +51,7 @@ def get_zoneinfo(key: str) -> ZoneInfo: # noqa: D103 US_PER_SECOND = 1_000_000 MS_PER_SECOND = 1_000 +EPOCH_DATE = date(1970, 1, 1) EPOCH = datetime(1970, 1, 1).replace(tzinfo=None) EPOCH_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc) @@ -108,14 +109,13 @@ def _time_to_pl_time(t: time) -> int: def _date_to_pl_date(d: date) -> int: - dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc) - return int(dt.timestamp()) // SECONDS_PER_DAY + return (d - EPOCH_DATE).days def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int: """Convert a python datetime to a timestamp in given time unit.""" if dt.tzinfo is None: - # Make sure to use UTC rather than system time zone. + # Make sure to use UTC rather than system time zone dt = dt.replace(tzinfo=timezone.utc) microseconds = dt.microsecond seconds = _timestamp_in_seconds(dt) diff --git a/py-polars/tests/unit/utils/test_utils.py b/py-polars/tests/unit/utils/test_utils.py index fc84cc3d59f8..1b5e79a40586 100644 --- a/py-polars/tests/unit/utils/test_utils.py +++ b/py-polars/tests/unit/utils/test_utils.py @@ -32,13 +32,18 @@ @pytest.mark.parametrize( ("dt", "time_unit", "expected"), [ - (datetime(2121, 1, 1), "ns", 4765132800000000000), - (datetime(2121, 1, 1), "us", 4765132800000000), - (datetime(2121, 1, 1), "ms", 4765132800000), + (datetime(2121, 1, 1), "ns", 4_765_132_800_000_000_000), + (datetime(2121, 1, 1), "us", 4_765_132_800_000_000), + (datetime(2121, 1, 1), "ms", 4_765_132_800_000), + (datetime(2121, 1, 1), None, 4_765_132_800_000_000), + (datetime.min, "ns", -62_135_596_800_000_000_000), + (datetime.max, "ns", 253_402_300_799_999_999_000), + (datetime.min, "ms", -62_135_596_800_000), + (datetime.max, "ms", 253_402_300_799_999), ], ) def test_datetime_to_pl_timestamp( - dt: datetime, time_unit: TimeUnit, expected: int + dt: datetime, time_unit: TimeUnit | None, expected: int ) -> None: out = _datetime_to_pl_timestamp(dt, time_unit) assert out == expected @@ -47,31 +52,47 @@ def test_datetime_to_pl_timestamp( @pytest.mark.parametrize( ("t", "expected"), [ - (time(0, 0, 0), 0), (time(0, 0, 1), 1_000_000_000), (time(20, 52, 10), 75_130_000_000_000), (time(20, 52, 10, 200), 75_130_000_200_000), + (time.min, 0), + (time.max, 86_399_999_999_000), ], ) def test_time_to_pl_time(t: time, expected: int) -> None: assert _time_to_pl_time(t) == expected -def test_date_to_pl_date() -> None: - d = date(1999, 9, 9) - out = _date_to_pl_date(d) - assert out == 10843 +@pytest.mark.parametrize( + ("d", "expected"), + [ + (date(1999, 9, 9), 10_843), + (date(1969, 12, 31), -1), + (date.min, -719_162), + (date.max, 2_932_896), + ], +) +def test_date_to_pl_date(d: date, expected: int) -> None: + assert _date_to_pl_date(d) == expected -def test_timedelta_to_pl_timedelta() -> None: - out = _timedelta_to_pl_timedelta(timedelta(days=1), "ns") - assert out == 86_400_000_000_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), "us") - assert out == 86_400_000_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), "ms") - assert out == 86_400_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), time_unit=None) - assert out == 86_400_000_000 +@pytest.mark.parametrize( + ("td", "time_unit", "expected"), + [ + (timedelta(days=1), "ns", 86_400_000_000_000), + (timedelta(days=1), "us", 86_400_000_000), + (timedelta(days=1), "ms", 86_400_000), + (timedelta(days=1), None, 86_400_000_000), + (timedelta.min, "ns", -86_399_999_913_600_000_000_000), + (timedelta.max, "ns", 86_399_999_999_999_999_999_000), + (timedelta.min, "ms", -86_399_999_913_600_000), + (timedelta.max, "ms", 86_399_999_999_999_999), + ], +) +def test_timedelta_to_pl_timedelta( + td: timedelta, time_unit: TimeUnit | None, expected: int +) -> None: + assert _timedelta_to_pl_timedelta(td, time_unit) == expected @pytest.mark.parametrize(