From f465e78bf3d76993947fb40de13629dbcca26f94 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 4 Jun 2024 10:35:54 +0200 Subject: [PATCH] feat!: Restrict casting for temporal data types (#14142) --- .../src/chunked_array/logical/date.rs | 12 ++-- .../src/chunked_array/logical/datetime.rs | 51 ++++++++++------- .../src/chunked_array/logical/duration.rs | 57 +++++++++---------- .../src/chunked_array/logical/time.rs | 21 +++++-- .../tests/unit/datatypes/test_temporal.py | 15 +++-- py-polars/tests/unit/operations/test_cast.py | 32 +++++++++-- py-polars/tests/unit/test_errors.py | 21 ------- 7 files changed, 118 insertions(+), 91 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs index 55df4f3d428cb..0b8a4b11b888f 100644 --- a/crates/polars-core/src/chunked_array/logical/date.rs +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -30,6 +30,7 @@ impl LogicalType for DateChunked { fn cast(&self, dtype: &DataType) -> PolarsResult { use DataType::*; match dtype { + Date => Ok(self.clone().into_series()), #[cfg(feature = "dtype-datetime")] Datetime(tu, tz) => { let casted = self.0.cast(dtype)?; @@ -43,11 +44,14 @@ impl LogicalType for DateChunked { .into_datetime(*tu, tz.clone()) .into_series()) }, - #[cfg(feature = "dtype-time")] - Time => { - polars_bail!(ComputeError: "cannot cast `Date` to `Time`"); + dt if dt.is_numeric() => self.0.cast(dtype), + dt => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dt + ) }, - _ => self.0.cast(dtype), } } } diff --git a/crates/polars-core/src/chunked_array/logical/datetime.rs b/crates/polars-core/src/chunked_array/logical/datetime.rs index eef9e7e859cd1..c3ba7702106db 100644 --- a/crates/polars-core/src/chunked_array/logical/datetime.rs +++ b/crates/polars-core/src/chunked_array/logical/datetime.rs @@ -30,17 +30,19 @@ impl LogicalType for DatetimeChunked { fn cast(&self, dtype: &DataType) -> PolarsResult { use DataType::*; - match (self.dtype(), dtype) { - (Datetime(from_unit, _), Datetime(to_unit, tz)) => { + use TimeUnit::*; + let out = match dtype { + Datetime(to_unit, tz) => { + let from_unit = self.time_unit(); let (multiplier, divisor) = match (from_unit, to_unit) { // scaling from lower precision to higher precision - (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => (Some(1_000_000i64), None), - (TimeUnit::Milliseconds, TimeUnit::Microseconds) => (Some(1_000i64), None), - (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => (Some(1_000i64), None), + (Milliseconds, Nanoseconds) => (Some(1_000_000i64), None), + (Milliseconds, Microseconds) => (Some(1_000i64), None), + (Microseconds, Nanoseconds) => (Some(1_000i64), None), // scaling from higher precision to lower precision - (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => (None, Some(1_000_000i64)), - (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => (None, Some(1_000i64)), - (TimeUnit::Microseconds, TimeUnit::Milliseconds) => (None, Some(1_000i64)), + (Nanoseconds, Milliseconds) => (None, Some(1_000_000i64)), + (Nanoseconds, Microseconds) => (None, Some(1_000i64)), + (Microseconds, Milliseconds) => (None, Some(1_000i64)), _ => return self.0.cast(dtype), }; let result = match multiplier { @@ -61,7 +63,7 @@ impl LogicalType for DatetimeChunked { result }, #[cfg(feature = "dtype-date")] - (Datetime(tu, _), Date) => { + Date => { let cast_to_date = |tu_in_day: i64| { let mut dt = self .0 @@ -73,18 +75,18 @@ impl LogicalType for DatetimeChunked { dt.set_sorted_flag(self.is_sorted_flag()); Ok(dt) }; - match tu { - TimeUnit::Nanoseconds => cast_to_date(NS_IN_DAY), - TimeUnit::Microseconds => cast_to_date(US_IN_DAY), - TimeUnit::Milliseconds => cast_to_date(MS_IN_DAY), + match self.time_unit() { + Nanoseconds => cast_to_date(NS_IN_DAY), + Microseconds => cast_to_date(US_IN_DAY), + Milliseconds => cast_to_date(MS_IN_DAY), } }, #[cfg(feature = "dtype-time")] - (Datetime(tu, _), Time) => { - let (scaled_mod, multiplier) = match tu { - TimeUnit::Nanoseconds => (NS_IN_DAY, 1i64), - TimeUnit::Microseconds => (US_IN_DAY, 1_000i64), - TimeUnit::Milliseconds => (MS_IN_DAY, 1_000_000i64), + Time => { + let (scaled_mod, multiplier) = match self.time_unit() { + Nanoseconds => (NS_IN_DAY, 1i64), + Microseconds => (US_IN_DAY, 1_000i64), + Milliseconds => (MS_IN_DAY, 1_000_000i64), }; return Ok(self .0 @@ -95,9 +97,16 @@ impl LogicalType for DatetimeChunked { .into_time() .into_series()); }, - _ => return self.0.cast(dtype), - } - .map(|mut s| { + dt if dt.is_numeric() => return self.0.cast(dtype), + dt => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dt + ) + }, + }; + out.map(|mut s| { // TODO!; implement the divisions/multipliers above // in a checked manner so that we raise on overflow s.set_sorted_flag(self.is_sorted_flag()); diff --git a/crates/polars-core/src/chunked_array/logical/duration.rs b/crates/polars-core/src/chunked_array/logical/duration.rs index 63546969df790..873fb805d0fe9 100644 --- a/crates/polars-core/src/chunked_array/logical/duration.rs +++ b/crates/polars-core/src/chunked_array/logical/duration.rs @@ -29,38 +29,35 @@ impl LogicalType for DurationChunked { fn cast(&self, dtype: &DataType) -> PolarsResult { use DataType::*; - match (self.dtype(), dtype) { - (Duration(TimeUnit::Milliseconds), Duration(TimeUnit::Nanoseconds)) => { - Ok((self.0.as_ref() * 1_000_000i64) - .into_duration(TimeUnit::Nanoseconds) - .into_series()) + use TimeUnit::*; + match dtype { + Duration(tu) => { + let to_unit = *tu; + let out = match (self.time_unit(), to_unit) { + (Milliseconds, Microseconds) => self.0.as_ref() * 1_000i64, + (Milliseconds, Nanoseconds) => self.0.as_ref() * 1_000_000i64, + (Microseconds, Milliseconds) => { + self.0.as_ref().wrapping_trunc_div_scalar(1_000i64) + }, + (Microseconds, Nanoseconds) => self.0.as_ref() * 1_000i64, + (Nanoseconds, Milliseconds) => { + self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64) + }, + (Nanoseconds, Microseconds) => { + self.0.as_ref().wrapping_trunc_div_scalar(1_000i64) + }, + _ => return Ok(self.clone().into_series()), + }; + Ok(out.into_duration(to_unit).into_series()) }, - (Duration(TimeUnit::Milliseconds), Duration(TimeUnit::Microseconds)) => { - Ok((self.0.as_ref() * 1_000i64) - .into_duration(TimeUnit::Microseconds) - .into_series()) + dt if dt.is_numeric() => self.0.cast(dtype), + dt => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dt + ) }, - (Duration(TimeUnit::Microseconds), Duration(TimeUnit::Milliseconds)) => { - Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)) - .into_duration(TimeUnit::Milliseconds) - .into_series()) - }, - (Duration(TimeUnit::Microseconds), Duration(TimeUnit::Nanoseconds)) => { - Ok((self.0.as_ref() * 1_000i64) - .into_duration(TimeUnit::Nanoseconds) - .into_series()) - }, - (Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Milliseconds)) => { - Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64)) - .into_duration(TimeUnit::Milliseconds) - .into_series()) - }, - (Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Microseconds)) => { - Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)) - .into_duration(TimeUnit::Microseconds) - .into_series()) - }, - _ => self.0.cast(dtype), } } } diff --git a/crates/polars-core/src/chunked_array/logical/time.rs b/crates/polars-core/src/chunked_array/logical/time.rs index 3c546ef64ab58..c3e6e1f74df7c 100644 --- a/crates/polars-core/src/chunked_array/logical/time.rs +++ b/crates/polars-core/src/chunked_array/logical/time.rs @@ -31,6 +31,8 @@ impl LogicalType for TimeChunked { fn cast(&self, dtype: &DataType) -> PolarsResult { use DataType::*; match dtype { + Time => Ok(self.clone().into_series()), + #[cfg(feature = "dtype-duration")] Duration(tu) => { let out = self.0.cast(&DataType::Duration(TimeUnit::Nanoseconds)); if !matches!(tu, TimeUnit::Nanoseconds) { @@ -39,15 +41,22 @@ impl LogicalType for TimeChunked { out } }, - #[cfg(feature = "dtype-date")] - Date => { - polars_bail!(ComputeError: "cannot cast `Time` to `Date`"); - }, #[cfg(feature = "dtype-datetime")] Datetime(_, _) => { - polars_bail!(ComputeError: "cannot cast `Time` to `Datetime`; consider using `dt.combine`"); + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported; consider using `dt.combine`", + self.dtype(), dtype + ) + }, + dt if dt.is_numeric() => self.0.cast(dtype), + _ => { + polars_bail!( + InvalidOperation: + "casting from {:?} to {:?} not supported", + self.dtype(), dtype + ) }, - _ => self.0.cast(dtype), } } } diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index 2753a0d16650f..2ca08d37d7cba 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -239,17 +239,22 @@ def test_int_to_python_datetime() -> None: datetime(1970, 1, 1, 0, 0, 0, 200000), ), ] + + assert df.select(pl.col(col).dt.timestamp() for col in ("c", "d", "e")).rows() == [ + (100000000000, 100000000, 100000), + (200000000000, 200000000, 200000), + ] + assert df.select( - [pl.col(col).dt.timestamp() for col in ("c", "d", "e")] - + [ - getattr(pl.col("b").cast(pl.Duration).dt, f"total_{unit}")().alias( + [ + getattr(pl.col("a").cast(pl.Duration).dt, f"total_{unit}")().alias( f"u[{unit}]" ) for unit in ("milliseconds", "microseconds", "nanoseconds") ] ).rows() == [ - (100000000000, 100000000, 100000, 100000, 100000000, 100000000000), - (200000000000, 200000000, 200000, 200000, 200000000, 200000000000), + (100000, 100000000, 100000000000), + (200000, 200000000, 200000000000), ] diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index 3ecdf33aae6a0..84336573c2e69 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -607,15 +607,15 @@ def test_cast_categorical_name_retention( def test_cast_date_to_time() -> None: s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)]) - msg = "cannot cast `Date` to `Time`" - with pytest.raises(pl.ComputeError, match=msg): + msg = "casting from Date to Time not supported" + with pytest.raises(pl.InvalidOperationError, match=msg): s.cast(pl.Time) def test_cast_time_to_date() -> None: s = pl.Series([time(0, 0), time(20, 00)]) - msg = "cannot cast `Time` to `Date`" - with pytest.raises(pl.ComputeError, match=msg): + msg = "casting from Time to Date not supported" + with pytest.raises(pl.InvalidOperationError, match=msg): s.cast(pl.Date) @@ -648,3 +648,27 @@ def test_cast_decimal_to_decimal_high_precision() -> None: assert result.dtype == target_dtype assert result.to_list() == values + + +def test_err_on_time_datetime_cast() -> None: + s = pl.Series([time(10, 0, 0), time(11, 30, 59)]) + with pytest.raises( + pl.InvalidOperationError, + match="casting from Time to Datetime\\(Microseconds, None\\) not supported; consider using `dt.combine`", + ): + s.cast(pl.Datetime) + + +def test_err_on_invalid_time_zone_cast() -> None: + s = pl.Series([datetime(2021, 1, 1)]) + with pytest.raises(pl.ComputeError, match=r"unable to parse time zone: 'qwerty'"): + s.cast(pl.Datetime("us", "qwerty")) + + +def test_invalid_inner_type_cast_list() -> None: + s = pl.Series([[-1, 1]]) + with pytest.raises( + pl.InvalidOperationError, + match=r"cannot cast List inner type: 'Int64' to Categorical", + ): + s.cast(pl.List(pl.Categorical)) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 6190d93422e53..e98d9637e88bd 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -497,27 +497,6 @@ def test_cast_err_column_value_highlighting( test_df.with_columns(pl.all().cast(type)) -def test_err_on_time_datetime_cast() -> None: - s = pl.Series([time(10, 0, 0), time(11, 30, 59)]) - with pytest.raises(pl.ComputeError, match=r"cannot cast `Time` to `Datetime`"): - s.cast(pl.Datetime) - - -def test_err_on_invalid_time_zone_cast() -> None: - s = pl.Series([datetime(2021, 1, 1)]) - with pytest.raises(pl.ComputeError, match=r"unable to parse time zone: 'qwerty'"): - s.cast(pl.Datetime("us", "qwerty")) - - -def test_invalid_inner_type_cast_list() -> None: - s = pl.Series([[-1, 1]]) - with pytest.raises( - pl.InvalidOperationError, - match=r"cannot cast List inner type: 'Int64' to Categorical", - ): - s.cast(pl.List(pl.Categorical)) - - def test_lit_agg_err() -> None: with pytest.raises(pl.ComputeError, match=r"cannot aggregate a literal"): pl.DataFrame({"y": [1]}).with_columns(pl.lit(1).sum().over("y"))