Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): fix result dtype in date_range(..., eager=True) if duration contains "1s1d" #9670

Merged
merged 2 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub enum TemporalFunction {
DateRange {
every: Duration,
closed: ClosedWindow,
time_unit: Option<TimeUnit>,
tz: Option<TimeZone>,
},
TimeRange {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,12 +640,18 @@ impl From<TemporalFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "timezones")]
TzLocalize(tz) => map!(datetime::tz_localize, &tz),
Combine(tu) => map_as_slice!(temporal::combine, tu),
DateRange { every, closed, tz } => {
DateRange {
every,
closed,
time_unit,
tz,
} => {
map_as_slice!(
temporal::temporal_range_dispatch,
"date",
every,
closed,
time_unit,
tz.clone()
)
}
Expand All @@ -655,6 +661,7 @@ impl From<TemporalFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
"time",
every,
closed,
None,
None
)
}
Expand Down
59 changes: 48 additions & 11 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ impl FunctionExpr {
#[cfg(feature = "timezones")]
TzLocalize(tz) => return mapper.map_datetime_dtype_timezone(Some(tz)),
DateRange {
every: _,
every,
closed: _,
time_unit,
tz,
} => {
// output dtype may change based on `tz`
return mapper.map_to_date_range_dtype(tz);
// output dtype may change based on `every`, `tz`, and `time_unit`
return mapper.map_to_date_range_dtype(every, time_unit, tz);
}
TimeRange { .. } => {
return Ok(Field::new("time", DataType::List(Box::new(DataType::Time))));
Expand Down Expand Up @@ -321,23 +322,59 @@ impl<'a> FieldsMapper<'a> {
Ok(first)
}

pub(super) fn map_to_date_range_dtype(&self, tz: &Option<String>) -> PolarsResult<Field> {
let inner_dtype = match (&self.map_to_supertype()?.dtype, tz) {
#[cfg(feature = "temporal")]
pub(super) fn map_to_date_range_dtype(
&self,
every: &Duration,
time_unit: &Option<TimeUnit>,
tz: &Option<String>,
) -> PolarsResult<Field> {
let inner_dtype = match (&self.map_to_supertype()?.dtype, time_unit, tz, every) {
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(field_tz)), Some(tz)) => {
(DataType::Datetime(tu, Some(field_tz)), time_unit, Some(tz), _) => {
if field_tz != tz {
polars_bail!(ComputeError: format!("Given time_zone is different from that of timezone aware datetimes. \
Given: '{}', got: '{}'.", tz, field_tz))
}
if let Some(time_unit) = time_unit {
DataType::Datetime(*time_unit, Some(tz.to_string()))
} else {
DataType::Datetime(*tu, Some(tz.to_string()))
}
}
#[cfg(feature = "timezones")]
(DataType::Datetime(_, Some(tz)), Some(time_unit), _, _) => {
DataType::Datetime(*time_unit, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), None, _, _) => {
DataType::Datetime(*tu, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), _) => DataType::Datetime(*tu, Some(tz.to_string())),
(DataType::Datetime(_, _), Some(time_unit), Some(tz), _) => {
DataType::Datetime(*time_unit, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => DataType::Datetime(*tu, Some(tz.to_string())),
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Date, _) => DataType::Date,
(dtype, _) => {
(DataType::Datetime(tu, _), None, Some(tz), _) => {
DataType::Datetime(*tu, Some(tz.to_string()))
}
(DataType::Datetime(_, _), Some(time_unit), _, _) => {
DataType::Datetime(*time_unit, None)
}
(DataType::Datetime(tu, _), None, _, _) => DataType::Datetime(*tu, None),
(DataType::Date, time_unit, time_zone, every) => {
let nsecs = every.nanoseconds();
if nsecs == 0 {
DataType::Date
} else if let Some(tu) = time_unit {
DataType::Datetime(*tu, time_zone.clone())
} else if nsecs % 1000 != 0 {
DataType::Datetime(TimeUnit::Nanoseconds, time_zone.clone())
} else {
DataType::Datetime(TimeUnit::Microseconds, time_zone.clone())
}
}
(dtype, _, _, _) => {
polars_bail!(ComputeError: "expected Date or Datetime, got {}", dtype)
}
};
Expand Down
66 changes: 47 additions & 19 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ pub(super) fn temporal_range_dispatch(
name: &str,
every: Duration,
closed: ClosedWindow,
tz: Option<TimeZone>,
time_unit: Option<TimeUnit>,
time_zone: Option<TimeZone>,
) -> PolarsResult<Series> {
let start = &s[0];
let stop = &s[1];
Expand All @@ -93,47 +94,74 @@ pub(super) fn temporal_range_dispatch(
);
const TO_MS: i64 = SECONDS_IN_DAY * 1000;

let start_dtype = start.dtype();

// Note: `start` and `stop` have already been cast to their supertype,
// so only `start`'s dtype needs to be matched against.
let (mut start, mut stop) = match start_dtype {
#[allow(unused_mut)] // `dtype` is mutated within a "feature = timezones" block.
let mut dtype = match (start.dtype(), time_unit) {
(DataType::Date, time_unit) => {
let nsecs = every.nanoseconds();
if nsecs == 0 {
DataType::Date
} else if let Some(tu) = time_unit {
DataType::Datetime(tu, None)
} else if nsecs % 1_000 != 0 {
DataType::Datetime(TimeUnit::Nanoseconds, None)
} else {
DataType::Datetime(TimeUnit::Microseconds, None)
}
}
(DataType::Time, _) => DataType::Time,
// overwrite nothing, keep as-is
(DataType::Datetime(_, _), None) => start.dtype().clone(),
// overwrite time unit, keep timezone
(DataType::Datetime(_, tz), Some(tu)) => DataType::Datetime(tu, tz.clone()),
_ => unreachable!(),
};

let (mut start, mut stop) = match dtype {
#[cfg(feature = "timezones")]
DataType::Datetime(_, Some(_)) => (
start
.cast(&dtype)?
.datetime()
.unwrap()
.replace_time_zone(None, None)?
.into_series()
.to_physical_repr()
.cast(&DataType::Int64)?,
stop.datetime()
stop.cast(&dtype)?
.datetime()
.unwrap()
.replace_time_zone(None, None)?
.into_series()
.to_physical_repr()
.cast(&DataType::Int64)?,
),
_ => (
start.to_physical_repr().cast(&DataType::Int64)?,
stop.to_physical_repr().cast(&DataType::Int64)?,
start
.cast(&dtype)?
.to_physical_repr()
.cast(&DataType::Int64)?,
stop.cast(&dtype)?
.to_physical_repr()
.cast(&DataType::Int64)?,
),
};

let dtype = match (start_dtype, tz) {
(DataType::Date, _) => {
start = &start * TO_MS;
stop = &stop * TO_MS;
DataType::Date
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => DataType::Datetime(*tu, Some(tz)),
if dtype == DataType::Date {
start = &start * TO_MS;
stop = &stop * TO_MS;
}

// overwrite time zone, if specified
match (&dtype, &time_zone) {
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), None) => DataType::Datetime(*tu, Some(tz.to_string())),
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Time, _) => DataType::Time,
_ => unimplemented!(),
(DataType::Datetime(tu, _), Some(tz)) => {
dtype = DataType::Datetime(*tu, Some(tz.clone()));
}
_ => {}
};

let start = start.i64().unwrap();
let stop = stop.i64().unwrap();

Expand Down
8 changes: 7 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/functions/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,19 @@ pub fn date_range(
end: Expr,
every: Duration,
closed: ClosedWindow,
time_unit: Option<TimeUnit>,
tz: Option<TimeZone>,
) -> Expr {
let input = vec![start, end];

Expr::Function {
input,
function: FunctionExpr::TemporalExpr(TemporalFunction::DateRange { every, closed, tz }),
function: FunctionExpr::TemporalExpr(TemporalFunction::DateRange {
every,
closed,
time_unit,
tz,
}),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
cast_to_supertypes: true,
Expand Down
81 changes: 19 additions & 62 deletions py-polars/polars/functions/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import contextlib
import warnings
from datetime import datetime, time, timedelta
from datetime import time, timedelta
from typing import TYPE_CHECKING, overload

import polars._reexport as pl
from polars import functions as F
from polars.datatypes import Date, Int64
from polars.datatypes import Int64
from polars.utils._parse_expr_input import parse_as_expression
from polars.utils._wrap import wrap_expr, wrap_s
from polars.utils.convert import (
_datetime_to_pl_timestamp,
_time_to_pl_time,
_timedelta_to_pl_duration,
)
Expand All @@ -22,7 +21,7 @@
import polars.polars as plr

if TYPE_CHECKING:
from datetime import date
from datetime import date, datetime
from typing import Literal

from polars import Expr, Series
Expand Down Expand Up @@ -390,9 +389,9 @@ def date_range(
closed : {'both', 'left', 'right', 'none'}
Define whether the temporal window interval is closed or not.
time_unit : {None, 'ns', 'us', 'ms'}
Set the time unit.
Set the time unit. Only takes effect if output is of ``Datetime`` type.
time_zone:
Optional timezone
Optional timezone. Only takes effect if output is of ``Datetime`` type.
eager
Evaluate immediately and return a ``Series``. If set to ``False`` (default),
return an expression instead.
Expand Down Expand Up @@ -528,72 +527,30 @@ def date_range(
elif " " in interval:
interval = interval.replace(" ", "")

if (
not eager
or isinstance(start, (str, pl.Expr))
or isinstance(end, (str, pl.Expr))
):
start = parse_as_expression(start)
end = parse_as_expression(end)
expr = wrap_expr(plr.date_range_lazy(start, end, interval, closed, time_zone))
if name is not None:
expr = expr.alias(name)
return expr

start, start_is_date = _ensure_datetime(start)
end, end_is_date = _ensure_datetime(end)

if start.tzinfo is not None or time_zone is not None:
if start.tzinfo != end.tzinfo:
raise ValueError(
"Cannot mix different timezone aware datetimes."
f" Got: '{start.tzinfo}' and '{end.tzinfo}'."
)

if time_zone is not None and start.tzinfo is not None:
if str(start.tzinfo) != time_zone:
raise ValueError(
"Given time_zone is different from that of timezone aware datetimes."
f" Given: '{time_zone}', got: '{start.tzinfo}'."
)
if time_zone is None and start.tzinfo is not None:
time_zone = str(start.tzinfo)

time_unit_: TimeUnit
time_unit_: TimeUnit | None
if time_unit is not None:
time_unit_ = time_unit
elif "ns" in interval:
time_unit_ = "ns"
else:
time_unit_ = "us"
time_unit_ = None

start_pl = _datetime_to_pl_timestamp(start, time_unit_)
end_pl = _datetime_to_pl_timestamp(end, time_unit_)
dt_range = wrap_s(
plr.date_range_eager(start_pl, end_pl, interval, closed, time_unit_, time_zone)
start_pl = parse_as_expression(start)
end_pl = parse_as_expression(end)
dt_range = wrap_expr(
plr.date_range_lazy(start_pl, end_pl, interval, closed, time_unit_, time_zone)
)
if (
start_is_date
and end_is_date
and not _interval_granularity(interval).endswith(("h", "m", "s"))
):
dt_range = dt_range.cast(Date)

if name is not None:
dt_range = dt_range.alias(name)
return dt_range


def _ensure_datetime(value: date | datetime) -> tuple[datetime, bool]:
is_date_type = False
if not isinstance(value, datetime):
value = datetime(value.year, value.month, value.day)
is_date_type = True
return value, is_date_type


def _interval_granularity(interval: str) -> str:
return interval[-2:].lstrip("0123456789")
if (
not eager
or isinstance(start_pl, (str, pl.Expr))
or isinstance(end_pl, (str, pl.Expr))
):
return dt_range
res = F.select(dt_range).to_series().explode().set_sorted()
return res


@overload
Expand Down
Loading