From 059ba2abd7b2d47563eaadee13bfd2360da58b2b Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 4 Jan 2023 10:15:44 +0100 Subject: [PATCH] fix(rust, python): fix diff overflow --- polars/polars-core/src/series/ops/diff.rs | 14 +++++++++++--- .../polars-plan/src/dsl/function_expr/schema.rs | 3 +++ py-polars/tests/unit/test_functions.py | 11 +++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/polars/polars-core/src/series/ops/diff.rs b/polars/polars-core/src/series/ops/diff.rs index 561a2c73ef7d..49fcd747b246 100644 --- a/polars/polars-core/src/series/ops/diff.rs +++ b/polars/polars-core/src/series/ops/diff.rs @@ -4,11 +4,19 @@ use crate::series::ops::NullBehavior; impl Series { #[cfg_attr(docsrs, doc(cfg(feature = "diff")))] pub fn diff(&self, n: usize, null_behavior: NullBehavior) -> Series { + use DataType::*; + let s = match self.dtype() { + UInt8 => self.cast(&Int16).unwrap(), + UInt16 => self.cast(&Int32).unwrap(), + UInt32 | UInt64 => self.cast(&Int64).unwrap(), + _ => self.clone(), + }; + match null_behavior { - NullBehavior::Ignore => self - &self.shift(n as i64), + NullBehavior::Ignore => &s - &s.shift(n as i64), NullBehavior::Drop => { - let len = self.len() - n; - &self.slice(n as i64, len) - &self.slice(0, len) + let len = s.len() - n; + &self.slice(n as i64, len) - &s.slice(0, len) } } } diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs index c2deaf4d6602..785961adf4a5 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs @@ -226,6 +226,9 @@ impl FunctionExpr { DataType::Date => DataType::Duration(TimeUnit::Milliseconds), #[cfg(feature = "dtype-time")] DataType::Time => DataType::Duration(TimeUnit::Nanoseconds), + DataType::UInt64 | DataType::UInt32 => DataType::Int64, + DataType::UInt16 => DataType::Int32, + DataType::UInt8 => DataType::Int8, dt => dt.clone(), }), #[cfg(feature = "interpolate")] diff --git a/py-polars/tests/unit/test_functions.py b/py-polars/tests/unit/test_functions.py index 30916582726f..14a5e88a3c17 100644 --- a/py-polars/tests/unit/test_functions.py +++ b/py-polars/tests/unit/test_functions.py @@ -272,3 +272,14 @@ def test_ones_zeros() -> None: zeros = pl.zeros(3, dtype=pl.UInt8) assert zeros.dtype == pl.UInt8 assert zeros.to_list() == [0, 0, 0] + + +def test_overflow_diff() -> None: + df = pl.DataFrame( + { + "a": [20, 10, 30], + } + ) + assert df.select(pl.col("a").cast(pl.UInt64).diff()).to_dict(False) == { + "a": [None, -10, 20] + }