Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support array arithmetic for equally sized shapes #16791

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ impl DataType {
prev
}

/// Cast the leaf types of Lists/Arrays and keep the nesting.
pub fn cast_leaf(&self, to: DataType) -> DataType {
use DataType::*;
match self {
List(inner) => List(Box::new(inner.cast_leaf(to))),
#[cfg(feature = "dtype-array")]
Array(inner, size) => Array(Box::new(inner.cast_leaf(to)), *size),
_ => to,
}
}

/// Convert to the physical data type
#[must_use]
pub fn to_physical(&self) -> DataType {
Expand Down
57 changes: 57 additions & 0 deletions crates/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,63 @@ impl NumOpsDispatchInner for BooleanType {
}
}

#[cfg(feature = "dtype-array")]
fn array_shape(dt: &DataType, infer: bool) -> Vec<i64> {
fn inner(dt: &DataType, buf: &mut Vec<i64>) {
if let DataType::Array(_, size) = dt {
buf.push(*size as i64)
}
}

let mut buf = vec![];
if infer {
buf.push(-1)
}
inner(dt, &mut buf);
buf
}

#[cfg(feature = "dtype-array")]
impl ArrayChunked {
fn arithm_helper(
&self,
rhs: &Series,
op: &dyn Fn(Series, Series) -> PolarsResult<Series>,
) -> PolarsResult<Series> {
let l_leaf_array = self.clone().into_series().get_leaf_array();
let shape = array_shape(self.dtype(), true);

let r_leaf_array = if rhs.dtype().is_numeric() && rhs.len() == 1 {
rhs.clone()
} else {
polars_ensure!(self.dtype() == rhs.dtype(), InvalidOperation: "can only do arithmetic of array's of the same type and shape; got {} and {}", self.dtype(), rhs.dtype());
rhs.get_leaf_array()
};

let out = op(l_leaf_array, r_leaf_array)?;
out.reshape_array(&shape)
}
}

#[cfg(feature = "dtype-array")]
impl NumOpsDispatchInner for FixedSizeListType {
fn add_to(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.add_to(&r))
}
fn subtract(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.subtract(&r))
}
fn multiply(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.multiply(&r))
}
fn divide(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.divide(&r))
}
fn remainder(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.remainder(&r))
}
}

#[cfg(feature = "checked_arithmetic")]
pub mod checked {
use num_traits::{CheckedDiv, One, ToPrimitive, Zero};
Expand Down
35 changes: 31 additions & 4 deletions crates/polars-core/src/series/arithmetic/owned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ pub fn coerce_lhs_rhs_owned(lhs: Series, rhs: Series) -> PolarsResult<(Series, S
Ok((left, right))
}

fn is_eligible(lhs: &DataType, rhs: &DataType) -> bool {
!lhs.is_logical() && lhs.to_physical().is_numeric() && rhs.to_physical().is_numeric()
}

#[cfg(feature = "performant")]
fn apply_operation_mut<T, F>(mut lhs: Series, mut rhs: Series, op: F) -> Series
where
Expand All @@ -43,10 +47,7 @@ macro_rules! impl_operation {
#[cfg(feature = "performant")]
{
// only physical numeric values take the mutable path
if !self.dtype().is_logical()
&& self.dtype().to_physical().is_numeric()
&& rhs.dtype().to_physical().is_numeric()
{
if is_eligible(self.dtype(), rhs.dtype()) {
let (lhs, rhs) = coerce_lhs_rhs_owned(self, rhs).unwrap();
let (lhs, rhs) = align_chunks_binary_owned_series(lhs, rhs);
use DataType::*;
Expand Down Expand Up @@ -84,3 +85,29 @@ impl_operation!(Add, add, |a, b| a.add(b));
impl_operation!(Sub, sub, |a, b| a.sub(b));
impl_operation!(Mul, mul, |a, b| a.mul(b));
impl_operation!(Div, div, |a, b| a.div(b));

impl Series {
pub fn try_add_owned(self, other: Self) -> PolarsResult<Self> {
if is_eligible(self.dtype(), other.dtype()) {
Ok(self + other)
} else {
self.try_add(&other)
}
}

pub fn try_sub_owned(self, other: Self) -> PolarsResult<Self> {
if is_eligible(self.dtype(), other.dtype()) {
Ok(self - other)
} else {
self.try_sub(&other)
}
}

pub fn try_mul_owned(self, other: Self) -> PolarsResult<Self> {
if is_eligible(self.dtype(), other.dtype()) {
Ok(self * other)
} else {
self.try_mul(&other)
}
}
}
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ impl private::PrivateSeries for SeriesWrap<ArrayChunked> {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
}

fn add_to(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.add_to(rhs)
}

fn subtract(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.subtract(rhs)
}

fn multiply(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.multiply(rhs)
}
fn divide(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.divide(rhs)
}
fn remainder(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.remainder(rhs)
}
}

impl SeriesTrait for SeriesWrap<ArrayChunked> {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ impl Series {
true
}

pub fn from_arrow_chunks(name: &str, arrays: Vec<ArrayRef>) -> PolarsResult<Series> {
Self::try_from((name, arrays))
}

pub fn from_arrow(name: &str, array: ArrayRef) -> PolarsResult<Series> {
Self::try_from((name, array))
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod extend;
mod null;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
mod reshape;

#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ use std::collections::VecDeque;
use arrow::array::*;
use arrow::legacy::kernels::list::array_to_unit_list;
use arrow::offset::Offsets;
use polars_core::chunked_array::builder::get_list_builder;
use polars_core::datatypes::{DataType, ListChunked};
use polars_core::prelude::{IntoSeries, Series};
use polars_error::{polars_bail, polars_ensure, PolarsResult};
#[cfg(feature = "dtype-array")]
use polars_utils::format_tuple;

use crate::prelude::*;
use crate::chunked_array::builder::get_list_builder;
use crate::datatypes::{DataType, ListChunked};
use crate::prelude::{IntoSeries, Series, *};

fn reshape_fast_path(name: &str, s: &Series) -> Series {
let mut ca = match s.dtype() {
Expand All @@ -30,10 +29,10 @@ fn reshape_fast_path(name: &str, s: &Series) -> Series {
ca.into_series()
}

pub trait SeriesReshape: SeriesSealed {
impl Series {
/// Recurse nested types until we are at the leaf array.
fn get_leaf_array(&self) -> Series {
let s = self.as_series();
pub fn get_leaf_array(&self) -> Series {
let s = self;
match s.dtype() {
#[cfg(feature = "dtype-array")]
DataType::Array(dtype, _) => {
Expand Down Expand Up @@ -62,8 +61,8 @@ pub trait SeriesReshape: SeriesSealed {

/// Convert the values of this Series to a ListChunked with a length of 1,
/// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.
fn implode(&self) -> PolarsResult<ListChunked> {
let s = self.as_series();
pub fn implode(&self) -> PolarsResult<ListChunked> {
let s = self;
let s = s.rechunk();
let values = s.array_ref(0);

Expand All @@ -89,7 +88,7 @@ pub trait SeriesReshape: SeriesSealed {
}

#[cfg(feature = "dtype-array")]
fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult<Series> {
pub fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult<Series> {
let mut dims = dimensions.iter().copied().collect::<VecDeque<_>>();

let leaf_array = self.get_leaf_array();
Expand Down Expand Up @@ -136,8 +135,8 @@ pub trait SeriesReshape: SeriesSealed {
})
}

fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult<Series> {
let s = self.as_series();
pub fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult<Series> {
let s = self;

if dimensions.is_empty() {
polars_bail!(ComputeError: "reshape `dimensions` cannot be empty")
Expand Down Expand Up @@ -212,13 +211,10 @@ pub trait SeriesReshape: SeriesSealed {
}
}

impl SeriesReshape for Series {}

#[cfg(test)]
mod test {
use polars_core::prelude::*;

use super::*;
use crate::prelude::*;

#[test]
fn test_to_list() -> PolarsResult<()> {
Expand Down
14 changes: 7 additions & 7 deletions crates/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,31 +136,31 @@ pub(crate) mod private {
}

fn subtract(&self, _rhs: &Series) -> PolarsResult<Series> {
invalid_operation_panic!(sub, self)
polars_bail!(opq = subtract, self._dtype());
}
fn add_to(&self, _rhs: &Series) -> PolarsResult<Series> {
invalid_operation_panic!(add, self)
polars_bail!(opq = add, self._dtype());
}
fn multiply(&self, _rhs: &Series) -> PolarsResult<Series> {
invalid_operation_panic!(mul, self)
polars_bail!(opq = multiply, self._dtype());
}
fn divide(&self, _rhs: &Series) -> PolarsResult<Series> {
invalid_operation_panic!(div, self)
polars_bail!(opq = divide, self._dtype());
}
fn remainder(&self, _rhs: &Series) -> PolarsResult<Series> {
invalid_operation_panic!(rem, self)
polars_bail!(opq = remainder, self._dtype());
}
#[cfg(feature = "algorithm_group_by")]
fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult<GroupsProxy> {
invalid_operation_panic!(group_tuples, self)
polars_bail!(opq = group_tuples, self._dtype());
}
#[cfg(feature = "zip_with")]
fn zip_with_same_type(
&self,
_mask: &BooleanChunked,
_other: &Series,
) -> PolarsResult<Series> {
invalid_operation_panic!(zip_with_same_type, self)
polars_bail!(opq = zip_with_same_type, self._dtype());
}

#[allow(unused_variables)]
Expand Down
12 changes: 9 additions & 3 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ impl BinaryExpr {
/// Can partially do operations in place.
fn apply_operator_owned(left: Series, right: Series, op: Operator) -> PolarsResult<Series> {
match op {
Operator::Plus => Ok(left + right),
Operator::Minus => Ok(left - right),
Operator::Multiply => Ok(left * right),
Operator::Plus => left.try_add_owned(right),
Operator::Minus => left.try_sub_owned(right),
Operator::Multiply => left.try_mul_owned(right),
_ => apply_operator(&left, &right, op),
}
}
Expand All @@ -61,6 +61,12 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu
#[cfg(feature = "dtype-decimal")]
Decimal(_, _) => left.try_div(right),
Date | Datetime(_, _) | Float32 | Float64 => left.try_div(right),
#[cfg(feature = "dtype-array")]
dt @ Array(_, _) => {
let left_dt = dt.cast_leaf(Float64);
let right_dt = right.dtype().cast_leaf(Float64);
left.cast(&left_dt)?.try_div(&right.cast(&right_dt)?)
},
_ => left.cast(&Float64)?.try_div(&right.cast(&Float64)?),
},
Operator::FloorDivide => {
Expand Down
1 change: 0 additions & 1 deletion crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub(crate) use gather::*;
pub(crate) use literal::*;
use polars_core::prelude::*;
use polars_io::predicates::PhysicalIoExpr;
use polars_ops::prelude::*;
use polars_plan::prelude::*;
#[cfg(feature = "dynamic_group_by")]
pub(crate) use rolling::RollingExpr;
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::chunked_array::list::sum_mean::sum_with_nulls;
#[cfg(feature = "diff")]
use crate::prelude::diff;
use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical};
use crate::series::{ArgAgg, SeriesReshape};
use crate::series::ArgAgg;

pub(super) fn has_inner_nulls(ca: &ListChunked) -> bool {
for arr in ca.downcast_iter() {
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ mod rank;
mod reinterpret;
#[cfg(feature = "replace")]
mod replace;
mod reshape;
#[cfg(feature = "rle")]
mod rle;
#[cfg(feature = "rolling_window")]
Expand Down Expand Up @@ -138,7 +137,6 @@ pub use unique::*;
pub use various::*;
mod not;
pub use not::*;
pub use reshape::*;

pub trait SeriesSealed {
fn as_series(&self) -> &Series;
Expand Down
1 change: 0 additions & 1 deletion crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use polars_core::export::regex;
use polars_core::prelude::*;
use polars_error::to_compute_err;
use polars_lazy::prelude::*;
use polars_ops::series::SeriesReshape;
use polars_plan::prelude::typed_lit;
use polars_plan::prelude::LiteralValue::Null;
use polars_time::Duration;
Expand Down
1 change: 0 additions & 1 deletion py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ impl PySeries {
}

fn reshape(&self, dims: Vec<i64>, is_list: bool) -> PyResult<Self> {
use polars_ops::prelude::SeriesReshape;
let out = if is_list {
self.series.reshape_list(&dims)
} else {
Expand Down
Loading
Loading