Skip to content

Commit

Permalink
perf(python): 2-3x speedup in creating literals/Series of type Date (
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Feb 27, 2024
1 parent 925f61a commit 431dfe9
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 49 deletions.
50 changes: 24 additions & 26 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from polars.dependencies import numpy as np
from polars.utils._wrap import wrap_expr
from polars.utils.convert import (
_date_to_pl_date,
_datetime_to_pl_timestamp,
_time_to_pl_time,
_timedelta_to_pl_timedelta,
Expand All @@ -35,15 +36,16 @@ def lit(
value
Value that should be used as a `literal`.
dtype
Optionally define a dtype.
The data type of the resulting expression.
If set to `None` (default), the data type is inferred from the `value` input.
allow_object
If type is unknown use an 'object' type.
By default, we will raise a `ValueException`
if the type is unknown.
Notes
-----
Expected datatypes
Expected datatypes:
- `pl.lit([])` -> empty Series Float32
- `pl.lit([1, 2, 3])` -> Series Int64
Expand Down Expand Up @@ -80,43 +82,39 @@ def lit(
else:
time_unit = "us"

time_zone = (
value.tzinfo
if getattr(dtype, "time_zone", None) is None
else getattr(dtype, "time_zone", None)
)
if (
value.tzinfo is not None
and getattr(dtype, "time_zone", None) is not None
and dtype.time_zone != str(value.tzinfo) # type: ignore[union-attr]
):
msg = f"time zone of dtype ({dtype.time_zone!r}) differs from time zone of value ({value.tzinfo!r})" # type: ignore[union-attr]
raise TypeError(msg)
e = lit(
_datetime_to_pl_timestamp(value.replace(tzinfo=timezone.utc), time_unit)
).cast(Datetime(time_unit))
time_zone: str | None = getattr(dtype, "time_zone", None)
if (tzinfo := value.tzinfo) is not None:
tzinfo_str = str(tzinfo)
if time_zone is not None and time_zone != tzinfo_str:
msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})"
raise TypeError(msg)
time_zone = tzinfo_str

dt_utc = value.replace(tzinfo=timezone.utc)
dt_int = _datetime_to_pl_timestamp(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
if time_zone is not None:
return e.dt.replace_time_zone(
str(time_zone), ambiguous="earliest" if value.fold == 0 else "latest"
expr = expr.dt.replace_time_zone(
time_zone, ambiguous="earliest" if value.fold == 0 else "latest"
)
else:
return e
return expr

elif isinstance(value, timedelta):
if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None:
time_unit = tu # type: ignore[assignment]
else:
time_unit = "us"

return lit(_timedelta_to_pl_timedelta(value, time_unit)).cast(
Duration(time_unit)
)
td_int = _timedelta_to_pl_timedelta(value, time_unit)
return lit(td_int).cast(Duration(time_unit))

elif isinstance(value, time):
return lit(_time_to_pl_time(value)).cast(Time)
time_int = _time_to_pl_time(value)
return lit(time_int).cast(Time)

elif isinstance(value, date):
return lit(datetime(value.year, value.month, value.day)).cast(Date)
date_int = _date_to_pl_date(value)
return lit(date_int).cast(Date)

elif isinstance(value, pl.Series):
value = value._s
Expand Down
10 changes: 5 additions & 5 deletions py-polars/polars/utils/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from datetime import datetime, time, timedelta, timezone
from datetime import date, datetime, time, timedelta, timezone
from decimal import Context
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar, overload
Expand All @@ -10,7 +10,7 @@

if TYPE_CHECKING:
from collections.abc import Reversible
from datetime import date, tzinfo
from datetime import tzinfo
from decimal import Decimal

from polars.type_aliases import TimeUnit
Expand Down Expand Up @@ -51,6 +51,7 @@ def get_zoneinfo(key: str) -> ZoneInfo: # noqa: D103
US_PER_SECOND = 1_000_000
MS_PER_SECOND = 1_000

EPOCH_DATE = date(1970, 1, 1)
EPOCH = datetime(1970, 1, 1).replace(tzinfo=None)
EPOCH_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc)

Expand Down Expand Up @@ -108,14 +109,13 @@ def _time_to_pl_time(t: time) -> int:


def _date_to_pl_date(d: date) -> int:
dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc)
return int(dt.timestamp()) // SECONDS_PER_DAY
return (d - EPOCH_DATE).days


def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int:
"""Convert a python datetime to a timestamp in given time unit."""
if dt.tzinfo is None:
# Make sure to use UTC rather than system time zone.
# Make sure to use UTC rather than system time zone
dt = dt.replace(tzinfo=timezone.utc)
microseconds = dt.microsecond
seconds = _timestamp_in_seconds(dt)
Expand Down
57 changes: 39 additions & 18 deletions py-polars/tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,18 @@
@pytest.mark.parametrize(
("dt", "time_unit", "expected"),
[
(datetime(2121, 1, 1), "ns", 4765132800000000000),
(datetime(2121, 1, 1), "us", 4765132800000000),
(datetime(2121, 1, 1), "ms", 4765132800000),
(datetime(2121, 1, 1), "ns", 4_765_132_800_000_000_000),
(datetime(2121, 1, 1), "us", 4_765_132_800_000_000),
(datetime(2121, 1, 1), "ms", 4_765_132_800_000),
(datetime(2121, 1, 1), None, 4_765_132_800_000_000),
(datetime.min, "ns", -62_135_596_800_000_000_000),
(datetime.max, "ns", 253_402_300_799_999_999_000),
(datetime.min, "ms", -62_135_596_800_000),
(datetime.max, "ms", 253_402_300_799_999),
],
)
def test_datetime_to_pl_timestamp(
dt: datetime, time_unit: TimeUnit, expected: int
dt: datetime, time_unit: TimeUnit | None, expected: int
) -> None:
out = _datetime_to_pl_timestamp(dt, time_unit)
assert out == expected
Expand All @@ -47,31 +52,47 @@ def test_datetime_to_pl_timestamp(
@pytest.mark.parametrize(
("t", "expected"),
[
(time(0, 0, 0), 0),
(time(0, 0, 1), 1_000_000_000),
(time(20, 52, 10), 75_130_000_000_000),
(time(20, 52, 10, 200), 75_130_000_200_000),
(time.min, 0),
(time.max, 86_399_999_999_000),
],
)
def test_time_to_pl_time(t: time, expected: int) -> None:
assert _time_to_pl_time(t) == expected


def test_date_to_pl_date() -> None:
d = date(1999, 9, 9)
out = _date_to_pl_date(d)
assert out == 10843
@pytest.mark.parametrize(
("d", "expected"),
[
(date(1999, 9, 9), 10_843),
(date(1969, 12, 31), -1),
(date.min, -719_162),
(date.max, 2_932_896),
],
)
def test_date_to_pl_date(d: date, expected: int) -> None:
assert _date_to_pl_date(d) == expected


def test_timedelta_to_pl_timedelta() -> None:
out = _timedelta_to_pl_timedelta(timedelta(days=1), "ns")
assert out == 86_400_000_000_000
out = _timedelta_to_pl_timedelta(timedelta(days=1), "us")
assert out == 86_400_000_000
out = _timedelta_to_pl_timedelta(timedelta(days=1), "ms")
assert out == 86_400_000
out = _timedelta_to_pl_timedelta(timedelta(days=1), time_unit=None)
assert out == 86_400_000_000
@pytest.mark.parametrize(
("td", "time_unit", "expected"),
[
(timedelta(days=1), "ns", 86_400_000_000_000),
(timedelta(days=1), "us", 86_400_000_000),
(timedelta(days=1), "ms", 86_400_000),
(timedelta(days=1), None, 86_400_000_000),
(timedelta.min, "ns", -86_399_999_913_600_000_000_000),
(timedelta.max, "ns", 86_399_999_999_999_999_999_000),
(timedelta.min, "ms", -86_399_999_913_600_000),
(timedelta.max, "ms", 86_399_999_999_999_999),
],
)
def test_timedelta_to_pl_timedelta(
td: timedelta, time_unit: TimeUnit | None, expected: int
) -> None:
assert _timedelta_to_pl_timedelta(td, time_unit) == expected


@pytest.mark.parametrize(
Expand Down

0 comments on commit 431dfe9

Please sign in to comment.