Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Restrict casting for temporal data types #14142

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions crates/polars-core/src/chunked_array/logical/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl LogicalType for DateChunked {
fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match dtype {
Date => Ok(self.clone().into_series()),
#[cfg(feature = "dtype-datetime")]
Datetime(tu, tz) => {
let casted = self.0.cast(dtype)?;
Expand All @@ -43,11 +44,14 @@ impl LogicalType for DateChunked {
.into_datetime(*tu, tz.clone())
.into_series())
},
#[cfg(feature = "dtype-time")]
Time => {
polars_bail!(ComputeError: "cannot cast `Date` to `Time`");
dt if dt.is_numeric() => self.0.cast(dtype),
dt => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dt
)
},
_ => self.0.cast(dtype),
}
}
}
51 changes: 30 additions & 21 deletions crates/polars-core/src/chunked_array/logical/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ impl LogicalType for DatetimeChunked {

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match (self.dtype(), dtype) {
(Datetime(from_unit, _), Datetime(to_unit, tz)) => {
use TimeUnit::*;
let out = match dtype {
Datetime(to_unit, tz) => {
let from_unit = self.time_unit();
let (multiplier, divisor) = match (from_unit, to_unit) {
// scaling from lower precision to higher precision
(TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => (Some(1_000_000i64), None),
(TimeUnit::Milliseconds, TimeUnit::Microseconds) => (Some(1_000i64), None),
(TimeUnit::Microseconds, TimeUnit::Nanoseconds) => (Some(1_000i64), None),
(Milliseconds, Nanoseconds) => (Some(1_000_000i64), None),
(Milliseconds, Microseconds) => (Some(1_000i64), None),
(Microseconds, Nanoseconds) => (Some(1_000i64), None),
// scaling from higher precision to lower precision
(TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => (None, Some(1_000_000i64)),
(TimeUnit::Nanoseconds, TimeUnit::Microseconds) => (None, Some(1_000i64)),
(TimeUnit::Microseconds, TimeUnit::Milliseconds) => (None, Some(1_000i64)),
(Nanoseconds, Milliseconds) => (None, Some(1_000_000i64)),
(Nanoseconds, Microseconds) => (None, Some(1_000i64)),
(Microseconds, Milliseconds) => (None, Some(1_000i64)),
_ => return self.0.cast(dtype),
};
let result = match multiplier {
Expand All @@ -61,7 +63,7 @@ impl LogicalType for DatetimeChunked {
result
},
#[cfg(feature = "dtype-date")]
(Datetime(tu, _), Date) => {
Date => {
let cast_to_date = |tu_in_day: i64| {
let mut dt = self
.0
Expand All @@ -73,18 +75,18 @@ impl LogicalType for DatetimeChunked {
dt.set_sorted_flag(self.is_sorted_flag());
Ok(dt)
};
match tu {
TimeUnit::Nanoseconds => cast_to_date(NS_IN_DAY),
TimeUnit::Microseconds => cast_to_date(US_IN_DAY),
TimeUnit::Milliseconds => cast_to_date(MS_IN_DAY),
match self.time_unit() {
Nanoseconds => cast_to_date(NS_IN_DAY),
Microseconds => cast_to_date(US_IN_DAY),
Milliseconds => cast_to_date(MS_IN_DAY),
}
},
#[cfg(feature = "dtype-time")]
(Datetime(tu, _), Time) => {
let (scaled_mod, multiplier) = match tu {
TimeUnit::Nanoseconds => (NS_IN_DAY, 1i64),
TimeUnit::Microseconds => (US_IN_DAY, 1_000i64),
TimeUnit::Milliseconds => (MS_IN_DAY, 1_000_000i64),
Time => {
let (scaled_mod, multiplier) = match self.time_unit() {
Nanoseconds => (NS_IN_DAY, 1i64),
Microseconds => (US_IN_DAY, 1_000i64),
Milliseconds => (MS_IN_DAY, 1_000_000i64),
};
return Ok(self
.0
Expand All @@ -95,9 +97,16 @@ impl LogicalType for DatetimeChunked {
.into_time()
.into_series());
},
_ => return self.0.cast(dtype),
}
.map(|mut s| {
dt if dt.is_numeric() => return self.0.cast(dtype),
dt => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dt
)
},
};
out.map(|mut s| {
// TODO!; implement the divisions/multipliers above
// in a checked manner so that we raise on overflow
s.set_sorted_flag(self.is_sorted_flag());
Expand Down
57 changes: 27 additions & 30 deletions crates/polars-core/src/chunked_array/logical/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,35 @@ impl LogicalType for DurationChunked {

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match (self.dtype(), dtype) {
(Duration(TimeUnit::Milliseconds), Duration(TimeUnit::Nanoseconds)) => {
Ok((self.0.as_ref() * 1_000_000i64)
.into_duration(TimeUnit::Nanoseconds)
.into_series())
use TimeUnit::*;
match dtype {
Duration(tu) => {
let to_unit = *tu;
let out = match (self.time_unit(), to_unit) {
(Milliseconds, Microseconds) => self.0.as_ref() * 1_000i64,
(Milliseconds, Nanoseconds) => self.0.as_ref() * 1_000_000i64,
(Microseconds, Milliseconds) => {
self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)
},
(Microseconds, Nanoseconds) => self.0.as_ref() * 1_000i64,
(Nanoseconds, Milliseconds) => {
self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64)
},
(Nanoseconds, Microseconds) => {
self.0.as_ref().wrapping_trunc_div_scalar(1_000i64)
},
_ => return Ok(self.clone().into_series()),
};
Ok(out.into_duration(to_unit).into_series())
},
(Duration(TimeUnit::Milliseconds), Duration(TimeUnit::Microseconds)) => {
Ok((self.0.as_ref() * 1_000i64)
.into_duration(TimeUnit::Microseconds)
.into_series())
dt if dt.is_numeric() => self.0.cast(dtype),
dt => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dt
)
},
(Duration(TimeUnit::Microseconds), Duration(TimeUnit::Milliseconds)) => {
Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64))
.into_duration(TimeUnit::Milliseconds)
.into_series())
},
(Duration(TimeUnit::Microseconds), Duration(TimeUnit::Nanoseconds)) => {
Ok((self.0.as_ref() * 1_000i64)
.into_duration(TimeUnit::Nanoseconds)
.into_series())
},
(Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Milliseconds)) => {
Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000_000i64))
.into_duration(TimeUnit::Milliseconds)
.into_series())
},
(Duration(TimeUnit::Nanoseconds), Duration(TimeUnit::Microseconds)) => {
Ok((self.0.as_ref().wrapping_trunc_div_scalar(1_000i64))
.into_duration(TimeUnit::Microseconds)
.into_series())
},
_ => self.0.cast(dtype),
}
}
}
21 changes: 15 additions & 6 deletions crates/polars-core/src/chunked_array/logical/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ impl LogicalType for TimeChunked {
fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match dtype {
Time => Ok(self.clone().into_series()),
#[cfg(feature = "dtype-duration")]
Duration(tu) => {
let out = self.0.cast(&DataType::Duration(TimeUnit::Nanoseconds));
if !matches!(tu, TimeUnit::Nanoseconds) {
Expand All @@ -39,15 +41,22 @@ impl LogicalType for TimeChunked {
out
}
},
#[cfg(feature = "dtype-date")]
Date => {
polars_bail!(ComputeError: "cannot cast `Time` to `Date`");
},
#[cfg(feature = "dtype-datetime")]
Datetime(_, _) => {
polars_bail!(ComputeError: "cannot cast `Time` to `Datetime`; consider using `dt.combine`");
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported; consider using `dt.combine`",
self.dtype(), dtype
)
},
dt if dt.is_numeric() => self.0.cast(dtype),
_ => {
polars_bail!(
InvalidOperation:
"casting from {:?} to {:?} not supported",
self.dtype(), dtype
)
},
_ => self.0.cast(dtype),
}
}
}
15 changes: 10 additions & 5 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,22 @@ def test_int_to_python_datetime() -> None:
datetime(1970, 1, 1, 0, 0, 0, 200000),
),
]

assert df.select(pl.col(col).dt.timestamp() for col in ("c", "d", "e")).rows() == [
(100000000000, 100000000, 100000),
(200000000000, 200000000, 200000),
]

assert df.select(
[pl.col(col).dt.timestamp() for col in ("c", "d", "e")]
+ [
getattr(pl.col("b").cast(pl.Duration).dt, f"total_{unit}")().alias(
[
getattr(pl.col("a").cast(pl.Duration).dt, f"total_{unit}")().alias(
f"u[{unit}]"
)
for unit in ("milliseconds", "microseconds", "nanoseconds")
]
).rows() == [
(100000000000, 100000000, 100000, 100000, 100000000, 100000000000),
(200000000000, 200000000, 200000, 200000, 200000000, 200000000000),
(100000, 100000000, 100000000000),
(200000, 200000000, 200000000000),
]


Expand Down
32 changes: 28 additions & 4 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,15 +607,15 @@ def test_cast_categorical_name_retention(

def test_cast_date_to_time() -> None:
s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)])
msg = "cannot cast `Date` to `Time`"
with pytest.raises(pl.ComputeError, match=msg):
msg = "casting from Date to Time not supported"
with pytest.raises(pl.InvalidOperationError, match=msg):
s.cast(pl.Time)


def test_cast_time_to_date() -> None:
s = pl.Series([time(0, 0), time(20, 00)])
msg = "cannot cast `Time` to `Date`"
with pytest.raises(pl.ComputeError, match=msg):
msg = "casting from Time to Date not supported"
with pytest.raises(pl.InvalidOperationError, match=msg):
s.cast(pl.Date)


Expand Down Expand Up @@ -648,3 +648,27 @@ def test_cast_decimal_to_decimal_high_precision() -> None:

assert result.dtype == target_dtype
assert result.to_list() == values


def test_err_on_time_datetime_cast() -> None:
s = pl.Series([time(10, 0, 0), time(11, 30, 59)])
with pytest.raises(
pl.InvalidOperationError,
match="casting from Time to Datetime\\(Microseconds, None\\) not supported; consider using `dt.combine`",
):
s.cast(pl.Datetime)


def test_err_on_invalid_time_zone_cast() -> None:
s = pl.Series([datetime(2021, 1, 1)])
with pytest.raises(pl.ComputeError, match=r"unable to parse time zone: 'qwerty'"):
s.cast(pl.Datetime("us", "qwerty"))


def test_invalid_inner_type_cast_list() -> None:
s = pl.Series([[-1, 1]])
with pytest.raises(
pl.InvalidOperationError,
match=r"cannot cast List inner type: 'Int64' to Categorical",
):
s.cast(pl.List(pl.Categorical))
21 changes: 0 additions & 21 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,27 +497,6 @@ def test_cast_err_column_value_highlighting(
test_df.with_columns(pl.all().cast(type))


def test_err_on_time_datetime_cast() -> None:
s = pl.Series([time(10, 0, 0), time(11, 30, 59)])
with pytest.raises(pl.ComputeError, match=r"cannot cast `Time` to `Datetime`"):
s.cast(pl.Datetime)


def test_err_on_invalid_time_zone_cast() -> None:
s = pl.Series([datetime(2021, 1, 1)])
with pytest.raises(pl.ComputeError, match=r"unable to parse time zone: 'qwerty'"):
s.cast(pl.Datetime("us", "qwerty"))


def test_invalid_inner_type_cast_list() -> None:
s = pl.Series([[-1, 1]])
with pytest.raises(
pl.InvalidOperationError,
match=r"cannot cast List inner type: 'Int64' to Categorical",
):
s.cast(pl.List(pl.Categorical))


def test_lit_agg_err() -> None:
with pytest.raises(pl.ComputeError, match=r"cannot aggregate a literal"):
pl.DataFrame({"y": [1]}).with_columns(pl.lit(1).sum().over("y"))
Expand Down