Skip to content

Commit

Permalink
fix: append decimal with different scale (#13977)
Browse files Browse the repository at this point in the history
  • Loading branch information
flisky committed Jan 26, 2024
1 parent 0b6be0b commit d0f2d27
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
2 changes: 2 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,8 @@ pub(crate) fn can_extend_dtype(left: &DataType, right: &DataType) -> PolarsResul
Ok(must_cast)
},
(DataType::Null, DataType::Null) => Ok(false),
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, s1), DataType::Decimal(_, s2)) => Ok(s1 != s2),
// Other way around we don't allow because we keep left dtype as is.
// We don't go to supertype, and we certainly don't want to cast self to null type.
(_, DataType::Null) => Ok(true),
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-core/src/series/implementations/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ impl SeriesTrait for SeriesWrap<DecimalChunked> {

fn extend(&mut self, other: &Series) -> PolarsResult<()> {
polars_ensure!(self.0.dtype() == other.dtype(), extend);
self.0.extend(other.as_ref().as_ref());
let other = other.decimal()?;
self.0.extend(&other.0);
Ok(())
}

Expand Down
58 changes: 41 additions & 17 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,23 +904,17 @@ where
T: 'static + PolarsDataType,
{
fn as_ref(&self) -> &ChunkedArray<T> {
match T::get_dtype() {
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(None, None) => panic!("impl error"),
_ => {
if &T::get_dtype() == self.dtype() ||
// Needed because we want to get ref of List no matter what the inner type is.
(matches!(T::get_dtype(), DataType::List(_)) && matches!(self.dtype(), DataType::List(_)))
{
unsafe { &*(self as *const dyn SeriesTrait as *const ChunkedArray<T>) }
} else {
panic!(
"implementation error, cannot get ref {:?} from {:?}",
T::get_dtype(),
self.dtype()
);
}
},
if &T::get_dtype() == self.dtype() ||
// Needed because we want to get ref of List no matter what the inner type is.
(matches!(T::get_dtype(), DataType::List(_)) && matches!(self.dtype(), DataType::List(_)))
{
unsafe { &*(self as *const dyn SeriesTrait as *const ChunkedArray<T>) }
} else {
panic!(
"implementation error, cannot get ref {:?} from {:?}",
T::get_dtype(),
self.dtype()
);
}
}
}
Expand Down Expand Up @@ -999,6 +993,36 @@ mod test {
assert!(s1.append(&s2).is_err())
}

#[test]
#[cfg(feature = "dtype-decimal")]
fn series_append_decimal() {
let s1 = Series::new("a", &[1.1, 2.3])
.cast(&DataType::Decimal(None, Some(2)))
.unwrap();
let s2 = Series::new("b", &[3])
.cast(&DataType::Decimal(None, Some(0)))
.unwrap();

{
let mut s1 = s1.clone();
s1.append(&s2).unwrap();
assert_eq!(s1.len(), 3);
#[cfg(feature = "python")]
assert_eq!(s1.get(2).unwrap(), AnyValue::Float64(3.0));
#[cfg(not(feature = "python"))]
assert_eq!(s1.get(2).unwrap(), AnyValue::Decimal(300, 2));
}

{
let mut s2 = s2.clone();
s2.extend(&s1).unwrap();
#[cfg(feature = "python")]
assert_eq!(s2.get(2).unwrap(), AnyValue::Float64(2.29)); // 2.3 == 2.2999999999999998
#[cfg(not(feature = "python"))]
assert_eq!(s2.get(2).unwrap(), AnyValue::Decimal(2, 0));
}
}

#[test]
fn series_slice_works() {
let series = Series::new("a", &[1i64, 2, 3, 4, 5]);
Expand Down

0 comments on commit d0f2d27

Please sign in to comment.