Skip to content

Commit

Permalink
perf: Coerce sorted flag during concat on unit arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Mar 17, 2024
1 parent be92470 commit 846f60b
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 24 deletions.
8 changes: 6 additions & 2 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if self.is_sorted_any() {
else if self.null_count() == 0 {
Some(0)
} else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.null_count()
Expand All @@ -256,7 +258,9 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if self.is_sorted_any() {
else if self.null_count() == 0 {
Some(self.len() - 1)
} else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.len() - 1
Expand Down
64 changes: 48 additions & 16 deletions crates/polars-core/src/chunked_array/ops/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,17 @@ where
}
},
(true, true) => {
// both arrays have non-null values
if !ca.is_sorted_any()
|| !other.is_sorted_any()
|| ca.is_sorted_flag() != other.is_sorted_flag()
// both arrays have non-null values.
// for arrays of unit length we can ignore the sorted flag, as it is
// not necessarily set.
if !(ca.is_sorted_any() || ca.len() == 1)
|| !(other.is_sorted_any() || other.len() == 1)
|| !(
// We will coerce for single values
ca.len() - ca.null_count() == 1
|| other.len() - other.null_count() == 1
|| ca.is_sorted_flag() == other.is_sorted_flag()
)
{
IsSorted::Not
} else {
Expand All @@ -68,7 +75,7 @@ where
let l_val = unsafe { ca.value_unchecked(l_idx) };
let r_val = unsafe { other.value_unchecked(r_idx) };

let keep_sorted =
let null_pos_check =
// check null positions
// lhs does not end in nulls
(1 + l_idx == ca.len())
Expand All @@ -77,18 +84,43 @@ where
// if there are nulls, they are all on one end
&& !(ca.first_non_null().unwrap() != 0 && 1 + other.last_non_null().unwrap() != other.len());

let keep_sorted = keep_sorted
// compare values
&& if ca.is_sorted_ascending_flag() {
l_val.tot_le(&r_val)
} else {
l_val.tot_ge(&r_val)
};

if keep_sorted {
ca.is_sorted_flag()
} else {
if !null_pos_check {
IsSorted::Not
} else {
#[allow(unused_assignments)]
let mut out = IsSorted::Not;

#[allow(clippy::never_loop)]
loop {
match (
ca.len() - ca.null_count() == 1,
other.len() - other.null_count() == 1,
) {
(true, true) => {
out = [IsSorted::Descending, IsSorted::Ascending]
[l_val.tot_le(&r_val) as usize];
break;
},
(true, false) => out = other.is_sorted_flag(),
_ => out = ca.is_sorted_flag(),
}

debug_assert!(!matches!(out, IsSorted::Not));

let check = if matches!(out, IsSorted::Ascending) {
l_val.tot_le(&r_val)
} else {
l_val.tot_ge(&r_val)
};

if !check {
out = IsSorted::Not
}

break;
}

out
}
}
},
Expand Down
91 changes: 85 additions & 6 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
from polars.testing import assert_frame_equal, assert_series_equal


def is_sorted_any(s: pl.Series) -> bool:
return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"]


def is_not_sorted(s: pl.Series) -> bool:
return not is_sorted_any(s)


def test_sort_dates_multiples() -> None:
df = pl.DataFrame(
[
Expand Down Expand Up @@ -799,12 +807,6 @@ def test_sorted_flag_14552() -> None:


def test_sorted_flag_concat_15072() -> None:
def is_sorted_any(s: pl.Series) -> bool:
return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"]

def is_not_sorted(s: pl.Series) -> bool:
return not is_sorted_any(s)

# Both all-null
a = pl.Series("x", [None, None], dtype=pl.Int8)
b = pl.Series("x", [None, None], dtype=pl.Int8)
Expand Down Expand Up @@ -903,3 +905,80 @@ def is_not_sorted(s: pl.Series) -> bool:
out = pl.concat((s, s.clear()))
assert_series_equal(out, s)
assert out.flags["SORTED_ASC"]


@pytest.mark.parametrize("unit_descending", [True, False])
def test_sorted_flag_concat_unit(unit_descending: bool) -> None:
unit = pl.Series([1]).set_sorted(descending=unit_descending)

a = unit
b = pl.Series([2, 3]).set_sorted()

out = pl.concat((a, b))
assert out.to_list() == [1, 2, 3]
assert out.flags["SORTED_ASC"]

out = pl.concat((b, a))
assert out.to_list() == [2, 3, 1]
assert is_not_sorted(out)

a = unit
b = pl.Series([3, 2]).set_sorted(descending=True)

out = pl.concat((a, b))
assert out.to_list() == [1, 3, 2]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [3, 2, 1]
assert out.flags["SORTED_DESC"]

# unit with nulls first
unit = pl.Series([None, 1]).set_sorted(descending=unit_descending)

a = unit
b = pl.Series([2, 3]).set_sorted()

out = pl.concat((a, b))
assert out.to_list() == [None, 1, 2, 3]
assert out.flags["SORTED_ASC"]

out = pl.concat((b, a))
assert out.to_list() == [2, 3, None, 1]
assert is_not_sorted(out)

a = unit
b = pl.Series([3, 2]).set_sorted(descending=True)

out = pl.concat((a, b))
assert out.to_list() == [None, 1, 3, 2]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [3, 2, None, 1]
assert is_not_sorted(out)

# unit with nulls last
unit = pl.Series([1, None]).set_sorted(descending=unit_descending)

a = unit
b = pl.Series([2, 3]).set_sorted()

out = pl.concat((a, b))
assert out.to_list() == [1, None, 2, 3]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [2, 3, 1, None]
assert is_not_sorted(out)

a = unit
b = pl.Series([3, 2]).set_sorted(descending=True)

out = pl.concat((a, b))
assert out.to_list() == [1, None, 3, 2]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [3, 2, 1, None]
assert out.flags["SORTED_DESC"]

0 comments on commit 846f60b

Please sign in to comment.