From c1872ccefa3cba72bc25a93af1cb0ca4f3431b2d Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 24 Sep 2024 13:59:21 -0400 Subject: [PATCH 01/11] Tests for arithmetic operations between list and numeric Series. --- .../operations/arithmetic/test_arithmetic.py | 138 +++++++++++++++++- 1 file changed, 135 insertions(+), 3 deletions(-) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 360def065ca1..e4b9726bb5fa 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -21,7 +21,7 @@ UInt32, UInt64, ) -from polars.exceptions import ColumnNotFoundError, InvalidOperationError, SchemaError +from polars.exceptions import ColumnNotFoundError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES @@ -707,14 +707,146 @@ def test_list_arithmetic_error_cases() -> None: _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) # Wrong types: - with pytest.raises(InvalidOperationError, match="cannot cast List type"): + with pytest.raises( + InvalidOperationError, match="they and other Series are numeric" + ): _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) # Different nesting: - with pytest.raises(SchemaError, match="failed to determine supertype"): + with pytest.raises(InvalidOperationError, match="should have same dtype"): _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + # All 5 arithmetic operations: + ([[3, 4], [6]], lambda a, b: a + b, ("list", "int64")), + ([[-1, 0], [0]], lambda a, b: a - b, ("list", "int64")), + ([[2, 4], [9]], lambda a, b: a * b, ("list", "int64")), + ([[0.5, 1.0], [1.0]], lambda a, b: a / b, ("list", "int64")), + ([[1, 0], [0]], lambda a, b: a % b, ("list", "int64")), + # Different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("list", "uint8"), + ), + # Extra nesting + different types: + ( + [[[3, 4]], [[8]]], + lambda a, b: a + b, + ("nested", "int64"), + ), + # Primitive numeric on the left; only addition and multiplication are + # supported: + ([[3, 4], [6]], lambda a, b: a + b, ("int64", "list")), + ([[2, 4], [9]], lambda a, b: a * b, ("int64", "list")), + # Primitive numeric on the left with different types: + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("uint8", "list"), + ), + ( + [[2, 4], [12]], + lambda a, b: a * b, + ("uint8", "list"), + ), + ], +) +def test_list_and_numeric_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("list", [[1, 2], [3]]), + pl.Series("int64", [2, 3], dtype=pl.Int64()), + pl.Series("uint8", [2, 4], dtype=pl.UInt8()), + pl.Series("nested", [[[1, 2]], [[5]]]), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + # Null on numeric on the right: + ([[1, 2], [3]], [1, None], [[2, 3], None]), + # Null on list on the left: + ([[[1, 2]], [[3]]], [None, 1], [None, [[4]]]), + # Extra nesting: + ([[[2, None]], [[3, 6]]], [3, 4], [[[5, None]], [[7, 10]]]), + ], +) +def test_list_and_numeric_arithmetic_nulls( + a: list[Any], b: list[Any], expected: list[Any] +) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + # Swap sides: + assert_series_equal(series_b + series_a, series_expected) + assert_series_equal( + series_b._recursive_cast_to_dtype(pl.Int32()) + + series_a._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_and_numeric_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) + pl.Series("b", [1, 2]) + with pytest.raises( + InvalidOperationError, match="series of different lengths: got 3 and 2" + ): + _ = pl.Series("a", [[1, 2], [3, 4], [5, 6]]) / pl.Series("b", [1, None]) + + # Wrong types: + with pytest.raises( + InvalidOperationError, match="they and other Series are numeric" + ): + _ = pl.Series("a", [[1, 2], [3, 4]]) + pl.Series("b", ["hello", "world"]) + + # Numeric on right and list on left doesn't work for subtraction, division, + # or reminder, since they're not commutative operations and it seems + # semantically weird. + numeric = pl.Series("a", [1, 2]) + list_num = pl.Series("b", [[3, 4], [5, 6]]) + with pytest.raises(InvalidOperationError, match="operation not supported"): + numeric / list_num + with pytest.raises(InvalidOperationError, match="operation not supported"): + numeric - list_num + with pytest.raises(InvalidOperationError, match="operation not supported"): + numeric % list_num + + def test_schema_owned_arithmetic_5669() -> None: df = ( pl.LazyFrame({"A": [1, 2, 3]}) From 944d4aefe77aa9e6c871f461ffcff71c7bdf16a2 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 24 Sep 2024 14:00:16 -0400 Subject: [PATCH 02/11] Cast the leaf dtype for nested dtypes. --- py-polars/polars/series/series.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 506c52ed9f7a..f532de9011b4 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1019,8 +1019,11 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self: else: return self._from_pyseries(getattr(self._s, op_s)(_s)) else: - other = maybe_cast(other, self.dtype) - f = get_ffi_func(op_ffi, self.dtype, self._s) + dtype = self.dtype + while hasattr(dtype, "inner"): + dtype = dtype.inner + other = maybe_cast(other, dtype) + f = get_ffi_func(op_ffi, dtype, self._s) if f is None: msg = ( f"cannot do arithmetic with Series of dtype: {self.dtype!r} and argument" From e3141fdbe20aee23cca8d180f9db7c6288fdf42b Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 24 Sep 2024 14:01:14 -0400 Subject: [PATCH 03/11] Casting logic for mixed list and numeric dtypes. --- .../src/series/arithmetic/borrowed.rs | 24 ++++++++++++---- crates/polars-plan/src/plans/aexpr/schema.rs | 28 +++++++++++++++++++ .../plans/conversion/type_coercion/binary.rs | 14 ++++++++-- 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 115626d63805..26c8282184be 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -390,23 +390,35 @@ pub(crate) fn coerce_lhs_rhs<'a>( if let Some(result) = coerce_time_units(lhs, rhs) { return Ok(result); } - let dtype = match (lhs.dtype(), rhs.dtype()) { + let (left_dtype, right_dtype) = (lhs.dtype(), rhs.dtype()); + let leaf_super_dtype = match (left_dtype, right_dtype) { #[cfg(feature = "dtype-struct")] (DataType::Struct(_), DataType::Struct(_)) => { return Ok((Cow::Borrowed(lhs), Cow::Borrowed(rhs))) }, - _ => try_get_supertype(lhs.dtype(), rhs.dtype())?, + _ => try_get_supertype(left_dtype.leaf_dtype(), right_dtype.leaf_dtype())?, }; - let left = if lhs.dtype() == &dtype { + let mut new_left_dtype = left_dtype.cast_leaf(leaf_super_dtype.clone()); + let mut new_right_dtype = right_dtype.cast_leaf(leaf_super_dtype); + + // If we have e.g. Array and List, we want to convert those too. + if (left_dtype.is_list() && right_dtype.is_array()) + || (left_dtype.is_array() && right_dtype.is_list()) + { + new_left_dtype = try_get_supertype(&new_left_dtype, &new_right_dtype)?; + new_right_dtype = new_left_dtype.clone(); + } + + let left = if lhs.dtype() == &new_left_dtype { Cow::Borrowed(lhs) } else { - Cow::Owned(lhs.cast(&dtype)?) + Cow::Owned(lhs.cast(&new_left_dtype)?) }; - let right = if rhs.dtype() == &dtype { + let right = if rhs.dtype() == &new_right_dtype { Cow::Borrowed(rhs) } else { - Cow::Owned(rhs.cast(&dtype)?) + Cow::Owned(rhs.cast(&new_right_dtype)?) }; Ok((left, right)) } diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 0145776684f4..32092f958a47 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -302,6 +302,15 @@ fn func_args_to_fields( .collect() } +/// Left is List(_), right is numeric primitive, we want to cast left's leaf +/// dtype. +fn try_get_list_super_type(left: &DataType, right: &DataType) -> PolarsResult { + debug_assert!(left.is_list()); + debug_assert!(right.is_numeric()); + let super_type = try_get_supertype(left.leaf_dtype(), right)?; + Ok(left.cast_leaf(super_type)) +} + fn get_arithmetic_field( left: Node, right: Node, @@ -355,6 +364,9 @@ fn get_arithmetic_field( (_, Time) | (Time, _) => { polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, + (List(_), _) if right_type.is_numeric() => { + try_get_list_super_type(&left_field.dtype, &right_type)? + }, (left, right) => try_get_supertype(left, right)?, } }, @@ -380,6 +392,12 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, (Boolean, Boolean) => IDX_DTYPE, + (List(_), _) if right_type.is_numeric() => { + try_get_list_super_type(&left_field.dtype, &right_type)? + }, + (_, List(_)) if left_field.dtype.is_numeric() => { + try_get_supertype(&left_field.dtype, right_type.leaf_dtype())? + }, (left, right) => try_get_supertype(left, right)?, } }, @@ -412,6 +430,16 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, }, + (List(_), _) if right_type.is_numeric() => { + let dtype = try_get_list_super_type(&left_field.dtype, &right_type)?; + left_field.coerce(dtype); + return Ok(left_field); + }, + (_, List(_)) if left_field.dtype.is_numeric() => { + let dtype = try_get_supertype(&left_field.dtype, right_type.leaf_dtype())?; + left_field.coerce(dtype); + return Ok(left_field); + }, _ => { // Avoid needlessly type casting numeric columns during arithmetic // with literals. diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 37d58e004ab1..20fad09c583c 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -59,9 +59,14 @@ fn process_list_arithmetic( (DataType::List(_), _) => { let leaf = type_left.leaf_dtype(); if type_right != *leaf { + let new_dtype = if type_right.is_nested() { + type_left.cast_leaf(leaf.clone()) + } else { + leaf.clone() + }; let new_node_right = expr_arena.add(AExpr::Cast { expr: node_right, - dtype: type_left.cast_leaf(leaf.clone()), + dtype: new_dtype, options: CastOptions::NonStrict, }); @@ -77,9 +82,14 @@ fn process_list_arithmetic( (_, DataType::List(_)) => { let leaf = type_right.leaf_dtype(); if type_left != *leaf { + let new_dtype = if type_left.is_nested() { + type_right.cast_leaf(leaf.clone()) + } else { + leaf.clone() + }; let new_node_left = expr_arena.add(AExpr::Cast { expr: node_left, - dtype: type_right.cast_leaf(leaf.clone()), + dtype: new_dtype, options: CastOptions::NonStrict, }); From 835f9fa0c5d584e0a0f6cd2c1b6008aed190ebcc Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 24 Sep 2024 14:01:39 -0400 Subject: [PATCH 04/11] Logic for arithmetic between a ListChunked and a numeric Series. --- .../src/series/arithmetic/list_borrowed.rs | 116 +++++++++++++++--- 1 file changed, 102 insertions(+), 14 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 1628780d7b0e..4a6d717e3ef5 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -53,16 +53,94 @@ fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { } } +/// Arithmetic operations that can be applied to a Series +#[derive(Clone, Copy)] +enum Op { + Add, + Subtract, + Multiply, + Divide, + Remainder, +} + +impl Op { + /// Apply the operation to a pair of Series. + fn apply_with_series(&self, lhs: &Series, rhs: &Series) -> PolarsResult { + use Op::*; + + match self { + Add => lhs + rhs, + Subtract => lhs - rhs, + Multiply => lhs * rhs, + Divide => lhs / rhs, + Remainder => lhs % rhs, + } + } + + /// Apply the operation to a Series and scalar. + fn apply_with_scalar(&self, lhs: &Series, rhs: T) -> Series { + use Op::*; + + match self { + Add => lhs + rhs, + Subtract => lhs - rhs, + Multiply => lhs * rhs, + Divide => lhs / rhs, + Remainder => lhs % rhs, + } + } +} + impl ListChunked { + /// Helper function for NumOpsDispatchInner implementation for ListChunked. + /// + /// Run the given `op` on `self` and `rhs`, for cases where `rhs` has a + /// primitive numeric dtype. + fn arithm_helper_numeric(&self, rhs: &Series, op: Op) -> PolarsResult { + let mut result = AnonymousListBuilder::new( + self.name().clone(), + self.len(), + Some(self.inner_dtype().clone()), + ); + macro_rules! combine { + ($ca:expr) => {{ + self.amortized_iter() + .zip($ca.iter()) + .map(|(a, b)| { + let (Some(a_owner), Some(b)) = (a, b) else { + // Operations with nulls always result in nulls: + return Ok(None); + }; + let a = a_owner.as_ref().rechunk(); + let leaf_result = op.apply_with_scalar(&a.get_leaf_array(), b); + let result = + reshape_list_based_on(&leaf_result.chunks()[0], &a.chunks()[0]); + Ok(Some(result)) + }) + .collect::>>>>()? + }}; + } + let combined = downcast_as_macro_arg_physical!(rhs, combine); + for arr in combined.iter() { + if let Some(arr) = arr { + result.append_array(arr.as_ref()); + } else { + result.append_null(); + } + } + Ok(result.finish().into()) + } + /// Helper function for NumOpsDispatchInner implementation for ListChunked. /// /// Run the given `op` on `self` and `rhs`. - fn arithm_helper( - &self, - rhs: &Series, - op: &dyn Fn(&Series, &Series) -> PolarsResult, - has_nulls: Option, - ) -> PolarsResult { + fn arithm_helper(&self, rhs: &Series, op: Op, has_nulls: Option) -> PolarsResult { + polars_ensure!( + self.dtype().leaf_dtype().is_numeric() && rhs.dtype().leaf_dtype().is_numeric(), + InvalidOperation: "List Series can only do arithmetic operations if they and other Series are numeric, left and right dtypes are {:?} and {:?}", + self.dtype(), + rhs.dtype() + ); polars_ensure!( self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", @@ -70,6 +148,17 @@ impl ListChunked { rhs.len() ); + if rhs.dtype().is_numeric() { + return self.arithm_helper_numeric(rhs, op); + } + + polars_ensure!( + self.dtype() == rhs.dtype(), + InvalidOperation: "List Series doing arithmetic operations to each other should have same dtype; got {:?} and {:?}", + self.dtype(), + rhs.dtype() + ); + let mut has_nulls = has_nulls.unwrap_or(false); if !has_nulls { for chunk in self.chunks().iter() { @@ -118,7 +207,7 @@ impl ListChunked { // along. a_listchunked.arithm_helper(b, op, Some(true)) } else { - op(a, b) + op.apply_with_series(a, b) }; chunk_result.map(Some) }).collect::>>>()?; @@ -139,8 +228,7 @@ impl ListChunked { InvalidOperation: "can only do arithmetic operations on lists of the same size" ); - let result = op(&l_leaf_array, &r_leaf_array)?; - + let result = op.apply_with_series(&l_leaf_array, &r_leaf_array)?; // We now need to wrap the Arrow arrays with the metadata that turns // them into lists: // TODO is there a way to do this without cloning the underlying data? @@ -160,18 +248,18 @@ impl ListChunked { impl NumOpsDispatchInner for ListType { fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) + lhs.arithm_helper(rhs, Op::Add, None) } fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) + lhs.arithm_helper(rhs, Op::Subtract, None) } fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) + lhs.arithm_helper(rhs, Op::Multiply, None) } fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) + lhs.arithm_helper(rhs, Op::Divide, None) } fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) + lhs.arithm_helper(rhs, Op::Remainder, None) } } From ef43739ecb7fa0a03843d100027ead651ecdc7e5 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 24 Sep 2024 14:02:16 -0400 Subject: [PATCH 05/11] Handle addition and multipliction cases where the rhs Series is numeric and the ls is a list. --- crates/polars-core/src/series/arithmetic/borrowed.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 26c8282184be..a40e5dd6d2d5 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -532,6 +532,12 @@ impl Add for &Series { (DataType::Struct(_), DataType::Struct(_)) => { _struct_arithmetic(self, rhs, |a, b| a.add(b)) }, + (left_dtype, DataType::List(_)) if left_dtype.is_numeric() => { + // Lists have implementation logic for rhs numeric: + let mut result = (rhs + self)?; + result.rename(self.name().clone()); + Ok(result) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.add_to(rhs.as_ref()) @@ -584,6 +590,12 @@ impl Mul for &Series { let out = rhs.multiply(self)?; Ok(out.with_name(self.name().clone())) }, + (left_dtype, DataType::List(_)) if left_dtype.is_numeric() => { + // Lists have implementation logic for rhs numeric: + let mut result = (rhs * self)?; + result.rename(self.name().clone()); + Ok(result) + }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; lhs.multiply(rhs.as_ref()) From 3ddfbc6e467f45bd6969ffa6464dcf824b028aa2 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 24 Sep 2024 14:09:20 -0400 Subject: [PATCH 06/11] Pacify pyright. --- .../tests/unit/operations/arithmetic/test_arithmetic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index e4b9726bb5fa..8ff7821c57e1 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -840,11 +840,11 @@ def test_list_and_numeric_arithmetic_error_cases() -> None: numeric = pl.Series("a", [1, 2]) list_num = pl.Series("b", [[3, 4], [5, 6]]) with pytest.raises(InvalidOperationError, match="operation not supported"): - numeric / list_num + _ = numeric / list_num with pytest.raises(InvalidOperationError, match="operation not supported"): - numeric - list_num + _ = numeric - list_num with pytest.raises(InvalidOperationError, match="operation not supported"): - numeric % list_num + _ = numeric % list_num def test_schema_owned_arithmetic_5669() -> None: From 9f984ee0a98ea2475edfead27162c70dff7009d2 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 24 Sep 2024 14:10:02 -0400 Subject: [PATCH 07/11] Revert unnecessary change. --- py-polars/polars/series/series.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index f532de9011b4..506c52ed9f7a 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1019,11 +1019,8 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self: else: return self._from_pyseries(getattr(self._s, op_s)(_s)) else: - dtype = self.dtype - while hasattr(dtype, "inner"): - dtype = dtype.inner - other = maybe_cast(other, dtype) - f = get_ffi_func(op_ffi, dtype, self._s) + other = maybe_cast(other, self.dtype) + f = get_ffi_func(op_ffi, self.dtype, self._s) if f is None: msg = ( f"cannot do arithmetic with Series of dtype: {self.dtype!r} and argument" From d25c9b864a987ab9546a6bcd859ed1ee0f02542e Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 2 Oct 2024 15:55:38 +1000 Subject: [PATCH 08/11] fix incorrect arithmetic, remove type coercion --- crates/polars-plan/src/plans/aexpr/schema.rs | 34 +++++----- .../plans/conversion/type_coercion/binary.rs | 64 ------------------- .../operations/arithmetic/test_arithmetic.py | 12 ++++ 3 files changed, 27 insertions(+), 83 deletions(-) diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 65fa58452312..ad1d2535486e 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -319,15 +319,6 @@ fn func_args_to_fields( .collect() } -/// Left is List(_), right is numeric primitive, we want to cast left's leaf -/// dtype. -fn try_get_list_super_type(left: &DataType, right: &DataType) -> PolarsResult { - debug_assert!(left.is_list()); - debug_assert!(right.is_numeric()); - let super_type = try_get_supertype(left.leaf_dtype(), right)?; - Ok(left.cast_leaf(super_type)) -} - fn get_arithmetic_field( left: Node, right: Node, @@ -381,8 +372,11 @@ fn get_arithmetic_field( (_, Time) | (Time, _) => { polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, - (List(_), _) if right_type.is_numeric() => { - try_get_list_super_type(&left_field.dtype, &right_type)? + (list_dtype @ List(_), prim_dtype) if prim_dtype.is_primitive() => { + list_dtype.cast_leaf(try_get_supertype(list_dtype.leaf_dtype(), prim_dtype)?) + }, + (prim_dtype, list_dtype @ List(_)) if prim_dtype.is_primitive() => { + list_dtype.cast_leaf(try_get_supertype(list_dtype.leaf_dtype(), prim_dtype)?) }, (left, right) => try_get_supertype(left, right)?, } @@ -409,11 +403,11 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, (Boolean, Boolean) => IDX_DTYPE, - (List(_), _) if right_type.is_numeric() => { - try_get_list_super_type(&left_field.dtype, &right_type)? + (list_dtype @ List(_), prim_dtype) if prim_dtype.is_primitive() => { + list_dtype.cast_leaf(try_get_supertype(list_dtype.leaf_dtype(), prim_dtype)?) }, - (_, List(_)) if left_field.dtype.is_numeric() => { - try_get_supertype(&left_field.dtype, right_type.leaf_dtype())? + (prim_dtype, list_dtype @ List(_)) if prim_dtype.is_primitive() => { + list_dtype.cast_leaf(try_get_supertype(list_dtype.leaf_dtype(), prim_dtype)?) }, (left, right) => try_get_supertype(left, right)?, } @@ -447,13 +441,15 @@ fn get_arithmetic_field( polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type) }, }, - (List(_), _) if right_type.is_numeric() => { - let dtype = try_get_list_super_type(&left_field.dtype, &right_type)?; + (list_dtype @ List(_), prim_dtype) if prim_dtype.is_primitive() => { + let dtype = list_dtype + .cast_leaf(try_get_supertype(list_dtype.leaf_dtype(), prim_dtype)?); left_field.coerce(dtype); return Ok(left_field); }, - (_, List(_)) if left_field.dtype.is_numeric() => { - let dtype = try_get_supertype(&left_field.dtype, right_type.leaf_dtype())?; + (prim_dtype, list_dtype @ List(_)) if prim_dtype.is_primitive() => { + let dtype = list_dtype + .cast_leaf(try_get_supertype(list_dtype.leaf_dtype(), prim_dtype)?); left_field.coerce(dtype); return Ok(left_field); }, diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 20fad09c583c..5c261cf2978d 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -47,65 +47,6 @@ fn is_cat_str_binary(type_left: &DataType, type_right: &DataType) -> bool { } } -fn process_list_arithmetic( - type_left: DataType, - type_right: DataType, - node_left: Node, - node_right: Node, - op: Operator, - expr_arena: &mut Arena, -) -> PolarsResult> { - match (&type_left, &type_right) { - (DataType::List(_), _) => { - let leaf = type_left.leaf_dtype(); - if type_right != *leaf { - let new_dtype = if type_right.is_nested() { - type_left.cast_leaf(leaf.clone()) - } else { - leaf.clone() - }; - let new_node_right = expr_arena.add(AExpr::Cast { - expr: node_right, - dtype: new_dtype, - options: CastOptions::NonStrict, - }); - - Ok(Some(AExpr::BinaryExpr { - left: node_left, - op, - right: new_node_right, - })) - } else { - Ok(None) - } - }, - (_, DataType::List(_)) => { - let leaf = type_right.leaf_dtype(); - if type_left != *leaf { - let new_dtype = if type_left.is_nested() { - type_right.cast_leaf(leaf.clone()) - } else { - leaf.clone() - }; - let new_node_left = expr_arena.add(AExpr::Cast { - expr: node_left, - dtype: new_dtype, - options: CastOptions::NonStrict, - }); - - Ok(Some(AExpr::BinaryExpr { - left: new_node_left, - op, - right: node_right, - })) - } else { - Ok(None) - } - }, - _ => unreachable!(), - } -} - #[cfg(feature = "dtype-struct")] // Ensure we don't cast to supertype // otherwise we will fill a struct with null fields @@ -275,11 +216,6 @@ pub(super) fn process_binary( (String, a) | (a, String) if a.is_numeric() => { polars_bail!(InvalidOperation: "arithmetic on string and numeric not allowed, try an explicit cast first") }, - (List(_), _) | (_, List(_)) => { - return process_list_arithmetic( - type_left, type_right, node_left, node_right, op, expr_arena, - ) - }, (Datetime(_, _), _) | (_, Datetime(_, _)) | (Date, _) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 8ff7821c57e1..a965b1616abf 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -780,6 +780,18 @@ def test_list_and_numeric_arithmetic_same_size( ) +def test_list_add_supertype() -> None: + a = pl.Series("a", [[1], [2]], dtype=pl.List(pl.Int8)) + b = pl.Series("b", [[1], [999]], dtype=pl.List(pl.Int64)) + + df = pl.DataFrame([a, b]) + + assert_series_equal( + df.select(x=pl.col("a") + pl.col("b")).to_series(), + pl.Series("x", [[2], [1001]], dtype=pl.List(pl.Int64)), + ) + + @pytest.mark.parametrize( ("a", "b", "expected"), [ From cbc2de19930267e7fc7675dc90f9e20027c11026 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 2 Oct 2024 16:19:22 +1000 Subject: [PATCH 09/11] make `Op` generic? --- .../src/series/arithmetic/list_borrowed.rs | 72 +++++++++++++------ 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 4a6d717e3ef5..7fc6625842fd 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -64,31 +64,57 @@ enum Op { } impl Op { - /// Apply the operation to a pair of Series. - fn apply_with_series(&self, lhs: &Series, rhs: &Series) -> PolarsResult { - use Op::*; - - match self { - Add => lhs + rhs, - Subtract => lhs - rhs, - Multiply => lhs * rhs, - Divide => lhs / rhs, - Remainder => lhs % rhs, + fn apply(&self, lhs: T, rhs: U) -> >::Output + where + T: Add + Sub + Mul + Div + Rem, + { + { + // This should be all const, optimized away + assert_eq!( + [core::mem::align_of::<>::Output>(); 4], + [ + core::mem::align_of::<>::Output>(), + core::mem::align_of::<>::Output>(), + core::mem::align_of::<>::Output>(), + core::mem::align_of::<>::Output>(), + ] + ); } - } - /// Apply the operation to a Series and scalar. - fn apply_with_scalar(&self, lhs: &Series, rhs: T) -> Series { - use Op::*; + { + // Safety: All operations return the same type + macro_rules! wrap { + ($e:expr) => { + unsafe { core::mem::transmute_copy(&$e) } + }; + } - match self { - Add => lhs + rhs, - Subtract => lhs - rhs, - Multiply => lhs * rhs, - Divide => lhs / rhs, - Remainder => lhs % rhs, + use Op::*; + match self { + Add => lhs + rhs, + Subtract => wrap!(lhs - rhs), + Multiply => wrap!(lhs * rhs), + Divide => wrap!(lhs / rhs), + Remainder => wrap!(lhs % rhs), + } } } + + // Apply the operation to a pair of Series. + // fn apply(&self, lhs: &Series, rhs: T) -> <&Series as Add>::Output + // where + // for<'a> &'a Series: Add, + // { + // use Op::*; + + // match self { + // Add => lhs + rhs, + // Subtract => lhs - rhs, + // Multiply => lhs * rhs, + // Divide => lhs / rhs, + // Remainder => lhs % rhs, + // } + // } } impl ListChunked { @@ -112,7 +138,7 @@ impl ListChunked { return Ok(None); }; let a = a_owner.as_ref().rechunk(); - let leaf_result = op.apply_with_scalar(&a.get_leaf_array(), b); + let leaf_result = op.apply(&a.get_leaf_array(), b); let result = reshape_list_based_on(&leaf_result.chunks()[0], &a.chunks()[0]); Ok(Some(result)) @@ -207,7 +233,7 @@ impl ListChunked { // along. a_listchunked.arithm_helper(b, op, Some(true)) } else { - op.apply_with_series(a, b) + op.apply(a, b) }; chunk_result.map(Some) }).collect::>>>()?; @@ -228,7 +254,7 @@ impl ListChunked { InvalidOperation: "can only do arithmetic operations on lists of the same size" ); - let result = op.apply_with_series(&l_leaf_array, &r_leaf_array)?; + let result = op.apply(&l_leaf_array, &r_leaf_array)?; // We now need to wrap the Arrow arrays with the metadata that turns // them into lists: // TODO is there a way to do this without cloning the underlying data? From 3d146086abc6d7de337be1df666a0c1648df8c66 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 2 Oct 2024 16:19:51 +1000 Subject: [PATCH 10/11] remove unused --- .../src/series/arithmetic/list_borrowed.rs | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index 7fc6625842fd..f6022a3db4cf 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -99,22 +99,6 @@ impl Op { } } } - - // Apply the operation to a pair of Series. - // fn apply(&self, lhs: &Series, rhs: T) -> <&Series as Add>::Output - // where - // for<'a> &'a Series: Add, - // { - // use Op::*; - - // match self { - // Add => lhs + rhs, - // Subtract => lhs - rhs, - // Multiply => lhs * rhs, - // Divide => lhs / rhs, - // Remainder => lhs % rhs, - // } - // } } impl ListChunked { From e916b56ce91fb8d1205c92cdb73fe4970aa7139e Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 2 Oct 2024 16:35:40 +1000 Subject: [PATCH 11/11] transmute_copy soundness --- crates/polars-core/src/series/arithmetic/list_borrowed.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs index f6022a3db4cf..76ee0d98f5b8 100644 --- a/crates/polars-core/src/series/arithmetic/list_borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -85,7 +85,9 @@ impl Op { // Safety: All operations return the same type macro_rules! wrap { ($e:expr) => { - unsafe { core::mem::transmute_copy(&$e) } + // Safety: This performs a `Copy`, but `$e` could be a `Series`, + // so we need to wrap in `ManuallyDrop` to avoid double-free. + unsafe { core::mem::transmute_copy(&core::mem::ManuallyDrop::new($e)) } }; }