Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): respect time_zone in lazy date_range #8591

Merged
merged 17 commits into from
Jul 2, 2023
Merged
33 changes: 32 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,38 @@ impl FunctionExpr {
}
#[cfg(feature = "timezones")]
TzLocalize(tz) => return mapper.map_datetime_dtype_timezone(Some(tz)),
DateRange { .. } => return mapper.map_to_supertype(),
DateRange {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the output dtype may change according to tz, so need to do some extra computation here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, can you add that as comment?

every: _,
closed: _,
tz,
} => {
let mut ret = mapper.map_to_supertype()?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hoist this into a function. That keeps the schema branch a bit more lean.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup! by doing that, I realised that I hadn't quite done it correctly, and that it was wrong for (lazy) time_range too #9036

+1 for smaller functions and leaner logic, thanks!

let ret_dtype = match (&ret.dtype, tz) {
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(field_tz)), Some(tz)) => {
if field_tz != tz {
polars_bail!(ComputeError: format!("Given time_zone is different from that of timezone aware datetimes. \
Given: '{}', got: '{}'.", tz, field_tz))
}
DataType::Datetime(*tu, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), _) => {
DataType::Datetime(*tu, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => {
DataType::Datetime(*tu, Some(tz.to_string()))
}
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Date, _) => DataType::Date,
(dtype, _) => {
polars_bail!(ComputeError: "expected Date or Datetime, got {}", dtype)
}
};
ret.coerce(ret_dtype);
return mapper.map_to_supertype();
}
TimeRange { .. } => DataType::Time,
Combine(tu) => match mapper.with_same_dtype().unwrap().dtype {
DataType::Datetime(_, tz) => DataType::Datetime(*tu, tz),
Expand Down
88 changes: 73 additions & 15 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub(super) fn temporal_range_dispatch(
name: &str,
every: Duration,
closed: ClosedWindow,
_tz: Option<TimeZone>, // todo: respect _tz: https://github.com/pola-rs/polars/issues/8512
tz: Option<TimeZone>,
) -> PolarsResult<Series> {
let start = &s[0];
let stop = &s[1];
Expand All @@ -93,21 +93,45 @@ pub(super) fn temporal_range_dispatch(
);
const TO_MS: i64 = SECONDS_IN_DAY * 1000;

let rng_start = start.to_physical_repr();
let rng_stop = stop.to_physical_repr();
let dtype = start.dtype();
// Note: `start` and `stop` have already been cast to their supertype,
// so only `start`'s dtype needs to be checked.
let start_dtype = start.dtype();

let mut start = rng_start.cast(&DataType::Int64)?;
let mut stop = rng_stop.cast(&DataType::Int64)?;
let (mut start, mut stop) = match start_dtype {
#[cfg(feature = "timezones")]
DataType::Datetime(_, Some(_)) => (
start
.datetime()
.unwrap()
.replace_time_zone(None, None)?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if start and stop have a time zone, I'm removing it, as date_range_impl deals with that

.into_series()
.to_physical_repr()
.cast(&DataType::Int64)?,
stop.datetime()
.unwrap()
.replace_time_zone(None, None)?
.into_series()
.to_physical_repr()
.cast(&DataType::Int64)?,
),
_ => (
start.to_physical_repr().cast(&DataType::Int64)?,
stop.to_physical_repr().cast(&DataType::Int64)?,
),
};

let (tu, tz) = match dtype {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the output dtype isn't necessarily start.dtype() (if tz is not None). so instead of just recording tu and tz, I'm assigning to dtype, which contains tu and tz.

Then,

  • for Date, I just use Milliseconds as tu and None for tz
  • for Time, same, but with Nanoseconds as tu
  • for Datetime, use its tu and tz

DataType::Date => {
let dtype = match (start_dtype, tz) {
(DataType::Date, _) => {
start = &start * TO_MS;
stop = &stop * TO_MS;
(TimeUnit::Milliseconds, None)
DataType::Date
}
DataType::Datetime(tu, tz) => (*tu, tz.as_ref()),
DataType::Time => (TimeUnit::Nanoseconds, None),
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => DataType::Datetime(*tu, Some(tz)),
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), None) => DataType::Datetime(*tu, Some(tz.to_string())),
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Time, _) => DataType::Time,
_ => unimplemented!(),
};
let start = start.i64().unwrap();
Expand All @@ -124,7 +148,15 @@ pub(super) fn temporal_range_dispatch(
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz)?;
let rng = date_range_impl(
"",
start,
stop,
every,
closed,
TimeUnit::Milliseconds,
None,
)?;
let rng = rng.cast(&DataType::Date).unwrap();
let rng = rng.to_physical_repr();
let rng = rng.i32().unwrap();
Expand All @@ -135,7 +167,25 @@ pub(super) fn temporal_range_dispatch(
}
builder.finish().into_series()
}
DataType::Datetime(_, _) | DataType::Time => {
DataType::Datetime(tu, ref tz) => {
Comment on lines -138 to +170
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separating the Datetime and Time arms to make use of tu and tz from Datetime

let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
name,
start.len(),
start.len() * 5,
DataType::Int64,
);
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz.as_ref())?;
builder.append_slice(rng.cont_slice().unwrap())
}
_ => builder.append_null(),
}
}
builder.finish().into_series()
}
DataType::Time => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
name,
start.len(),
Expand All @@ -145,7 +195,15 @@ pub(super) fn temporal_range_dispatch(
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz)?;
let rng = date_range_impl(
"",
start,
stop,
every,
closed,
TimeUnit::Nanoseconds,
None,
)?;
builder.append_slice(rng.cont_slice().unwrap())
}
_ => builder.append_null(),
Expand All @@ -156,6 +214,6 @@ pub(super) fn temporal_range_dispatch(
_ => unimplemented!(),
};

let to_type = DataType::List(Box::new(dtype.clone()));
let to_type = DataType::List(Box::new(dtype));
list.cast(&to_type)
}
75 changes: 75 additions & 0 deletions py-polars/tests/unit/functions/test_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import polars as pl
from polars.datatypes import DTYPE_TEMPORAL_UNITS
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal
from polars.utils.convert import get_zoneinfo as ZoneInfo

Expand Down Expand Up @@ -188,6 +189,80 @@ def test_date_range_lazy_with_literals() -> None:
)


@pytest.mark.parametrize(
("start_tzinfo", "time_zone", "output_tzinfo"),
[
(ZoneInfo("Asia/Kathmandu"), "Asia/Kathmandu", ZoneInfo("Asia/Kathmandu")),
(ZoneInfo("Asia/Kathmandu"), None, ZoneInfo("Asia/Kathmandu")),
(None, "Asia/Kathmandu", ZoneInfo("Asia/Kathmandu")),
(None, None, None),
],
)
def test_date_range_lazy_time_zones(
start_tzinfo: ZoneInfo | None, time_zone: str | None, output_tzinfo: ZoneInfo | None
) -> None:
start_time_zone = start_tzinfo.key if start_tzinfo is not None else start_tzinfo
output_time_zone = output_tzinfo.key if output_tzinfo is not None else output_tzinfo
ldf = (
pl.DataFrame(
{
"start": [datetime(2015, 6, 30)],
"stop": [datetime(2022, 12, 31)],
}
)
.with_columns(
pl.col("start").dt.replace_time_zone(start_time_zone),
pl.col("stop").dt.replace_time_zone(start_time_zone),
)
.with_columns(
pl.date_range(
"start", "stop", interval="678d", eager=False, time_zone=time_zone
).alias("dts")
)
.lazy()
)
assert ldf.schema == {
"start": pl.Datetime(time_unit="us", time_zone=start_time_zone),
"stop": pl.Datetime(time_unit="us", time_zone=start_time_zone),
"dts": pl.List(pl.Datetime(time_unit="us", time_zone=output_time_zone)),
}
assert ldf.collect().rows() == [
(
datetime(2015, 6, 30, tzinfo=start_tzinfo),
datetime(2022, 12, 31, tzinfo=start_tzinfo),
[
datetime(2015, 6, 30, tzinfo=output_tzinfo),
datetime(2017, 5, 8, tzinfo=output_tzinfo),
datetime(2019, 3, 17, tzinfo=output_tzinfo),
datetime(2021, 1, 23, tzinfo=output_tzinfo),
datetime(2022, 12, 2, tzinfo=output_tzinfo),
],
)
]


def test_date_range_lazy_time_zones_invalid() -> None:
start = datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu"))
stop = datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu"))
with pytest.raises(
ComputeError,
match="Given time_zone is different from that of timezone aware datetimes. Given: 'Pacific/Tarawa', got: 'Asia/Kathmandu",
):
(
pl.DataFrame({"start": [start], "stop": [stop]})
.with_columns(
pl.date_range(
start,
stop,
interval="678d",
eager=False,
time_zone="Pacific/Tarawa",
)
)
.lazy()
)


@pytest.mark.parametrize("low", ["start", pl.col("start")])
@pytest.mark.parametrize("high", ["stop", pl.col("stop")])
def test_date_range_lazy_with_expressions(
Expand Down