Skip to content

Commit

Permalink
feat(rust, python): let cast_time_zone work on tz-naive and deprecate…
Browse files Browse the repository at this point in the history
… tz-localize (pola-rs#6649)

Co-authored-by: MarcoGorelli <>
  • Loading branch information
MarcoGorelli authored and vincent committed Feb 9, 2023
1 parent 8f75892 commit 92a4522
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 38 deletions.
19 changes: 16 additions & 3 deletions polars/polars-core/src/chunked_array/temporal/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,22 @@ impl DatetimeChunked {
let out = unsafe { ChunkedArray::from_chunks(self.name(), chunks) };
Ok(out.into_datetime(self.time_unit(), None))
}
(_, _) => Err(PolarsError::ComputeError(
"Cannot cast Naive Datetime. First set a timezone".into(),
)),
(None, Some(to)) => {
let chunks = self
.downcast_iter()
.map(|arr| {
Ok(cast_timezone(
arr,
self.time_unit().to_arrow(),
to.to_string(),
"UTC".to_string(),
)?)
})
.collect::<PolarsResult<_>>()?;
let out = unsafe { ChunkedArray::from_chunks(self.name(), chunks) };
Ok(out.into_datetime(self.time_unit(), Some(to.to_string())))
}
(None, None) => Ok(self.clone()),
}
}

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def sequence_to_pyseries(
if dtype == Datetime and value.tzinfo is not None:
py_series = PySeries.new_from_anyvalues(name, values)
tz = str(value.tzinfo)
return pli.wrap_s(py_series).dt.tz_localize(tz)._s
return pli.wrap_s(py_series).dt.cast_time_zone(tz)._s

# TODO: use anyvalues here (no need to require pyarrow for this).
arrow_dtype = dtype_to_arrow_type(dtype)
Expand Down
11 changes: 11 additions & 0 deletions py-polars/polars/internals/expr/datetime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from datetime import time, timedelta
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -1247,6 +1248,10 @@ def tz_localize(self, tz: str) -> pli.Expr:
This method takes a naive Datetime Series and makes this time zone aware.
It does not move the time to another time zone.
.. deprecated:: 0.16.3
`with_column` will be removed in favor of the more generic `with_columns`
in version 0.18.0.
Parameters
----------
tz
Expand Down Expand Up @@ -1281,6 +1286,12 @@ def tz_localize(self, tz: str) -> pli.Expr:
│ 2020-05-01 00:00:00 ┆ 2020-05-01 00:00:00 CEST │
└─────────────────────┴────────────────────────────────┘
"""
warnings.warn(
"`tz_localize` has been deprecated in favor of `cast_time_zone`."
" This method will be removed in version 0.18.0",
category=DeprecationWarning,
stacklevel=2,
)
return pli.wrap_expr(self._pyexpr.dt_tz_localize(tz))

def days(self) -> pli.Expr:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ def lit(
tu = "us"
e = lit(_datetime_to_pl_timestamp(value, tu)).cast(Datetime(tu))
if value.tzinfo is not None:
return e.dt.tz_localize(str(value.tzinfo))
return e.dt.cast_time_zone(str(value.tzinfo))
else:
return e

Expand Down
18 changes: 17 additions & 1 deletion py-polars/polars/internals/series/datetime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -1045,12 +1046,27 @@ def tz_localize(self, tz: str) -> pli.Series:
This method takes a naive Datetime Series and makes this time zone aware.
It does not move the time to another time zone.
.. deprecated:: 0.16.3
`with_column` will be removed in favor of the more generic `with_columns`
in version 0.18.0.
Parameters
----------
tz
Time zone for the `Datetime` Series.
"""
warnings.warn(
"`tz_localize` has been deprecated in favor of `cast_time_zone`."
" This method will be removed in version 0.18.0",
category=DeprecationWarning,
stacklevel=2,
)
return (
pli.wrap_s(self._s)
.to_frame()
.select(pli.col(self._s.name()).dt.tz_localize(tz))
.to_series()
)

def days(self) -> pli.Series:
"""
Expand Down
56 changes: 24 additions & 32 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,35 +2060,23 @@ def test_cast_timezone_invalid_timezone() -> None:


@pytest.mark.parametrize(
("from_tz", "to_tz", "tzinfo"),
("to_tz", "tzinfo"),
[
(
"America/Barbados",
"+01:00",
timezone(timedelta(seconds=3600)),
),
("+01:00", "America/Barbados", zoneinfo.ZoneInfo(key="America/Barbados")),
(
"America/Barbados",
"Europe/Helsinki",
zoneinfo.ZoneInfo(key="Europe/Helsinki"),
),
(
"+02:00",
"+01:00",
timezone(timedelta(seconds=3600)),
),
("+01:00", timezone(timedelta(seconds=3600))),
("America/Barbados", zoneinfo.ZoneInfo(key="America/Barbados")),
(None, None),
],
)
@pytest.mark.parametrize("from_tz", ["Asia/Seoul", "-01:00", None])
@pytest.mark.parametrize("tu", ["ms", "us", "ns"])
def test_cast_timezone_fixed_offsets_and_area_location(
def test_cast_timezone_from_to(
from_tz: str,
to_tz: str,
tzinfo: timezone | zoneinfo.ZoneInfo,
tu: TimeUnit,
) -> None:
ts = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime(tu))
result = ts.dt.tz_localize(from_tz).dt.cast_time_zone(to_tz).item()
result = ts.dt.cast_time_zone(from_tz).dt.cast_time_zone(to_tz).item()
expected = datetime(2020, 1, 1, 0, 0, tzinfo=tzinfo)
assert result == expected

Expand Down Expand Up @@ -2209,7 +2197,7 @@ def test_logical_nested_take() -> None:
}


def test_tz_localize() -> None:
def test_cast_time_zone_from_naive() -> None:
df = pl.DataFrame(
{
"date": pl.Series(["2022-01-01", "2022-01-02"]).str.strptime(
Expand All @@ -2219,7 +2207,7 @@ def test_tz_localize() -> None:
)

assert df.select(
pl.col("date").cast(pl.Datetime).dt.tz_localize("America/New_York")
pl.col("date").cast(pl.Datetime).dt.cast_time_zone("America/New_York")
).to_dict(False) == {
"date": [
datetime(
Expand All @@ -2235,14 +2223,18 @@ def test_tz_localize() -> None:
@pytest.mark.parametrize("time_zone", ["UTC", "Africa/Abidjan"])
def test_tz_localize_from_utc(time_zone: str) -> None:
ts_utc = (
pl.Series(["2018-10-28"]).str.strptime(pl.Datetime).dt.tz_localize(time_zone)
pl.Series(["2018-10-28"]).str.strptime(pl.Datetime).dt.cast_time_zone(time_zone)
)
with pytest.raises(
ComputeError,
match=(
"^Cannot localize a tz-aware datetime. Consider using "
"'dt.with_time_zone' or 'dt.cast_time_zone'$"
),
err_msg = (
"^Cannot localize a tz-aware datetime. Consider using "
"'dt.with_time_zone' or 'dt.cast_time_zone'$"
)
deprecation_msg = (
"`tz_localize` has been deprecated in favor of `cast_time_zone`."
" This method will be removed in version 0.18.0"
)
with pytest.raises(ComputeError, match=err_msg), pytest.warns(
DeprecationWarning, match=deprecation_msg
):
ts_utc.dt.tz_localize("America/Maceio")

Expand All @@ -2259,7 +2251,7 @@ def test_tz_aware_truncate() -> None:
{
"dt": pl.date_range(
low=datetime(2022, 11, 1), high=datetime(2022, 11, 4), interval="12h"
).dt.tz_localize("America/New_York")
).dt.cast_time_zone("America/New_York")
}
)
assert test.with_columns(pl.col("dt").dt.truncate("1d").alias("trunced")).to_dict(
Expand Down Expand Up @@ -2323,7 +2315,7 @@ def test_tz_aware_truncate() -> None:
)
}
).lazy()
lf = lf.with_columns(pl.col("naive").dt.tz_localize("UTC").alias("UTC"))
lf = lf.with_columns(pl.col("naive").dt.cast_time_zone("UTC").alias("UTC"))
lf = lf.with_columns(pl.col("UTC").dt.with_time_zone("US/Central").alias("CST"))
lf = lf.with_columns(pl.col("CST").dt.truncate("1d").alias("CST truncated"))
assert lf.collect().to_dict(False) == {
Expand Down Expand Up @@ -2401,7 +2393,7 @@ def test_tz_aware_strftime() -> None:
{
"dt": pl.date_range(
low=datetime(2022, 11, 1), high=datetime(2022, 11, 4), interval="24h"
).dt.tz_localize("America/New_York")
).dt.cast_time_zone("America/New_York")
}
)
assert df.with_columns(pl.col("dt").dt.strftime("%c").alias("fmt")).to_dict(
Expand Down Expand Up @@ -2437,7 +2429,7 @@ def test_tz_aware_filter_lit() -> None:

assert (
pl.DataFrame({"date": pl.date_range(start, stop, "1h")})
.with_columns(pl.col("date").dt.tz_localize("America/New_York").alias("nyc"))
.with_columns(pl.col("date").dt.cast_time_zone("America/New_York").alias("nyc"))
.filter(pl.col("nyc") < dt)
).to_dict(False) == {
"date": [
Expand Down

0 comments on commit 92a4522

Please sign in to comment.