Skip to content

Commit

Permalink
feat(rust, python): allow nested categorical cast
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 8, 2023
1 parent aef8938 commit 433c907
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ node_modules/
polars/vendor
target/
venv/
.vim
82 changes: 53 additions & 29 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,38 +183,41 @@ fn cast_inner_list_type(list: &ListArray<i64>, child_type: &DataType) -> PolarsR
/// So this implementation casts the inner type
impl ChunkCast for ListChunked {
fn cast(&self, data_type: &DataType) -> PolarsResult<Series> {
use DataType::*;
match data_type {
DataType::List(child_type) => {
List(child_type) => {
let phys_child = child_type.to_physical();

if phys_child.is_primitive() {
let mut ca = if child_type.to_physical() != self.inner_dtype().to_physical() {
let chunks = self
.downcast_iter()
.map(|list| cast_inner_list_type(list, &phys_child))
.collect::<PolarsResult<_>>()?;
unsafe { ListChunked::from_chunks(self.name(), chunks) }
} else {
self.clone()
};
ca.set_inner_dtype(*child_type.clone());
Ok(ca.into_series())
} else {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
let s = Series::try_from(("", arr.values().clone())).unwrap();
let new_inner = s.cast(child_type)?;
let new_values = new_inner.array_ref(0).clone();

let data_type =
ListArray::<i64>::default_datatype(new_values.data_type().clone());
let new_arr = ListArray::<i64>::new(
data_type,
arr.offsets().clone(),
new_values,
arr.validity().cloned(),
);
Series::try_from((self.name(), Box::new(new_arr) as ArrayRef))
match (self.inner_dtype(), &**child_type) {
#[cfg(feature = "dtype-categorical")]
(Utf8, Categorical(_)) => {
let (arr, inner_dtype) = cast_list(self, child_type)?;
Ok(unsafe {
Series::from_chunks_and_dtype_unchecked(
self.name(),
vec![arr],
&List(Box::new(inner_dtype)),
)
})
}
_ if phys_child.is_primitive() => {
let mut ca = if child_type.to_physical() != self.inner_dtype().to_physical()
{
let chunks = self
.downcast_iter()
.map(|list| cast_inner_list_type(list, &phys_child))
.collect::<PolarsResult<_>>()?;
unsafe { ListChunked::from_chunks(self.name(), chunks) }
} else {
self.clone()
};
ca.set_inner_dtype(*child_type.clone());
Ok(ca.into_series())
}
_ => {
let arr = cast_list(self, child_type)?.0;
Series::try_from((self.name(), arr))
}
}
}
_ => Err(PolarsError::ComputeError("Cannot cast list type".into())),
Expand All @@ -226,6 +229,27 @@ impl ChunkCast for ListChunked {
}
}

// returns inner data type
fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, DataType)> {
let ca = ca.rechunk();
let arr = ca.downcast_iter().next().unwrap();
let s = Series::try_from(("", arr.values().clone())).unwrap();
let new_inner = s.cast(child_type)?;

let inner_dtype = new_inner.dtype().clone();

let new_values = new_inner.array_ref(0).clone();

let data_type = ListArray::<i64>::default_datatype(new_values.data_type().clone());
let new_arr = ListArray::<i64>::new(
data_type,
arr.offsets().clone(),
new_values,
arr.validity().cloned(),
);
Ok((Box::new(new_arr), inner_dtype))
}

#[cfg(test)]
mod test {
use crate::prelude::*;
Expand Down
20 changes: 17 additions & 3 deletions polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,23 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized {

let expr2 = expr.clone();
let func = move |s: Series| {
for name in expr_to_leaf_column_names(&expr) {
if !name.is_empty() {
return Err(PolarsError::ComputeError(r#"Named columns not allowed in 'arr.eval'. Consider using 'element' or 'col("")'."#.into()));
for e in expr.into_iter() {
match e {
#[cfg(feature = "dtype-categorical")]
Expr::Cast {
data_type: DataType::Categorical(_),
..
} => {
return Err(PolarsError::ComputeError(
"Casting to 'Categorical' not allowed in 'arr.eval'".into(),
))
}
Expr::Column(name) => {
if !name.is_empty() {
return Err(PolarsError::ComputeError(r#"Named columns not allowed in 'arr.eval'. Consider using 'element' or 'col("")'."#.into()));
}
}
_ => {}
}
}

Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,17 @@ def test_sort_categoricals_6014() -> None:
assert out.to_dict(False) == {"key": ["bbb", "aaa", "ccc"]}
out = df2.sort("key")
assert out.to_dict(False) == {"key": ["aaa", "bbb", "ccc"]}


def test_cast_inner_categorical() -> None:
dtype = pl.List(pl.Categorical)
out = pl.Series("foo", [["a"], ["a", "b"]]).cast(dtype)
assert out.dtype == dtype
assert out.to_list() == [["a"], ["a", "b"]]

with pytest.raises(
pl.ComputeError, match=r"Casting to 'Categorical' not allowed in 'arr.eval'"
):
pl.Series("foo", [["a", "b"], ["a", "b"]]).arr.eval(
pl.element().cast(pl.Categorical)
)

0 comments on commit 433c907

Please sign in to comment.