Skip to content

Commit

Permalink
feat!: Restrict casting for temporal data types (pola-rs#14142)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored and Wouittone committed Jun 22, 2024
1 parent e5aa894 commit f465e78
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 91 deletions.
12 changes: 8 additions & 4 deletions crates/polars-core/src/chunked_array/logical/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl LogicalType for DateChunked {
fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match dtype {
Date => Ok(self.clone().into_series()),
#[cfg(feature = "dtype-datetime")]
Datetime(tu, tz) => {
let casted = self.0.cast(dtype)?;
Expand All @@ -43,11 +44,14 @@ impl LogicalType for DateChunked {
.into_datetime(*tu, tz.clone())
.into_series())
},
#[cfg(feature = "dtype-time")]
Time => {
polars_bail!(ComputeError: "cannot cast `Date` to `Time`");
dt if dt.is_numeric() => self.0.cast(dtype),
dt => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dt
)
},
_ => self.0.cast(dtype),
}
}
}
51 changes: 30 additions & 21 deletions crates/polars-core/src/chunked_array/logical/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ impl LogicalType for DatetimeChunked {

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match (self.dtype(), dtype) {
(Datetime(from_unit, _), Datetime(to_unit, tz)) => {
use TimeUnit::*;
let out = match dtype {
Datetime(to_unit, tz) => {
let from_unit = self.time_unit();
let (multiplier, divisor) = match (from_unit, to_unit) {
// scaling from lower precision to higher precision
(TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => (Some(1_000_000i64), None),
(TimeUnit::Milliseconds, TimeUnit::Microseconds) => (Some(1_000i64), None),
(TimeUnit::Microseconds, TimeUnit::Nanoseconds) => (Some(1_000i64), None),
(Milliseconds, Nanoseconds) => (Some(1_000_000i64), None),
(Milliseconds, Microseconds) => (Some(1_000i64), None),
(Microseconds, Nanoseconds) => (Some(1_000i64), None),
// scaling from higher precision to lower precision
(TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => (None, Some(1_000_000i64)),
(TimeUnit::Nanoseconds, TimeUnit::Microseconds) => (None, Some(1_000i64)),
(TimeUnit::Microseconds, TimeUnit::Milliseconds) => (None, Some(1_000i64)),
(Nanoseconds, Milliseconds) => (None, Some(1_000_000i64)),
(Nanoseconds, Microseconds) => (None, Some(1_000i64)),
(Microseconds, Milliseconds) => (None, Some(1_000i64)),
_ => return self.0.cast(dtype),
};
let result = match multiplier {
Expand All @@ -61,7 +63,7 @@ impl LogicalType for DatetimeChunked {
result
},
#[cfg(feature = "dtype-date")]
(Datetime(tu, _), Date) => {
Date => {
let cast_to_date = |tu_in_day: i64| {
let mut dt = self
.0
Expand All @@ -73,18 +75,18 @@ impl LogicalType for DatetimeChunked {
dt.set_sorted_flag(self.is_sorted_flag());
Ok(dt)
};
match tu {
TimeUnit::Nanoseconds => cast_to_date(NS_IN_DAY),
TimeUnit::Microseconds => cast_to_date(US_IN_DAY),
TimeUnit::Milliseconds => cast_to_date(MS_IN_DAY),
match self.time_unit() {
Nanoseconds => cast_to_date(NS_IN_DAY),
Microseconds => cast_to_date(US_IN_DAY),
Milliseconds => cast_to_date(MS_IN_DAY),
}
},
#[cfg(feature = "dtype-time")]
(Datetime(tu, _), Time) => {
let (scaled_mod, multiplier) = match tu {
TimeUnit::Nanoseconds => (NS_IN_DAY, 1i64),
TimeUnit::Microseconds => (US_IN_DAY, 1_000i64),
TimeUnit::Milliseconds => (MS_IN_DAY, 1_000_000i64),
Time => {
let (scaled_mod, multiplier) = match self.time_unit() {
Nanoseconds => (NS_IN_DAY, 1i64),
Microseconds => (US_IN_DAY, 1_000i64),
Milliseconds => (MS_IN_DAY, 1_000_000i64),
};
return Ok(self
.0
Expand All @@ -95,9 +97,16 @@ impl LogicalType for DatetimeChunked {
.into_time()
.into_series());
},
_ => return self.0.cast(dtype),
}
.map(|mut s| {
dt if dt.is_numeric() => return self.0.cast(dtype),
dt => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dt
)
},
};
out.map(|mut s| {
// TODO!; implement the divisions/multipliers above
// in a checked manner so that we raise on overflow
s.set_sorted_flag(self.is_sorted_flag());
Expand Down
57 changes: 27 additions & 30 deletions crates/polars-core/src/chunked_array/logical/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,35 @@ impl LogicalType for DurationChunked {

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match (self.dtype(), dtype) {
(Duration(TimeUnit::Milliseconds), Duration(TimeUnit::Nanoseconds)) => {
Ok((self.0.as_ref() * 1_000_000i64)
.into_duration(TimeUnit::Nanoseconds)
.into_series())
use TimeUnit::*;
match dtype {
Duration(tu) => {
let to_unit = *tu;
let out = match (self.time_unit(), to_unit) {
(Milliseconds, Microseconds) => self.0.as_ref() * 1_000i64,
(Milliseconds, Nanoseconds) => self.0.as_ref() * 1_000_000i64,
(Microseconds, Milliseconds) => {
self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)
},
(Microseconds, Nanoseconds) => self.0.as_ref() * 1_000i64,
(Nanoseconds, Milliseconds) => {
self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64)
},
(Nanoseconds, Microseconds) => {
self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)
},
_ => return Ok(self.clone().into_series()),
};
Ok(out.into_duration(to_unit).into_series())
},
(Duration(TimeUnit::Milliseconds), Duration(TimeUnit::Microseconds)) => {
Ok((self.0.as_ref() * 1_000i64)
.into_duration(TimeUnit::Microseconds)
.into_series())
dt if dt.is_numeric() => self.0.cast(dtype),
dt => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dt
)
},
(Duration(TimeUnit::Microseconds), Duration(TimeUnit::Milliseconds)) => {
Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64))
.into_duration(TimeUnit::Milliseconds)
.into_series())
},
(Duration(TimeUnit::Microseconds), Duration(TimeUnit::Nanoseconds)) => {
Ok((self.0.as_ref() * 1_000i64)
.into_duration(TimeUnit::Nanoseconds)
.into_series())
},
(Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Milliseconds)) => {
Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64))
.into_duration(TimeUnit::Milliseconds)
.into_series())
},
(Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Microseconds)) => {
Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64))
.into_duration(TimeUnit::Microseconds)
.into_series())
},
_ => self.0.cast(dtype),
}
}
}
21 changes: 15 additions & 6 deletions crates/polars-core/src/chunked_array/logical/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ impl LogicalType for TimeChunked {
fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match dtype {
Time => Ok(self.clone().into_series()),
#[cfg(feature = "dtype-duration")]
Duration(tu) => {
let out = self.0.cast(&DataType::Duration(TimeUnit::Nanoseconds));
if !matches!(tu, TimeUnit::Nanoseconds) {
Expand All @@ -39,15 +41,22 @@ impl LogicalType for TimeChunked {
out
}
},
#[cfg(feature = "dtype-date")]
Date => {
polars_bail!(ComputeError: "cannot cast `Time` to `Date`");
},
#[cfg(feature = "dtype-datetime")]
Datetime(_, _) => {
polars_bail!(ComputeError: "cannot cast `Time` to `Datetime`; consider using `dt.combine`");
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported; consider using `dt.combine`",
self.dtype(), dtype
)
},
dt if dt.is_numeric() => self.0.cast(dtype),
_ => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dtype
)
},
_ => self.0.cast(dtype),
}
}
}
15 changes: 10 additions & 5 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,22 @@ def test_int_to_python_datetime() -> None:
datetime(1970, 1, 1, 0, 0, 0, 200000),
),
]

assert df.select(pl.col(col).dt.timestamp() for col in ("c", "d", "e")).rows() == [
(100000000000, 100000000, 100000),
(200000000000, 200000000, 200000),
]

assert df.select(
[pl.col(col).dt.timestamp() for col in ("c", "d", "e")]
+ [
getattr(pl.col("b").cast(pl.Duration).dt, f"total_{unit}")().alias(
[
getattr(pl.col("a").cast(pl.Duration).dt, f"total_{unit}")().alias(
f"u[{unit}]"
)
for unit in ("milliseconds", "microseconds", "nanoseconds")
]
).rows() == [
(100000000000, 100000000, 100000, 100000, 100000000, 100000000000),
(200000000000, 200000000, 200000, 200000, 200000000, 200000000000),
(100000, 100000000, 100000000000),
(200000, 200000000, 200000000000),
]


Expand Down
32 changes: 28 additions & 4 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,15 +607,15 @@ def test_cast_categorical_name_retention(

def test_cast_date_to_time() -> None:
s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)])
msg = "cannot cast `Date` to `Time`"
with pytest.raises(pl.ComputeError, match=msg):
msg = "casting from Date to Time not supported"
with pytest.raises(pl.InvalidOperationError, match=msg):
s.cast(pl.Time)


def test_cast_time_to_date() -> None:
s = pl.Series([time(0, 0), time(20, 00)])
msg = "cannot cast `Time` to `Date`"
with pytest.raises(pl.ComputeError, match=msg):
msg = "casting from Time to Date not supported"
with pytest.raises(pl.InvalidOperationError, match=msg):
s.cast(pl.Date)


Expand Down Expand Up @@ -648,3 +648,27 @@ def test_cast_decimal_to_decimal_high_precision() -> None:

assert result.dtype == target_dtype
assert result.to_list() == values


def test_err_on_time_datetime_cast() -> None:
s = pl.Series([time(10, 0, 0), time(11, 30, 59)])
with pytest.raises(
pl.InvalidOperationError,
match="casting from Time to Datetime\\(Microseconds, None\\) not supported; consider using `dt.combine`",
):
s.cast(pl.Datetime)


def test_err_on_invalid_time_zone_cast() -> None:
s = pl.Series([datetime(2021, 1, 1)])
with pytest.raises(pl.ComputeError, match=r"unable to parse time zone: 'qwerty'"):
s.cast(pl.Datetime("us", "qwerty"))


def test_invalid_inner_type_cast_list() -> None:
s = pl.Series([[-1, 1]])
with pytest.raises(
pl.InvalidOperationError,
match=r"cannot cast List inner type: 'Int64' to Categorical",
):
s.cast(pl.List(pl.Categorical))
21 changes: 0 additions & 21 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,27 +497,6 @@ def test_cast_err_column_value_highlighting(
test_df.with_columns(pl.all().cast(type))


def test_err_on_time_datetime_cast() -> None:
s = pl.Series([time(10, 0, 0), time(11, 30, 59)])
with pytest.raises(pl.ComputeError, match=r"cannot cast `Time` to `Datetime`"):
s.cast(pl.Datetime)


def test_err_on_invalid_time_zone_cast() -> None:
s = pl.Series([datetime(2021, 1, 1)])
with pytest.raises(pl.ComputeError, match=r"unable to parse time zone: 'qwerty'"):
s.cast(pl.Datetime("us", "qwerty"))


def test_invalid_inner_type_cast_list() -> None:
s = pl.Series([[-1, 1]])
with pytest.raises(
pl.InvalidOperationError,
match=r"cannot cast List inner type: 'Int64' to Categorical",
):
s.cast(pl.List(pl.Categorical))


def test_lit_agg_err() -> None:
with pytest.raises(pl.ComputeError, match=r"cannot aggregate a literal"):
pl.DataFrame({"y": [1]}).with_columns(pl.lit(1).sum().over("y"))
Expand Down

0 comments on commit f465e78

Please sign in to comment.