Skip to content

Commit

Permalink
fix(python): if given, respect dtype timeunit when instantiating `pl.…
Browse files Browse the repository at this point in the history
…lit` value (#6991)
  • Loading branch information
alexander-beedie authored Feb 18, 2023
1 parent 49f71a8 commit 5098966
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
5 changes: 3 additions & 2 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,16 +1127,17 @@ def lit(
"""
tu: TimeUnit

if isinstance(value, datetime):
tu = "us"
tu = "us" if dtype is None else getattr(dtype, "tu", "us")
e = lit(_datetime_to_pl_timestamp(value, tu)).cast(Datetime(tu))
if value.tzinfo is not None:
return e.dt.replace_time_zone(str(value.tzinfo))
else:
return e

elif isinstance(value, timedelta):
tu = "us"
tu = "us" if dtype is None else getattr(dtype, "tu", "us")
return lit(_timedelta_to_pl_timedelta(value, tu)).cast(Duration(tu))

elif isinstance(value, time):
Expand Down
41 changes: 40 additions & 1 deletion py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import random
import typing
from typing import cast
from datetime import datetime, timedelta
from typing import Any, cast

import numpy as np
import pytest
Expand All @@ -15,6 +16,7 @@
INTEGER_DTYPES,
NUMERIC_DTYPES,
TEMPORAL_DTYPES,
PolarsDataType,
)
from polars.testing import assert_frame_equal, assert_series_equal

Expand Down Expand Up @@ -520,3 +522,40 @@ def test_map_dict() -> None:
"country_code": ["FR", None, "ES", "DE"],
"remapped": ["France", "Not specified", "2", "Germany"],
}


def test_lit_dtypes() -> None:
def lit_series(value: Any, dtype: PolarsDataType) -> pl.Series:
return pl.select(pl.lit(value, dtype=dtype)).to_series()

d = datetime(2049, 10, 5, 1, 2, 3, 987654)
d_ms = datetime(2049, 10, 5, 1, 2, 3, 987000)

td = timedelta(days=942, hours=6, microseconds=123456)
td_ms = timedelta(days=942, seconds=21600, microseconds=123000)

df = pl.DataFrame(
{
"dtm_ms": lit_series(d, pl.Datetime("ms")),
"dtm_us": lit_series(d, pl.Datetime("us")),
"dtm_ns": lit_series(d, pl.Datetime("ns")),
"dur_ms": lit_series(td, pl.Duration("ms")),
"dur_us": lit_series(td, pl.Duration("us")),
"dur_ns": lit_series(td, pl.Duration("ns")),
"f32": lit_series(0, pl.Float32),
"u16": lit_series(0, pl.UInt16),
"i16": lit_series(0, pl.Int16),
}
)
assert df.dtypes == [
pl.Datetime("ms"),
pl.Datetime("us"),
pl.Datetime("ns"),
pl.Duration("ms"),
pl.Duration("us"),
pl.Duration("ns"),
pl.Float32,
pl.UInt16,
pl.Int16,
]
assert df.row(0) == (d_ms, d, d, td_ms, td, td, 0, 0, 0)

0 comments on commit 5098966

Please sign in to comment.