Skip to content

Commit

Permalink
fix(rust, python): fix diff overflow (#6033)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 4, 2023
1 parent 88cc5cf commit 265a409
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
14 changes: 11 additions & 3 deletions polars/polars-core/src/series/ops/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@ use crate::series::ops::NullBehavior;
impl Series {
#[cfg_attr(docsrs, doc(cfg(feature = "diff")))]
pub fn diff(&self, n: usize, null_behavior: NullBehavior) -> Series {
use DataType::*;
let s = match self.dtype() {
UInt8 => self.cast(&Int16).unwrap(),
UInt16 => self.cast(&Int32).unwrap(),
UInt32 | UInt64 => self.cast(&Int64).unwrap(),
_ => self.clone(),
};

match null_behavior {
NullBehavior::Ignore => self - &self.shift(n as i64),
NullBehavior::Ignore => &s - &s.shift(n as i64),
NullBehavior::Drop => {
let len = self.len() - n;
&self.slice(n as i64, len) - &self.slice(0, len)
let len = s.len() - n;
&self.slice(n as i64, len) - &s.slice(0, len)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ impl FunctionExpr {
DataType::Date => DataType::Duration(TimeUnit::Milliseconds),
#[cfg(feature = "dtype-time")]
DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
DataType::UInt64 | DataType::UInt32 => DataType::Int64,
DataType::UInt16 => DataType::Int32,
DataType::UInt8 => DataType::Int8,
dt => dt.clone(),
}),
#[cfg(feature = "interpolate")]
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,14 @@ def test_ones_zeros() -> None:
zeros = pl.zeros(3, dtype=pl.UInt8)
assert zeros.dtype == pl.UInt8
assert zeros.to_list() == [0, 0, 0]


def test_overflow_diff() -> None:
df = pl.DataFrame(
{
"a": [20, 10, 30],
}
)
assert df.select(pl.col("a").cast(pl.UInt64).diff()).to_dict(False) == {
"a": [None, -10, 20]
}

0 comments on commit 265a409

Please sign in to comment.