Skip to content

Commit

Permalink
feat!: Forbid casting from Date to Time and vice versa (#14127)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jan 31, 2024
1 parent 60be4fb commit ba0a5a8
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 55 deletions.
10 changes: 5 additions & 5 deletions crates/polars-core/src/chunked_array/logical/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ impl LogicalType for DateChunked {

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match (self.dtype(), dtype) {
match dtype {
#[cfg(feature = "dtype-datetime")]
(Date, Datetime(tu, tz)) => {
Datetime(tu, tz) => {
let casted = self.0.cast(dtype)?;
let casted = casted.datetime().unwrap();
let conversion = match tu {
Expand All @@ -44,9 +44,9 @@ impl LogicalType for DateChunked {
.into_series())
},
#[cfg(feature = "dtype-time")]
(Date, Time) => Ok(Int64Chunked::full(self.name(), 0i64, self.len())
.into_time()
.into_series()),
Time => {
polars_bail!(ComputeError: "cannot cast `Date` to `Time`");
},
_ => self.0.cast(dtype),
}
}
Expand Down
11 changes: 10 additions & 1 deletion crates/polars-core/src/chunked_array/logical/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,24 @@ impl LogicalType for TimeChunked {
}

fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
use DataType::*;
match dtype {
DataType::Duration(tu) => {
Duration(tu) => {
let out = self.0.cast(&DataType::Duration(TimeUnit::Nanoseconds));
if !matches!(tu, TimeUnit::Nanoseconds) {
out?.cast(dtype)
} else {
out
}
},
#[cfg(feature = "dtype-date")]
Date => {
polars_bail!(ComputeError: "cannot cast `Time` to `Date`");
},
#[cfg(feature = "dtype-datetime")]
Datetime(_, _) => {
polars_bail!(ComputeError: "cannot cast `Time` to `Datetime`; consider using `dt.combine`");
},
_ => self.0.cast(dtype),
}
}
Expand Down
34 changes: 13 additions & 21 deletions crates/polars-core/src/series/implementations/dates_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ macro_rules! impl_dyn_series {
fn _dtype(&self) -> &DataType {
self.0.dtype()
}
fn _get_flags(&self) -> Settings{
fn _get_flags(&self) -> Settings {
self.0.get_flags()
}
fn _set_flags(&mut self, flags: Settings){
fn _set_flags(&mut self, flags: Settings) {
self.0.set_flags(flags)
}

Expand Down Expand Up @@ -78,17 +78,17 @@ macro_rules! impl_dyn_series {
Ok(())
}

#[cfg(feature = "algorithm_group_by")]
#[cfg(feature = "algorithm_group_by")]
unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series {
self.0.agg_min(groups).$into_logical().into_series()
}

#[cfg(feature = "algorithm_group_by")]
#[cfg(feature = "algorithm_group_by")]
unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series {
self.0.agg_max(groups).$into_logical().into_series()
}

#[cfg(feature = "algorithm_group_by")]
#[cfg(feature = "algorithm_group_by")]
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
// we cannot cast and dispatch as the inner type of the list would be incorrect
self.0
Expand All @@ -104,7 +104,7 @@ macro_rules! impl_dyn_series {
let lhs = self.cast(&dt)?;
let rhs = rhs.cast(&dt)?;
lhs.subtract(&rhs)
}
},
(DataType::Date, DataType::Duration(_)) => ((&self
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))
.unwrap())
Expand Down Expand Up @@ -132,7 +132,7 @@ macro_rules! impl_dyn_series {
fn remainder(&self, rhs: &Series) -> PolarsResult<Series> {
polars_bail!(opq = rem, self.0.dtype(), rhs.dtype());
}
#[cfg(feature = "algorithm_group_by")]
#[cfg(feature = "algorithm_group_by")]
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
self.0.group_tuples(multithreaded, sorted)
}
Expand All @@ -143,7 +143,6 @@ macro_rules! impl_dyn_series {
}

impl SeriesTrait for SeriesWrap<$ca> {

fn rename(&mut self, name: &str) {
self.0.rename(name);
}
Expand Down Expand Up @@ -250,7 +249,7 @@ macro_rules! impl_dyn_series {

fn cast(&self, data_type: &DataType) -> PolarsResult<Series> {
match (self.dtype(), data_type) {
#[cfg(feature="dtype-date")]
#[cfg(feature = "dtype-date")]
(DataType::Date, DataType::String) => Ok(self
.0
.clone()
Expand All @@ -259,7 +258,7 @@ macro_rules! impl_dyn_series {
.unwrap()
.to_string("%Y-%m-%d")
.into_series()),
#[cfg(feature="dtype-time")]
#[cfg(feature = "dtype-time")]
(DataType::Time, DataType::String) => Ok(self
.0
.clone()
Expand All @@ -269,18 +268,11 @@ macro_rules! impl_dyn_series {
.to_string("%T")
.into_series()),
#[cfg(feature = "dtype-datetime")]
(DataType::Time, DataType::Datetime(_, _)) => {
polars_bail!(
ComputeError:
"cannot cast `Time` to `Datetime`; consider using 'dt.combine'"
);
}
#[cfg(feature = "dtype-datetime")]
(DataType::Date, DataType::Datetime(_, _)) => {
let mut out = self.0.cast(data_type)?;
out.set_sorted_flag(self.0.is_sorted_flag());
Ok(out)
}
},
_ => self.0.cast(data_type),
}
}
Expand Down Expand Up @@ -310,17 +302,17 @@ macro_rules! impl_dyn_series {
self.0.has_validity()
}

#[cfg(feature = "algorithm_group_by")]
#[cfg(feature = "algorithm_group_by")]
fn unique(&self) -> PolarsResult<Series> {
self.0.unique().map(|ca| ca.$into_logical().into_series())
}

#[cfg(feature = "algorithm_group_by")]
#[cfg(feature = "algorithm_group_by")]
fn n_unique(&self) -> PolarsResult<usize> {
self.0.n_unique()
}

#[cfg(feature = "algorithm_group_by")]
#[cfg(feature = "algorithm_group_by")]
fn arg_unique(&self) -> PolarsResult<IdxCa> {
self.0.arg_unique()
}
Expand Down
16 changes: 0 additions & 16 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,22 +1343,6 @@ def test_rolling_by_() -> None:
}


def test_date_to_time_cast_5111() -> None:
# check date -> time casts (fast-path: always 00:00:00)
df = pl.DataFrame(
{
"xyz": [
date(1969, 1, 1),
date(1990, 3, 8),
date(2000, 6, 16),
date(2010, 9, 24),
date(2022, 12, 31),
]
}
).with_columns(pl.col("xyz").cast(pl.Time))
assert df["xyz"].to_list() == [time(0), time(0), time(0), time(0), time(0)]


def test_sum_duration() -> None:
assert pl.DataFrame(
[
Expand Down
37 changes: 25 additions & 12 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal
from polars.testing.asserts.series import assert_series_equal
from polars.utils.convert import (
Expand All @@ -31,7 +30,7 @@ def test_string_date() -> None:
def test_invalid_string_date() -> None:
df = pl.DataFrame({"x1": ["2021-01-aa"]})

with pytest.raises(ComputeError):
with pytest.raises(pl.ComputeError):
df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)})


Expand Down Expand Up @@ -67,7 +66,7 @@ def test_string_datetime() -> None:

def test_invalid_string_datetime() -> None:
df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]})
with pytest.raises(ComputeError):
with pytest.raises(pl.ComputeError):
df.with_columns(
**{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))}
)
Expand Down Expand Up @@ -236,11 +235,11 @@ def test_strict_cast_int(
assert _cast_expr(*args) == expected_value # type: ignore[arg-type]
assert _cast_lit(*args) == expected_value # type: ignore[arg-type]
else:
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_series(*args) # type: ignore[arg-type]
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_expr(*args) # type: ignore[arg-type]
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_lit(*args) # type: ignore[arg-type]


Expand Down Expand Up @@ -375,11 +374,11 @@ def test_strict_cast_temporal(
assert out.item() == expected_value
assert out.dtype == to_dtype
else:
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_series_t(*args) # type: ignore[arg-type]
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_expr_t(*args) # type: ignore[arg-type]
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_lit_t(*args) # type: ignore[arg-type]


Expand Down Expand Up @@ -571,11 +570,11 @@ def test_strict_cast_string_and_binary(
assert out.item() == expected_value
assert out.dtype == to_dtype
else:
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_series_t(*args) # type: ignore[arg-type]
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_expr_t(*args) # type: ignore[arg-type]
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.ComputeError):
_cast_lit_t(*args) # type: ignore[arg-type]


Expand All @@ -599,3 +598,17 @@ def test_strict_cast_string_and_binary(
)
def test_cast_categorical_name_retention(dtype_out: PolarsDataType) -> None:
assert pl.Series("a", ["1"], dtype=pl.Categorical).cast(dtype_out).name == "a"


def test_cast_date_to_time() -> None:
s = pl.Series([date(1970, 1, 1), date(2000, 12, 31)])
msg = "cannot cast `Date` to `Time`"
with pytest.raises(pl.ComputeError, match=msg):
s.cast(pl.Time)


def test_cast_time_to_date() -> None:
s = pl.Series([time(0, 0), time(20, 00)])
msg = "cannot cast `Time` to `Date`"
with pytest.raises(pl.ComputeError, match=msg):
s.cast(pl.Date)

0 comments on commit ba0a5a8

Please sign in to comment.