Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cc #18

Closed
wants to merge 1 commit into from
Closed

Cc #18

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if self.is_sorted_any() {
else if self.null_count() == 0 {
Some(0)
} else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.null_count()
Expand All @@ -256,7 +258,9 @@ impl<T: PolarsDataType> ChunkedArray<T> {
None
}
// We now know there is at least 1 non-null item in the array, and self.len() > 0
else if self.is_sorted_any() {
else if self.null_count() == 0 {
Some(self.len() - 1)
} else if self.is_sorted_any() {
let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } {
// nulls are all at the start
self.len() - 1
Expand Down
64 changes: 48 additions & 16 deletions crates/polars-core/src/chunked_array/ops/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,17 @@ where
}
},
(true, true) => {
// both arrays have non-null values
if !ca.is_sorted_any()
|| !other.is_sorted_any()
|| ca.is_sorted_flag() != other.is_sorted_flag()
// both arrays have non-null values.
// for arrays of unit length we can ignore the sorted flag, as it is
// not necessarily set.
if !(ca.is_sorted_any() || ca.len() == 1)
|| !(other.is_sorted_any() || other.len() == 1)
|| !(
// We will coerce for single values
ca.len() - ca.null_count() == 1
|| other.len() - other.null_count() == 1
|| ca.is_sorted_flag() == other.is_sorted_flag()
)
{
IsSorted::Not
} else {
Expand All @@ -68,7 +75,7 @@ where
let l_val = unsafe { ca.value_unchecked(l_idx) };
let r_val = unsafe { other.value_unchecked(r_idx) };

let keep_sorted =
let null_pos_check =
// check null positions
// lhs does not end in nulls
(1 + l_idx == ca.len())
Expand All @@ -77,18 +84,43 @@ where
// if there are nulls, they are all on one end
&& !(ca.first_non_null().unwrap() != 0 && 1 + other.last_non_null().unwrap() != other.len());

let keep_sorted = keep_sorted
// compare values
&& if ca.is_sorted_ascending_flag() {
l_val.tot_le(&r_val)
} else {
l_val.tot_ge(&r_val)
};

if keep_sorted {
ca.is_sorted_flag()
} else {
if !null_pos_check {
IsSorted::Not
} else {
#[allow(unused_assignments)]
let mut out = IsSorted::Not;

#[allow(clippy::never_loop)]
loop {
match (
ca.len() - ca.null_count() == 1,
other.len() - other.null_count() == 1,
) {
(true, true) => {
out = [IsSorted::Descending, IsSorted::Ascending]
[l_val.tot_le(&r_val) as usize];
break;
},
(true, false) => out = other.is_sorted_flag(),
_ => out = ca.is_sorted_flag(),
}

debug_assert!(!matches!(out, IsSorted::Not));

let check = if matches!(out, IsSorted::Ascending) {
l_val.tot_le(&r_val)
} else {
l_val.tot_ge(&r_val)
};

if !check {
out = IsSorted::Not
}

break;
}

out
}
}
},
Expand Down
91 changes: 85 additions & 6 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
from polars.testing import assert_frame_equal, assert_series_equal


def is_sorted_any(s: pl.Series) -> bool:
return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"]


def is_not_sorted(s: pl.Series) -> bool:
return not is_sorted_any(s)


def test_sort_dates_multiples() -> None:
df = pl.DataFrame(
[
Expand Down Expand Up @@ -799,12 +807,6 @@ def test_sorted_flag_14552() -> None:


def test_sorted_flag_concat_15072() -> None:
def is_sorted_any(s: pl.Series) -> bool:
return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"]

def is_not_sorted(s: pl.Series) -> bool:
return not is_sorted_any(s)

# Both all-null
a = pl.Series("x", [None, None], dtype=pl.Int8)
b = pl.Series("x", [None, None], dtype=pl.Int8)
Expand Down Expand Up @@ -903,3 +905,80 @@ def is_not_sorted(s: pl.Series) -> bool:
out = pl.concat((s, s.clear()))
assert_series_equal(out, s)
assert out.flags["SORTED_ASC"]


@pytest.mark.parametrize("unit_descending", [True, False])
def test_sorted_flag_concat_unit(unit_descending: bool) -> None:
unit = pl.Series([1]).set_sorted(descending=unit_descending)

a = unit
b = pl.Series([2, 3]).set_sorted()

out = pl.concat((a, b))
assert out.to_list() == [1, 2, 3]
assert out.flags["SORTED_ASC"]

out = pl.concat((b, a))
assert out.to_list() == [2, 3, 1]
assert is_not_sorted(out)

a = unit
b = pl.Series([3, 2]).set_sorted(descending=True)

out = pl.concat((a, b))
assert out.to_list() == [1, 3, 2]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [3, 2, 1]
assert out.flags["SORTED_DESC"]

# unit with nulls first
unit = pl.Series([None, 1]).set_sorted(descending=unit_descending)

a = unit
b = pl.Series([2, 3]).set_sorted()

out = pl.concat((a, b))
assert out.to_list() == [None, 1, 2, 3]
assert out.flags["SORTED_ASC"]

out = pl.concat((b, a))
assert out.to_list() == [2, 3, None, 1]
assert is_not_sorted(out)

a = unit
b = pl.Series([3, 2]).set_sorted(descending=True)

out = pl.concat((a, b))
assert out.to_list() == [None, 1, 3, 2]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [3, 2, None, 1]
assert is_not_sorted(out)

# unit with nulls last
unit = pl.Series([1, None]).set_sorted(descending=unit_descending)

a = unit
b = pl.Series([2, 3]).set_sorted()

out = pl.concat((a, b))
assert out.to_list() == [1, None, 2, 3]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [2, 3, 1, None]
assert is_not_sorted(out)

a = unit
b = pl.Series([3, 2]).set_sorted(descending=True)

out = pl.concat((a, b))
assert out.to_list() == [1, None, 3, 2]
assert is_not_sorted(out)

out = pl.concat((b, a))
assert out.to_list() == [3, 2, 1, None]
assert out.flags["SORTED_DESC"]
Loading