Skip to content

Commit

Permalink
fix(rust, python): respect time_zone in lazy date_range (pola-rs#8591)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored and c-peters committed Jul 14, 2023
1 parent e45958d commit 3cc9704
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 27 deletions.
33 changes: 30 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ impl FunctionExpr {
}
#[cfg(feature = "timezones")]
TzLocalize(tz) => return mapper.map_datetime_dtype_timezone(Some(tz)),
DateRange { .. } => {
let res = mapper.map_to_list_supertype()?;
return Ok(Field::new("date", res.dtype));
DateRange {
every: _,
closed: _,
tz,
} => {
// output dtype may change based on `tz`
return mapper.map_to_date_range_dtype(tz);
}
TimeRange { .. } => {
return Ok(Field::new("time", DataType::List(Box::new(DataType::Time))));
Expand Down Expand Up @@ -301,6 +305,29 @@ impl<'a> FieldsMapper<'a> {
Ok(first)
}

pub(super) fn map_to_date_range_dtype(&self, tz: &Option<String>) -> PolarsResult<Field> {
let inner_dtype = match (&self.map_to_supertype()?.dtype, tz) {
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(field_tz)), Some(tz)) => {
if field_tz != tz {
polars_bail!(ComputeError: format!("Given time_zone is different from that of timezone aware datetimes. \
Given: '{}', got: '{}'.", tz, field_tz))
}
DataType::Datetime(*tu, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), _) => DataType::Datetime(*tu, Some(tz.to_string())),
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => DataType::Datetime(*tu, Some(tz.to_string())),
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Date, _) => DataType::Date,
(dtype, _) => {
polars_bail!(ComputeError: "expected Date or Datetime, got {}", dtype)
}
};
Ok(Field::new("date", DataType::List(Box::new(inner_dtype))))
}

/// Map the dtypes to the "supertype" of a list of lists.
pub(super) fn map_to_list_supertype(&self) -> PolarsResult<Field> {
self.try_map_dtypes(|dts| {
Expand Down
88 changes: 73 additions & 15 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub(super) fn temporal_range_dispatch(
name: &str,
every: Duration,
closed: ClosedWindow,
_tz: Option<TimeZone>, // todo: respect _tz: https://github.com/pola-rs/polars/issues/8512
tz: Option<TimeZone>,
) -> PolarsResult<Series> {
let start = &s[0];
let stop = &s[1];
Expand All @@ -93,21 +93,45 @@ pub(super) fn temporal_range_dispatch(
);
const TO_MS: i64 = SECONDS_IN_DAY * 1000;

let rng_start = start.to_physical_repr();
let rng_stop = stop.to_physical_repr();
let dtype = start.dtype();
let start_dtype = start.dtype();

let mut start = rng_start.cast(&DataType::Int64)?;
let mut stop = rng_stop.cast(&DataType::Int64)?;
// Note: `start` and `stop` have already been cast to their supertype,
// so only `start`'s dtype needs to be matched against.
let (mut start, mut stop) = match start_dtype {
#[cfg(feature = "timezones")]
DataType::Datetime(_, Some(_)) => (
start
.datetime()
.unwrap()
.replace_time_zone(None, None)?
.into_series()
.to_physical_repr()
.cast(&DataType::Int64)?,
stop.datetime()
.unwrap()
.replace_time_zone(None, None)?
.into_series()
.to_physical_repr()
.cast(&DataType::Int64)?,
),
_ => (
start.to_physical_repr().cast(&DataType::Int64)?,
stop.to_physical_repr().cast(&DataType::Int64)?,
),
};

let (tu, tz) = match dtype {
DataType::Date => {
let dtype = match (start_dtype, tz) {
(DataType::Date, _) => {
start = &start * TO_MS;
stop = &stop * TO_MS;
(TimeUnit::Milliseconds, None)
DataType::Date
}
DataType::Datetime(tu, tz) => (*tu, tz.as_ref()),
DataType::Time => (TimeUnit::Nanoseconds, None),
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => DataType::Datetime(*tu, Some(tz)),
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), None) => DataType::Datetime(*tu, Some(tz.to_string())),
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Time, _) => DataType::Time,
_ => unimplemented!(),
};
let start = start.i64().unwrap();
Expand All @@ -124,7 +148,15 @@ pub(super) fn temporal_range_dispatch(
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz)?;
let rng = date_range_impl(
"",
start,
stop,
every,
closed,
TimeUnit::Milliseconds,
None,
)?;
let rng = rng.cast(&DataType::Date).unwrap();
let rng = rng.to_physical_repr();
let rng = rng.i32().unwrap();
Expand All @@ -135,7 +167,25 @@ pub(super) fn temporal_range_dispatch(
}
builder.finish().into_series()
}
DataType::Datetime(_, _) | DataType::Time => {
DataType::Datetime(tu, ref tz) => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
name,
start.len(),
start.len() * 5,
DataType::Int64,
);
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz.as_ref())?;
builder.append_slice(rng.cont_slice().unwrap())
}
_ => builder.append_null(),
}
}
builder.finish().into_series()
}
DataType::Time => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
name,
start.len(),
Expand All @@ -145,7 +195,15 @@ pub(super) fn temporal_range_dispatch(
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz)?;
let rng = date_range_impl(
"",
start,
stop,
every,
closed,
TimeUnit::Nanoseconds,
None,
)?;
builder.append_slice(rng.cont_slice().unwrap())
}
_ => builder.append_null(),
Expand All @@ -156,6 +214,6 @@ pub(super) fn temporal_range_dispatch(
_ => unimplemented!(),
};

let to_type = DataType::List(Box::new(dtype.clone()));
let to_type = DataType::List(Box::new(dtype));
list.cast(&to_type)
}
74 changes: 65 additions & 9 deletions py-polars/tests/unit/functions/test_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

import polars as pl
from polars.datatypes import DTYPE_TEMPORAL_UNITS
from polars.exceptions import ComputeError, TimeZoneAwareConstructorWarning
from polars.testing import assert_frame_equal
from polars.utils.convert import get_zoneinfo as ZoneInfo

if TYPE_CHECKING:
from zoneinfo import ZoneInfo

from polars.type_aliases import TimeUnit
else:
from polars.utils.convert import get_zoneinfo as ZoneInfo


def test_arange() -> None:
Expand Down Expand Up @@ -188,6 +192,28 @@ def test_date_range_lazy_with_literals() -> None:
)


def test_date_range_lazy_time_zones_invalid() -> None:
start = datetime(2020, 1, 1, tzinfo=ZoneInfo("Asia/Kathmandu"))
stop = datetime(2020, 1, 2, tzinfo=ZoneInfo("Asia/Kathmandu"))
with pytest.raises(
ComputeError,
match="Given time_zone is different from that of timezone aware datetimes. Given: 'Pacific/Tarawa', got: 'Asia/Kathmandu",
), pytest.warns(TimeZoneAwareConstructorWarning, match="Series with UTC"):
(
pl.DataFrame({"start": [start], "stop": [stop]})
.with_columns(
pl.date_range(
start,
stop,
interval="678d",
eager=False,
time_zone="Pacific/Tarawa",
)
)
.lazy()
)


@pytest.mark.parametrize("low", ["start", pl.col("start")])
@pytest.mark.parametrize("high", ["stop", pl.col("stop")])
def test_date_range_lazy_with_expressions(
Expand Down Expand Up @@ -560,21 +586,51 @@ def test_deprecated_name_arg() -> None:
assert result_eager.name == name


def test_date_range_schema() -> None:
df = pl.DataFrame(
{"start": [datetime(2020, 1, 1)], "end": [datetime(2020, 1, 2)]}
).lazy()
@pytest.mark.parametrize(
("values_time_zone", "input_time_zone", "output_time_zone"),
[
("Asia/Kathmandu", "Asia/Kathmandu", "Asia/Kathmandu"),
("Asia/Kathmandu", None, "Asia/Kathmandu"),
(None, "Asia/Kathmandu", "Asia/Kathmandu"),
(None, None, None),
],
)
def test_date_range_schema(
values_time_zone: str | None,
input_time_zone: str | None,
output_time_zone: str | None,
) -> None:
df = (
pl.DataFrame({"start": [datetime(2020, 1, 1)], "end": [datetime(2020, 1, 2)]})
.with_columns(pl.col("*").dt.replace_time_zone(values_time_zone))
.lazy()
)
result = df.with_columns(
pl.date_range(pl.col("start"), pl.col("end")).alias("date_range")
pl.date_range(pl.col("start"), pl.col("end"), time_zone=input_time_zone).alias(
"date_range"
)
)
expected_schema = {
"start": pl.Datetime(time_unit="us", time_zone=None),
"end": pl.Datetime(time_unit="us", time_zone=None),
"date_range": pl.List(pl.Datetime(time_unit="us", time_zone=None)),
"start": pl.Datetime(time_unit="us", time_zone=values_time_zone),
"end": pl.Datetime(time_unit="us", time_zone=values_time_zone),
"date_range": pl.List(pl.Datetime(time_unit="us", time_zone=output_time_zone)),
}
assert result.schema == expected_schema
assert result.collect().schema == expected_schema

expected = pl.DataFrame(
{
"start": [datetime(2020, 1, 1)],
"end": [datetime(2020, 1, 2)],
"date_range": [[datetime(2020, 1, 1), datetime(2020, 1, 2)]],
}
).with_columns(
pl.col("start").dt.replace_time_zone(values_time_zone),
pl.col("end").dt.replace_time_zone(values_time_zone),
pl.col("date_range").explode().dt.replace_time_zone(output_time_zone).implode(),
)
assert_frame_equal(result.collect(), expected)


def test_date_range_no_alias_schema_9037() -> None:
df = pl.DataFrame(
Expand Down

0 comments on commit 3cc9704

Please sign in to comment.