Skip to content

Commit

Permalink
fix(rust, python): respect time_zone in lazy date_range
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed May 25, 2023
1 parent de59459 commit 1145bd7
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 12 deletions.
33 changes: 32 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,38 @@ impl FunctionExpr {
}
#[cfg(feature = "timezones")]
TzLocalize(tz) => return mapper.map_datetime_dtype_timezone(Some(tz)),
DateRange { .. } => return mapper.map_to_supertype(),
DateRange {
every: _,
closed: _,
tz,
} => {
let mut ret = mapper.map_to_supertype()?;
let ret_dtype = match (&ret.dtype, tz) {
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(field_tz)), Some(tz)) => {
if field_tz != tz {
polars_bail!(ComputeError: format!("Given time_zone is different from that of timezone aware datetimes. \
Given: '{}', got: '{}'.", tz, field_tz))
}
DataType::Datetime(*tu, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), _) => {
DataType::Datetime(*tu, Some(tz.to_string()))
}
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => {
DataType::Datetime(*tu, Some(tz.to_string()))
}
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Date, _) => DataType::Date,
(dtype, _) => {
polars_bail!(ComputeError: "expected Date or Datetime, got {}", dtype)
}
};
ret.coerce(ret_dtype);
return mapper.map_to_supertype();
}
TimeRange { .. } => DataType::Time,
Combine(tu) => match mapper.with_same_dtype().unwrap().dtype {
DataType::Datetime(_, tz) => DataType::Datetime(*tu, tz),
Expand Down
60 changes: 49 additions & 11 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub(super) fn temporal_range_dispatch(
name: &str,
every: Duration,
closed: ClosedWindow,
_tz: Option<TimeZone>, // todo: respect _tz: https://github.com/pola-rs/polars/issues/8512
tz: Option<TimeZone>,
) -> PolarsResult<Series> {
let start = &s[0];
let stop = &s[1];
Expand All @@ -95,19 +95,23 @@ pub(super) fn temporal_range_dispatch(

let rng_start = start.to_physical_repr();
let rng_stop = stop.to_physical_repr();
let dtype = start.dtype();
let start_dtype = start.dtype();

let mut start = rng_start.cast(&DataType::Int64)?;
let mut stop = rng_stop.cast(&DataType::Int64)?;

let (tu, tz) = match dtype {
DataType::Date => {
let dtype = match (start_dtype, tz) {
(DataType::Date, _) => {
start = &start * TO_MS;
stop = &stop * TO_MS;
(TimeUnit::Milliseconds, None)
DataType::Date
}
DataType::Datetime(tu, tz) => (*tu, tz.as_ref()),
DataType::Time => (TimeUnit::Nanoseconds, None),
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, _), Some(tz)) => DataType::Datetime(*tu, Some(tz)),
#[cfg(feature = "timezones")]
(DataType::Datetime(tu, Some(tz)), None) => DataType::Datetime(*tu, Some(tz.to_string())),
(DataType::Datetime(tu, _), _) => DataType::Datetime(*tu, None),
(DataType::Time, _) => DataType::Time,
_ => unimplemented!(),
};
let start = start.i64().unwrap();
Expand All @@ -124,7 +128,15 @@ pub(super) fn temporal_range_dispatch(
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz)?;
let rng = date_range_impl(
"",
start,
stop,
every,
closed,
TimeUnit::Milliseconds,
None,
)?;
let rng = rng.cast(&DataType::Date).unwrap();
let rng = rng.to_physical_repr();
let rng = rng.i32().unwrap();
Expand All @@ -135,7 +147,25 @@ pub(super) fn temporal_range_dispatch(
}
builder.finish().into_series()
}
DataType::Datetime(_, _) | DataType::Time => {
DataType::Datetime(tu, ref tz) => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
name,
start.len(),
start.len() * 5,
DataType::Int64,
);
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz.as_ref())?;
builder.append_slice(rng.cont_slice().unwrap())
}
_ => builder.append_null(),
}
}
builder.finish().into_series()
}
DataType::Time => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
name,
start.len(),
Expand All @@ -145,7 +175,15 @@ pub(super) fn temporal_range_dispatch(
for (start, stop) in start.into_iter().zip(stop.into_iter()) {
match (start, stop) {
(Some(start), Some(stop)) => {
let rng = date_range_impl("", start, stop, every, closed, tu, tz)?;
let rng = date_range_impl(
"",
start,
stop,
every,
closed,
TimeUnit::Nanoseconds,
None,
)?;
builder.append_slice(rng.cont_slice().unwrap())
}
_ => builder.append_null(),
Expand All @@ -156,6 +194,6 @@ pub(super) fn temporal_range_dispatch(
_ => unimplemented!(),
};

let to_type = DataType::List(Box::new(dtype.clone()));
let to_type = DataType::List(Box::new(dtype));
list.cast(&to_type)
}

0 comments on commit 1145bd7

Please sign in to comment.