From 98d6a2b8fd822d7bb42502add39c045b36fb3c9f Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 4 Jun 2024 14:29:22 +0200 Subject: [PATCH] fix: Raise unsupported cat array --- crates/polars-core/src/chunked_array/array/iterator.rs | 4 ++-- crates/polars-core/src/chunked_array/array/mod.rs | 8 ++++---- crates/polars-core/src/chunked_array/cast.rs | 9 +++++++-- crates/polars-core/src/chunked_array/iterator/mod.rs | 4 ++-- crates/polars-core/src/chunked_array/ops/explode.rs | 2 +- .../src/chunked_array/ops/explode_and_offsets.rs | 4 ++-- crates/polars-core/src/chunked_array/ops/mod.rs | 4 ++-- crates/polars-core/src/chunked_array/ops/reverse.rs | 2 +- crates/polars-core/src/chunked_array/ops/shift.rs | 2 +- crates/polars-ops/src/chunked_array/array/dispersion.rs | 4 ++-- crates/polars-ops/src/chunked_array/array/get.rs | 6 +++--- crates/polars-ops/src/chunked_array/array/namespace.rs | 8 ++++---- py-polars/tests/unit/test_errors.py | 5 +++++ py-polars/tests/unit/testing/test_assert_series_equal.py | 6 ------ 14 files changed, 36 insertions(+), 32 deletions(-) diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index cbb38f954e5e..d17d34e5582b 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -61,7 +61,7 @@ impl ArrayChunked { series_container, NonNull::new(ptr).unwrap(), self.downcast_iter().flat_map(|arr| arr.iter()), - inner_dtype, + inner_dtype.clone(), ) } @@ -72,7 +72,7 @@ impl ArrayChunked { if self.is_empty() { return Ok(Series::new_empty( self.name(), - &DataType::List(Box::new(self.inner_dtype())), + &DataType::List(Box::new(self.inner_dtype().clone())), ) .list() .unwrap() diff --git a/crates/polars-core/src/chunked_array/array/mod.rs b/crates/polars-core/src/chunked_array/array/mod.rs index b92255a8f995..96fee06ff3b2 100644 --- a/crates/polars-core/src/chunked_array/array/mod.rs +++ b/crates/polars-core/src/chunked_array/array/mod.rs @@ -6,9 +6,9 @@ use crate::prelude::*; impl ArrayChunked { /// Get the inner data type of the fixed size list. - pub fn inner_dtype(&self) -> DataType { + pub fn inner_dtype(&self) -> &DataType { match self.dtype() { - DataType::Array(dt, _size) => *dt.clone(), + DataType::Array(dt, _size) => dt.as_ref(), _ => unreachable!(), } } @@ -23,7 +23,7 @@ impl ArrayChunked { /// # Safety /// The caller must ensure that the logical type given fits the physical type of the array. pub unsafe fn to_logical(&mut self, inner_dtype: DataType) { - debug_assert_eq!(inner_dtype.to_physical(), self.inner_dtype()); + debug_assert_eq!(&inner_dtype.to_physical(), self.inner_dtype()); let width = self.width(); let fld = Arc::make_mut(&mut self.field); fld.coerce(DataType::Array(Box::new(inner_dtype), width)) @@ -34,7 +34,7 @@ impl ArrayChunked { let chunks: Vec<_> = self.downcast_iter().map(|c| c.values().clone()).collect(); // SAFETY: Data type of arrays matches because they are chunks from the same array. - unsafe { Series::from_chunks_and_dtype_unchecked(self.name(), chunks, &self.inner_dtype()) } + unsafe { Series::from_chunks_and_dtype_unchecked(self.name(), chunks, self.inner_dtype()) } } /// Ignore the list indices and apply `func` to the inner type as [`Series`]. diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 82ffa79f4af1..b329dc661b1d 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -412,6 +412,11 @@ impl ChunkCast for ListChunked { #[cfg(feature = "dtype-array")] Array(child_type, width) => { let physical_type = data_type.to_physical(); + + // TODO!: properly implement this recursively. + #[cfg(feature = "dtype-categorical")] + polars_ensure!(!matches!(&**child_type, Categorical(_, _)), InvalidOperation: "array of categorical is not yet supported"); + // cast to the physical type to avoid logical chunks. let chunks = cast_chunks(self.chunks(), &physical_type, true)?; // SAFETY: we just casted so the dtype matches. @@ -457,7 +462,7 @@ impl ChunkCast for ArrayChunked { ); match (self.inner_dtype(), &**child_type) { - (old, new) if old == *new => Ok(self.clone().into_series()), + (old, new) if old == new => Ok(self.clone().into_series()), #[cfg(feature = "dtype-categorical")] (dt, Categorical(None, _) | Enum(_, _)) if !matches!(dt, String) => { polars_bail!(InvalidOperation: "cannot cast Array inner type: '{:?}' to dtype: {:?}", dt, child_type) @@ -571,7 +576,7 @@ fn cast_fixed_size_list( let arr = ca.downcast_iter().next().unwrap(); // SAFETY: inner dtype is passed correctly let s = unsafe { - Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], &ca.inner_dtype()) + Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], ca.inner_dtype()) }; let new_inner = s.cast(child_type)?; diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs index 27781b892876..1ffacdb4bbb1 100644 --- a/crates/polars-core/src/chunked_array/iterator/mod.rs +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -282,7 +282,7 @@ impl<'a> IntoIterator for &'a ArrayChunked { Some(Series::from_chunks_and_dtype_unchecked( "", vec![arr], - &dtype, + dtype, )) }), ) @@ -296,7 +296,7 @@ impl<'a> IntoIterator for &'a ArrayChunked { .trust_my_length(self.len()) .map(move |arr| { arr.map(|arr| { - Series::from_chunks_and_dtype_unchecked("", vec![arr], &dtype) + Series::from_chunks_and_dtype_unchecked("", vec![arr], dtype) }) }), ) diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs index 735792c19572..ca55ac6836e8 100644 --- a/crates/polars-core/src/chunked_array/ops/explode.rs +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -291,7 +291,7 @@ impl ExplodeByOffsets for ArrayChunked { let cap = get_capacity(offsets); let inner_type = self.inner_dtype(); let mut builder = - get_fixed_size_list_builder(&inner_type, cap, self.width(), self.name()).unwrap(); + get_fixed_size_list_builder(inner_type, cap, self.width(), self.name()).unwrap(); let mut start = offsets[0] as usize; let mut last = start; diff --git a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs index 7c08d4de622a..f407c6245bd8 100644 --- a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs +++ b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs @@ -179,7 +179,7 @@ impl ChunkExplode for ArrayChunked { if arr.null_count() == 0 { let s = Series::try_from((self.name(), arr.values().clone())) .unwrap() - .cast(&ca.inner_dtype())?; + .cast(ca.inner_dtype())?; let width = self.width() as i64; let offsets = (0..self.len() + 1) .map(|i| { @@ -224,7 +224,7 @@ impl ChunkExplode for ArrayChunked { Ok(( // SAFETY: inner_dtype should be correct unsafe { - Series::from_chunks_and_dtype_unchecked(ca.name(), vec![chunk], &ca.inner_dtype()) + Series::from_chunks_and_dtype_unchecked(ca.name(), vec![chunk], ca.inner_dtype()) }, offsets, )) diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index c6f434e97675..30e0f275ab1e 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -524,13 +524,13 @@ impl ChunkExpandAtIndex for ArrayChunked { match opt_val { Some(val) => { let mut ca = ArrayChunked::full(self.name(), &val, length); - unsafe { ca.to_logical(self.inner_dtype()) }; + unsafe { ca.to_logical(self.inner_dtype().clone()) }; ca }, None => ArrayChunked::full_null_with_dtype( self.name(), length, - &self.inner_dtype(), + self.inner_dtype(), self.width(), ), } diff --git a/crates/polars-core/src/chunked_array/ops/reverse.rs b/crates/polars-core/src/chunked_array/ops/reverse.rs index fcc622fe10a3..9d3b0938f390 100644 --- a/crates/polars-core/src/chunked_array/ops/reverse.rs +++ b/crates/polars-core/src/chunked_array/ops/reverse.rs @@ -89,7 +89,7 @@ impl ChunkReverse for ArrayChunked { let values = arr.values().as_ref(); let mut builder = - get_fixed_size_list_builder(&ca.inner_dtype(), ca.len(), ca.width(), ca.name()) + get_fixed_size_list_builder(ca.inner_dtype(), ca.len(), ca.width(), ca.name()) .expect("not yet supported"); // SAFETY, we are within bounds diff --git a/crates/polars-core/src/chunked_array/ops/shift.rs b/crates/polars-core/src/chunked_array/ops/shift.rs index ecb64f7b79cd..2938c1ba5ecd 100644 --- a/crates/polars-core/src/chunked_array/ops/shift.rs +++ b/crates/polars-core/src/chunked_array/ops/shift.rs @@ -147,7 +147,7 @@ impl ChunkShiftFill> for ArrayChunked { let mut fill = match fill_value { Some(val) => Self::full(self.name(), val, fill_length), None => { - ArrayChunked::full_null_with_dtype(self.name(), fill_length, &self.inner_dtype(), 0) + ArrayChunked::full_null_with_dtype(self.name(), fill_length, self.inner_dtype(), 0) }, }; diff --git a/crates/polars-ops/src/chunked_array/array/dispersion.rs b/crates/polars-ops/src/chunked_array/array/dispersion.rs index e7039ac5db2e..afec05d0f539 100644 --- a/crates/polars-ops/src/chunked_array/array/dispersion.rs +++ b/crates/polars-ops/src/chunked_array/array/dispersion.rs @@ -13,7 +13,7 @@ pub(super) fn median_with_nulls(ca: &ArrayChunked) -> PolarsResult { let out: Int64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64))) .with_name(ca.name()); - out.into_duration(tu).into_series() + out.into_duration(*tu).into_series() }, _ => { let out: Float64Chunked = ca @@ -39,7 +39,7 @@ pub(super) fn std_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { let out: Float64Chunked = ca diff --git a/crates/polars-ops/src/chunked_array/array/get.rs b/crates/polars-ops/src/chunked_array/array/get.rs index 11d498c396f7..46bf7232e390 100644 --- a/crates/polars-ops/src/chunked_array/array/get.rs +++ b/crates/polars-ops/src/chunked_array/array/get.rs @@ -13,7 +13,7 @@ fn array_get_literal(ca: &ArrayChunked, idx: i64, null_on_oob: bool) -> PolarsRe .collect::>>()?; Series::try_from((ca.name(), chunks)) .unwrap() - .cast(&ca.inner_dtype()) + .cast(ca.inner_dtype()) } /// Get the value by literal index in the array. @@ -31,14 +31,14 @@ pub fn array_get( if let Some(index) = index { array_get_literal(ca, index, null_on_oob) } else { - Ok(Series::full_null(ca.name(), ca.len(), &ca.inner_dtype())) + Ok(Series::full_null(ca.name(), ca.len(), ca.inner_dtype())) } }, len if len == ca.len() => { let out = binary_to_series_arr_get(ca, index, null_on_oob, |arr, idx, nob| { sub_fixed_size_list_get(arr, idx, nob) }); - out?.cast(&ca.inner_dtype()) + out?.cast(ca.inner_dtype()) }, len => polars_bail!( ComputeError: diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs index 84381b777f71..1fa813be05a9 100644 --- a/crates/polars-ops/src/chunked_array/array/namespace.rs +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -41,13 +41,13 @@ pub trait ArrayNameSpace: AsArray { let ca = self.as_array(); if has_inner_nulls(ca) { - return sum_with_nulls(ca, &ca.inner_dtype()); + return sum_with_nulls(ca, ca.inner_dtype()); }; match ca.inner_dtype() { DataType::Boolean => Ok(count_boolean_bits(ca).into_series()), - dt if dt.is_numeric() => Ok(sum_array_numerical(ca, &dt)), - dt => sum_with_nulls(ca, &dt), + dt if dt.is_numeric() => Ok(sum_array_numerical(ca, dt)), + dt => sum_with_nulls(ca, dt), } } @@ -151,7 +151,7 @@ pub trait ArrayNameSpace: AsArray { ArrayChunked::full_null_with_dtype( ca.name(), ca.len(), - &ca.inner_dtype(), + ca.inner_dtype(), ca.width(), ) } diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index e98d9637e88b..91a8e2fe3825 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -653,3 +653,8 @@ def test_fill_null_invalid_supertype() -> None: pl.InvalidOperationError, match="could not determine supertype of" ): df.select(pl.col("date").fill_null(1.0)) + + +def test_raise_array_of_cats() -> None: + with pytest.raises(pl.InvalidOperationError, match="is not yet supported"): + pl.Series([["a", "b"], ["a", "c"]], dtype=pl.Array(pl.Categorical, 2)) diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index d04766edfb95..bf6812b1f241 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -673,12 +673,6 @@ def test_assert_series_equal_nested_categorical_as_str_global() -> None: "s", [ pl.Series([["a", "b"], ["a"]], dtype=pl.List(pl.Categorical)), - pytest.param( - pl.Series([["a", "b"], ["a", "c"]], dtype=pl.Array(pl.Categorical, 2)), - marks=pytest.mark.xfail( - reason="Currently bugged: https://github.com/pola-rs/polars/issues/16706" - ), - ), pl.Series([{"a": "x"}, {"a": "y"}], dtype=pl.Struct({"a": pl.Categorical})), ], )