Skip to content

Commit

Permalink
fix(rust, python): fix groupby rolling by_key if groups are empty (#6333
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ritchie46 authored Jan 20, 2023
1 parent 843c603 commit a7fffe3
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 8 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ once_cell = "1"

[workspace.dependencies.arrow]
package = "arrow2"
git = "https://github.com/jorgecarleitao/arrow2"
# git = "https://github.com/ritchie46/arrow2"
rev = "218b7cf93a1fe713bd4fd4641cfc53a3c10977aa"
# git = "https://github.com/jorgecarleitao/arrow2"
git = "https://github.com/ritchie46/arrow2"
# rev = "218b7cf93a1fe713bd4fd4641cfc53a3c10977aa"
# path = "../arrow2"
# branch = "polars_2022-12-30"
branch = "mmap_slice"
version = "0.15"
default-features = false
features = [
Expand Down
17 changes: 17 additions & 0 deletions polars/polars-core/src/chunked_array/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,21 @@ where
out.compute_len();
out
}

/// Create a temporary [`ChunkedArray`] from a slice.
///
/// # Safety
/// The lifetime will be bound to the lifetime of the slice.
/// This will not be checked by the borrowchecker.
pub unsafe fn borrowed_from_slice(name: &str, values: &[T::Native]) -> Self {
let arr = Box::new(PrimitiveArray::borrowed_from_slice(values));
let mut out = ChunkedArray {
field: Arc::new(Field::new(name, T::get_dtype())),
chunks: vec![arr],
phantom: PhantomData,
..Default::default()
};
out.compute_len();
out
}
}
9 changes: 9 additions & 0 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,15 @@ impl Series {
Ok(self.finish_take_threaded(series?, rechunk))
}

/// Take by index if ChunkedArray contains a single chunk.
///
/// # Safety
/// This doesn't check any bounds. Null validity is checked.
pub unsafe fn take_unchecked_from_slice(&self, idx: &[IdxSize]) -> PolarsResult<Series> {
let idx = IdxCa::borrowed_from_slice("", idx);
self.take_unchecked(&idx)
}

/// Take by index if ChunkedArray contains a single chunk.
///
/// # Safety
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-io/src/ipc/ipc_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ where
fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
let mut ipc_writer = write::FileWriter::try_new(
&mut self.writer,
&df.schema().to_arrow(),
df.schema().to_arrow(),
None,
WriteOptions {
compression: self.compression.map(|c| c.into()),
Expand Down
19 changes: 17 additions & 2 deletions polars/polars-lazy/src/physical_plan/executors/groupby_rolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,23 @@ impl GroupByRollingExec {
// the ordering has changed due to the groupby
if !keys.is_empty() {
unsafe {
for key in keys.iter_mut() {
*key = key.agg_first(groups);
match groups {
GroupsProxy::Idx(groups) => {
let first = groups.first();
// we don't use agg_first here, because the group
// can be empty, but we still want to know the first value
// of that group
for key in keys.iter_mut() {
*key = key.take_unchecked_from_slice(first).unwrap();
}
}
GroupsProxy::Slice { groups, .. } => {
for key in keys.iter_mut() {
let iter = &mut groups.iter().map(|[first, _len]| *first as usize)
as &mut dyn TakeIterator<Item = usize>;
*key = key.take_iter_unchecked(iter);
}
}
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,3 +2542,27 @@ def test_datetime_cum_agg_schema() -> None:
datetime(2023, 1, 4, 0, 0),
],
}


def test_rolling_groupby_empty_groups_by_take_6330() -> None:
df = pl.DataFrame({"Event": ["Rain", "Sun"]}).join(
pl.DataFrame(
{
"Date": [1, 2, 3, 4],
}
),
how="cross",
)
assert (
df.groupby_rolling(
index_column="Date",
period="2i",
offset="-2i",
by="Event",
closed="left",
).agg([pl.count()])
).to_dict(False) == {
"Event": ["Rain", "Rain", "Rain", "Rain", "Sun", "Sun", "Sun", "Sun"],
"Date": [1, 2, 3, 4, 1, 2, 3, 4],
"count": [0, 1, 2, 2, 0, 1, 2, 2],
}

0 comments on commit a7fffe3

Please sign in to comment.