-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Improved list arithmetic support #19162
Conversation
let mut new_left_dtype = left_dtype.cast_leaf(leaf_super_dtype.clone()); | ||
let mut new_right_dtype = right_dtype.cast_leaf(leaf_super_dtype); | ||
|
||
// Cast List<->Array to List |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently already cast list<->array to list as a supertype from #12016
@@ -327,6 +328,26 @@ pub trait SeriesTrait: | |||
/// Aggregate all chunks to a contiguous array of memory. | |||
fn rechunk(&self) -> Series; | |||
|
|||
fn rechunk_validity(&self) -> Option<Bitmap> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied from impl<T> ChunkedArray<T>
so that it can be used without downcasting
@@ -47,55 +47,6 @@ fn is_cat_str_binary(type_left: &DataType, type_right: &DataType) -> bool { | |||
} | |||
} | |||
|
|||
fn process_list_arithmetic( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove IR checking / modifying logic from type_coercion - I think this should already be handled by get_arithmetic_field
, get_truediv_field
etc.
@@ -1085,7 +1087,8 @@ def _div(self, other: Any, *, floordiv: bool) -> DataFrame: | |||
int_casts = [ | |||
col(column).cast(tp) | |||
for i, (column, tp) in enumerate(self.schema.items()) | |||
if tp.is_integer() and orig_dtypes[i].is_integer() | |||
if tp.is_integer() | |||
and (orig_dtypes[i].is_integer() or orig_dtypes[i] == Null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a fix after something I did broke this test case -
polars/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Lines 287 to 288 in 2e0438b
assert_frame_equal(df_expected, op(df, None)) | |
assert_series_equal(s_expected, op(s, None)) |
as_float = self._recursive_cast_to_dtype(Float64()) | ||
|
||
return as_float._arithmetic(other, "div", "div_<>") | ||
return self._arithmetic(other, "div", "div_<>") |
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
), | ||
], | ||
) | ||
def test_list_arithmetic_same_size( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check these tests back in later to reduce the size of this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just add comment saying "no need to review these functions, they're cut/pasted from test_arithmetic.py"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tip - I realized I can actually leave those tests here for now instead of removing them
72e5b6e
to
12494af
Compare
BROADCAST_SERIES_COMBINATIONS, | ||
) | ||
@pytest.mark.parametrize("exec_op", EXEC_OP_COMBINATIONS) | ||
def test_list_arithmetic_values( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Magical test case parametrization that runs every codepath in BinaryListNumericOpHelper
😄
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #19162 +/- ##
==========================================
+ Coverage 79.67% 79.72% +0.04%
==========================================
Files 1532 1533 +1
Lines 209200 209915 +715
Branches 2417 2415 -2
==========================================
+ Hits 166687 167357 +670
- Misses 41965 42010 +45
Partials 548 548 ☔ View full report in Codecov by Sentry. |
3bf866d
to
09b0d7f
Compare
for o in others { | ||
let slc = o.as_slice(); | ||
l = slc[l].to_usize(); | ||
r = slc[r].to_usize(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not 100% sure, but I feel like this should be r + 1
. Might be completely wrong though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be correct as it is - Arrow list offsets are defined as
1st row : offsets[0]..offsets[1]
2nd row : offsets[1]..offsets[2]
..and so on
for o in &offsets[1..] { | ||
let slc = o.as_slice(); | ||
l = slc[l].to_usize(); | ||
r = slc[r].to_usize(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idem.
| (_, Time) | ||
| (_, Date) | ||
| (_, Datetime(_, _)) => polars_bail!(opq = div, self.dtype(), rhs.dtype()), | ||
_ => match (self.dtype(), rhs.dtype()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why is this a nested match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to remove
// list lengths. | ||
let mut mismatch_pos = 0; | ||
|
||
with_match_numeric_list_op!(&self.op, self.swapped, |$OP| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kinda feel like this should not be here. Instead, this should be in polars-compute
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to move the entire file, but I ran into some difficulty as polars-compute
doesn't have access to the Series
struct - I think we can leave it here for now?
I also don't want to split out parts of the logic in this file as it isn't really used anywhere else
|
||
/// Reduce monomorphization | ||
#[inline(never)] | ||
fn combine_validities_list_to_list_no_broadcast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I want to keep this here for now - it's a very specialized function that's only used in this file. If we need to use it somewhere else later I can move it then
@@ -56,6 +57,25 @@ impl Series { | |||
} | |||
} | |||
|
|||
/// TODO: Move this somewhere else? | |||
pub fn list_offsets_and_validities_recursive( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This always outputs Vecs with 1 element.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily - it returns Vecs with a number of elements corresponding to the nesting level - so e.g. for a List(List(Float64))
-
print((pl.Series([[[1]]]) / pl.Series([1])).dtype)
List(List(Float64))
[crates/polars-core/src/series/arithmetic/list_borrowed.rs:93:17] lhs.list_offsets_and_validities_recursive().0 = [
OffsetsBuffer(
[
0,
1,
],
),
OffsetsBuffer(
[
0,
1,
],
),
]
09b0d7f
to
2ea1524
Compare
This comment has been minimized.
This comment has been minimized.
This comment was marked as outdated.
This comment was marked as outdated.
1662a10
to
8e7554b
Compare
Thanks for doing this! |
Introduces numeric list kernels that operate directly on the list offsets and leaf arrays.
Notes
Fixes #19010
Fixes #19025