Skip to content

Commit

Permalink
feat(rust, python): accept expressions in arr.slice (pola-rs#5191)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and zundertj committed Jan 7, 2023
1 parent ac758fd commit a65c2ef
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 35 deletions.
54 changes: 53 additions & 1 deletion polars/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,40 @@ impl<'a> Deserialize<'a> for AnyValue<'static> {
}

impl<'a> AnyValue<'a> {
pub fn dtype(&self) -> DataType {
use AnyValue::*;
match self.as_borrowed() {
Null => DataType::Unknown,
Int8(_) => DataType::Int8,
Int16(_) => DataType::Int16,
Int32(_) => DataType::Int32,
Int64(_) => DataType::Int64,
UInt8(_) => DataType::UInt8,
UInt16(_) => DataType::UInt16,
UInt32(_) => DataType::UInt32,
UInt64(_) => DataType::UInt64,
Float32(_) => DataType::Float32,
Float64(_) => DataType::Float64,
#[cfg(feature = "dtype-date")]
Date(_) => DataType::Date,
#[cfg(feature = "dtype-datetime")]
Datetime(_, tu, tz) => DataType::Datetime(tu, tz.clone()),
#[cfg(feature = "dtype-time")]
Time(_) => DataType::Time,
#[cfg(feature = "dtype-duration")]
Duration(_, tu) => DataType::Duration(tu),
Boolean(_) => DataType::Boolean,
Utf8(_) => DataType::Utf8,
#[cfg(feature = "dtype-categorical")]
Categorical(_, _) => DataType::Categorical(None),
List(s) => DataType::List(Box::new(s.dtype().clone())),
#[cfg(feature = "dtype-struct")]
Struct(_, field) => DataType::Struct(field.to_vec()),
#[cfg(feature = "dtype-binary")]
Binary(_) => DataType::Binary,
_ => unimplemented!(),
}
}
/// Extract a numerical value from the AnyValue
#[doc(hidden)]
#[cfg(feature = "private")]
Expand Down Expand Up @@ -594,9 +628,27 @@ impl<'a> AnyValue<'a> {
NumCast::from(0)
}
}
_ => unimplemented!(),
dt => panic!("dtype {:?} not implemented", dt),
}
}

pub fn try_extract<T: NumCast>(&self) -> PolarsResult<T> {
self.extract().ok_or_else(|| {
PolarsError::ComputeError(
format!(
"could not extract number from AnyValue of dtype: '{:?}'",
self.dtype()
)
.into(),
)
})
}
}

impl From<AnyValue<'_>> for DataType {
fn from(value: AnyValue<'_>) -> Self {
value.dtype()
}
}

impl<'a> Hash for AnyValue<'a> {
Expand Down
99 changes: 95 additions & 4 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
use polars_arrow::utils::CustomIterTools;

use super::*;

#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ListFunction {
Concat,
#[cfg(feature = "is_in")]
Contains,
Slice,
}

impl Display for ListFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use self::*;
match self {
ListFunction::Concat => write!(f, "concat"),
}
use ListFunction::*;

let name = match self {
Concat => "concat",
Contains => "contains",
Slice => "slice",
};
write!(f, "{}", name)
}
}

Expand All @@ -26,6 +35,88 @@ pub(super) fn contains(args: &mut [Series]) -> PolarsResult<Series> {
})
}

pub(super) fn slice(args: &mut [Series]) -> PolarsResult<Series> {
let s = &args[0];
let list_ca = s.list()?;
let offset_s = &args[1];
let length_s = &args[2];

let mut out: ListChunked = match (offset_s.len(), length_s.len()) {
(1, 1) => {
let offset = offset_s.get(0).try_extract::<i64>()?;
let slice_len = length_s.get(0).try_extract::<usize>()?;
return Ok(list_ca.lst_slice(offset, slice_len).into_series());
}
(1, length_slice_len) => {
if length_slice_len != list_ca.len() {
return Err(PolarsError::ComputeError("the length of the slice 'length' argument does not match that of the list column".into()));
}
let offset = offset_s.get(0).try_extract::<i64>()?;
// cast to i64 as it is more likely that it is that dtype
// instead of usize/u64 (we never need that max length)
let length_ca = length_s.cast(&DataType::Int64)?;
let length_ca = length_ca.i64().unwrap();

list_ca
.amortized_iter()
.zip(length_ca.into_iter())
.map(|(opt_s, opt_length)| match (opt_s, opt_length) {
(Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)),
_ => None,
})
.collect_trusted()
}
(offset_len, 1) => {
if offset_len != list_ca.len() {
return Err(PolarsError::ComputeError("the length of the slice 'offset' argument does not match that of the list column".into()));
}
let length_slice = length_s.get(0).try_extract::<usize>()?;
let offset_ca = offset_s.cast(&DataType::Int64)?;
let offset_ca = offset_ca.i64().unwrap();
list_ca
.amortized_iter()
.zip(offset_ca)
.map(|(opt_s, opt_offset)| match (opt_s, opt_offset) {
(Some(s), Some(offset)) => {
Some(s.as_ref().slice(offset, length_slice as usize))
}
_ => None,
})
.collect_trusted()
}
_ => {
if offset_s.len() != list_ca.len() {
return Err(PolarsError::ComputeError("the length of the slice 'offset' argument does not match that of the list column".into()));
}
if length_s.len() != list_ca.len() {
return Err(PolarsError::ComputeError("the length of the slice 'length' argument does not match that of the list column".into()));
}
let offset_ca = offset_s.cast(&DataType::Int64)?;
let offset_ca = offset_ca.i64()?;
// cast to i64 as it is more likely that it is that dtype
// instead of usize/u64 (we never need that max length)
let length_ca = length_s.cast(&DataType::Int64)?;
let length_ca = length_ca.i64().unwrap();

list_ca
.amortized_iter()
.zip(offset_ca.into_iter())
.zip(length_ca.into_iter())
.map(
|((opt_s, opt_offset), opt_length)| match (opt_s, opt_offset, opt_length) {
(Some(s), Some(offset), Some(length)) => {
Some(s.as_ref().slice(offset, length as usize))
}
_ => None,
},
)
.collect_trusted()
}
};
out.rename(s.name());
Ok(out.into_series())
}

pub(super) fn concat(s: &mut [Series]) -> PolarsResult<Series> {
let mut first = std::mem::take(&mut s[0]);
let other = &s[1..];
Expand Down
11 changes: 3 additions & 8 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ pub enum FunctionExpr {
FillNull {
super_type: DataType,
},
#[cfg(feature = "is_in")]
ListContains,
#[cfg(all(feature = "rolling_window", feature = "moment"))]
// if we add more, make a sub enum
RollingSkew {
Expand Down Expand Up @@ -138,8 +136,6 @@ impl Display for FunctionExpr {
#[cfg(feature = "sign")]
Sign => "sign",
FillNull { .. } => "fill_null",
#[cfg(feature = "is_in")]
ListContains => "arr.contains",
#[cfg(all(feature = "rolling_window", feature = "moment"))]
RollingSkew { .. } => "rolling_skew",
ShiftAndFill { .. } => "shift_and_fill",
Expand Down Expand Up @@ -293,10 +289,6 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
map_as_slice!(fill_null::fill_null, &super_type)
}

#[cfg(feature = "is_in")]
ListContains => {
wrap!(list::contains)
}
#[cfg(all(feature = "rolling_window", feature = "moment"))]
RollingSkew { window_size, bias } => {
map!(rolling::rolling_skew, window_size, bias)
Expand All @@ -314,6 +306,9 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
use ListFunction::*;
match lf {
Concat => wrap!(list::concat),
#[cfg(feature = "is_in")]
Contains => wrap!(list::contains),
Slice => wrap!(list::slice),
}
}
#[cfg(feature = "dtype-struct")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ impl FunctionExpr {
#[cfg(feature = "sign")]
Sign => with_dtype(DataType::Int64),
FillNull { super_type, .. } => with_dtype(super_type.clone()),
#[cfg(feature = "is_in")]
ListContains => with_dtype(DataType::Boolean),
#[cfg(all(feature = "rolling_window", feature = "moment"))]
RollingSkew { .. } => float_dtype(),
ShiftAndFill { .. } => same_type(),
Expand All @@ -161,6 +159,9 @@ impl FunctionExpr {
use ListFunction::*;
match l {
Concat => inner_super_type_list(),
#[cfg(feature = "is_in")]
Contains => with_dtype(DataType::Boolean),
Slice => same_type(),
}
}
#[cfg(feature = "dtype-struct")]
Expand Down
24 changes: 12 additions & 12 deletions polars/polars-lazy/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use polars_core::series::ops::NullBehavior;
use polars_ops::prelude::*;

use crate::dsl::function_expr::FunctionExpr;
use crate::prelude::function_expr::ListFunction;
use crate::prelude::*;

/// Specialized expressions for [`Series`] of [`DataType::List`].
Expand Down Expand Up @@ -188,23 +189,22 @@ impl ListNameSpace {
}

/// Slice every sublist.
pub fn slice(self, offset: i64, length: usize) -> Expr {
self.0
.map(
move |s| Ok(s.list()?.lst_slice(offset, length).into_series()),
GetOutput::same_type(),
)
.with_fmt("arr.slice")
pub fn slice(self, offset: Expr, length: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Slice),
&[offset, length],
false,
)
}

/// Get the head of every sublist
pub fn head(self, n: usize) -> Expr {
self.slice(0, n)
pub fn head(self, n: Expr) -> Expr {
self.slice(lit(0), n)
}

/// Get the tail of every sublist
pub fn tail(self, n: usize) -> Expr {
self.slice(-(n as i64), n)
pub fn tail(self, n: Expr) -> Expr {
self.slice(lit(0) - n.clone().cast(DataType::Int64), n)
}

#[cfg(feature = "list_to_struct")]
Expand Down Expand Up @@ -238,7 +238,7 @@ impl ListNameSpace {

Expr::Function {
input: vec![self.0, other],
function: FunctionExpr::ListContains,
function: FunctionExpr::ListExpr(ListFunction::Contains),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
Expand Down
13 changes: 9 additions & 4 deletions py-polars/polars/internals/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,9 @@ def shift(self, periods: int = 1) -> pli.Expr:
"""
return pli.wrap_expr(self._pyexpr.lst_shift(periods))

def slice(self, offset: int, length: int | None = None) -> pli.Expr:
def slice(
self, offset: int | str | pli.Expr, length: int | str | pli.Expr | None = None
) -> pli.Expr:
"""
Slice every sublist.
Expand All @@ -533,9 +535,11 @@ def slice(self, offset: int, length: int | None = None) -> pli.Expr:
]
"""
offset = pli.expr_to_lit_or_expr(offset, str_to_lit=False)._pyexpr
length = pli.expr_to_lit_or_expr(length, str_to_lit=False)._pyexpr
return pli.wrap_expr(self._pyexpr.lst_slice(offset, length))

def head(self, n: int = 5) -> pli.Expr:
def head(self, n: int | str | pli.Expr = 5) -> pli.Expr:
"""
Slice the first `n` values of every sublist.
Expand All @@ -558,7 +562,7 @@ def head(self, n: int = 5) -> pli.Expr:
"""
return self.slice(0, n)

def tail(self, n: int = 5) -> pli.Expr:
def tail(self, n: int | str | pli.Expr = 5) -> pli.Expr:
"""
Slice the last `n` values of every sublist.
Expand All @@ -579,7 +583,8 @@ def tail(self, n: int = 5) -> pli.Expr:
]
"""
return self.slice(-n, n)
offset = -pli.expr_to_lit_or_expr(n, str_to_lit=False)
return self.slice(offset, n)

def to_struct(
self,
Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,12 +1329,12 @@ impl PyExpr {
self.inner.clone().arr().shift(periods).into()
}

fn lst_slice(&self, offset: i64, length: Option<usize>) -> Self {
fn lst_slice(&self, offset: PyExpr, length: Option<PyExpr>) -> Self {
let length = match length {
Some(i) => i,
None => usize::MAX,
Some(i) => i.inner,
None => dsl::lit(i64::MAX),
};
self.inner.clone().arr().slice(offset, length).into()
self.inner.clone().arr().slice(offset.inner, length).into()
}

fn lst_eval(&self, expr: PyExpr, parallel: bool) -> Self {
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,23 @@ def groupby_list_column() -> None:
"a_list": [["b"], ["a", "a"]],
"a": ["b", "a"],
}


def test_list_slice() -> None:
df = pl.DataFrame(
{
"lst": [[1, 2, 3, 4], [10, 2, 1]],
"offset": [1, 2],
"len": [3, 2],
}
)

assert df.select([pl.col("lst").arr.slice("offset", "len")]).to_dict(False) == {
"lst": [[2, 3, 4], [1]]
}
assert df.select([pl.col("lst").arr.slice("offset", 1)]).to_dict(False) == {
"lst": [[2], [1]]
}
assert df.select([pl.col("lst").arr.slice(-2, "len")]).to_dict(False) == {
"lst": [[3, 4], [2, 1]]
}

0 comments on commit a65c2ef

Please sign in to comment.