Skip to content

Commit

Permalink
perf: Improve Bitmap construction performance (#15570)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Apr 10, 2024
1 parent c2bcd12 commit a919601
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 68 deletions.
65 changes: 27 additions & 38 deletions crates/polars-arrow/src/bitmap/bitmap_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,33 @@ use super::Bitmap;
use crate::bitmap::MutableBitmap;
use crate::trusted_len::TrustedLen;

/// Creates a [Vec<u8>] from an [`Iterator`] of [`BitChunk`].
/// # Safety
/// The iterator must be [`TrustedLen`].
pub unsafe fn from_chunk_iter_unchecked<T: BitChunk, I: Iterator<Item = T>>(
iterator: I,
) -> Vec<u8> {
let (_, upper) = iterator.size_hint();
let upper = upper.expect("try_from_trusted_len_iter requires an upper limit");
let len = upper * std::mem::size_of::<T>();

let mut buffer = Vec::with_capacity(len);

let mut dst = buffer.as_mut_ptr();
for item in iterator {
let bytes = item.to_ne_bytes();
for i in 0..std::mem::size_of::<T>() {
std::ptr::write(dst, bytes[i]);
dst = dst.add(1);
}
}
assert_eq!(
dst.offset_from(buffer.as_ptr()) as usize,
len,
"Trusted iterator length was not accurately reported"
);
buffer.set_len(len);
buffer
#[inline(always)]
pub(crate) fn push_bitchunk<T: BitChunk>(buffer: &mut Vec<u8>, value: T) {
buffer.extend(value.to_ne_bytes())
}

/// Creates a [`Vec<u8>`] from a [`TrustedLen`] of [`BitChunk`].
pub fn chunk_iter_to_vec<T: BitChunk, I: TrustedLen<Item = T>>(iter: I) -> Vec<u8> {
unsafe { from_chunk_iter_unchecked(iter) }
let cap = iter.size_hint().0 * std::mem::size_of::<T>();
let mut buffer = Vec::with_capacity(cap);
for v in iter {
push_bitchunk(&mut buffer, v)
}
buffer
}

fn chunk_iter_to_vec_and_remainder<T: BitChunk, I: TrustedLen<Item = T>>(
iter: I,
remainder: T,
) -> Vec<u8> {
let cap = (iter.size_hint().0 + 1) * std::mem::size_of::<T>();
let mut buffer = Vec::with_capacity(cap);
for v in iter {
push_bitchunk(&mut buffer, v)
}
push_bitchunk(&mut buffer, remainder);
debug_assert_eq!(buffer.len(), cap);
buffer
}

/// Apply a bitwise operation `op` to four inputs and return the result as a [`Bitmap`].
Expand All @@ -62,9 +57,8 @@ where
.zip(a3_chunks)
.zip(a4_chunks)
.map(|(((a1, a2), a3), a4)| op(a1, a2, a3, a4));
let buffer =
chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3, rem_a4))));

let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3, rem_a4));
let length = a1.len();

Bitmap::from_u8_vec(buffer, length)
Expand All @@ -90,8 +84,7 @@ where
.zip(a3_chunks)
.map(|((a1, a2), a3)| op(a1, a2, a3));

let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3))));

let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3));
let length = a1.len();

Bitmap::from_u8_vec(buffer, length)
Expand All @@ -112,8 +105,7 @@ where
.zip(rhs_chunks)
.map(|(left, right)| op(left, right));

let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_lhs, rem_rhs))));

let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_lhs, rem_rhs));
let length = lhs.len();

Bitmap::from_u8_vec(buffer, length)
Expand All @@ -125,10 +117,7 @@ where
F: Fn(u64) -> u64,
{
let rem = op(iter.remainder());

let iterator = iter.map(op).chain(std::iter::once(rem));

let buffer = chunk_iter_to_vec(iterator);
let buffer = chunk_iter_to_vec_and_remainder(iter.map(op), rem);

Bitmap::from_u8_vec(buffer, length)
}
Expand Down
60 changes: 59 additions & 1 deletion crates/polars-arrow/src/compute/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::borrow::Borrow;
use std::ops::{BitAnd, BitOr};

use polars_error::{polars_ensure, PolarsResult};

use crate::array::Array;
use crate::bitmap::{and_not, ternary, Bitmap};
use crate::bitmap::{and_not, push_bitchunk, ternary, Bitmap};

pub fn combine_validities_and3(
opt1: Option<&Bitmap>,
Expand Down Expand Up @@ -49,6 +50,63 @@ pub fn combine_validities_and_not(
}
}

pub fn combine_validities_and_many<B: Borrow<Bitmap>>(bitmaps: &[Option<B>]) -> Option<Bitmap> {
let mut bitmaps = bitmaps
.iter()
.flatten()
.map(|b| b.borrow())
.collect::<Vec<_>>();

match bitmaps.len() {
0 => None,
1 => bitmaps.pop().cloned(),
2 => combine_validities_and(bitmaps.pop(), bitmaps.pop()),
3 => combine_validities_and3(bitmaps.pop(), bitmaps.pop(), bitmaps.pop()),
_ => {
let mut iterators = bitmaps
.iter()
.map(|v| v.fast_iter_u64())
.collect::<Vec<_>>();
let mut buffer = Vec::with_capacity(iterators.first().unwrap().size_hint().0 + 2);

'rows: loop {
// All ones so as identity for & operation
let mut out = u64::MAX;
for iter in iterators.iter_mut() {
if let Some(v) = iter.next() {
out &= v
} else {
break 'rows;
}
}
push_bitchunk(&mut buffer, out);
}

// All ones so as identity for & operation
let mut out = [u64::MAX, u64::MAX];
let mut len = 0;
for iter in iterators.into_iter() {
let (rem, rem_len) = iter.remainder();
len = rem_len;

for (out, rem) in out.iter_mut().zip(rem) {
*out &= rem;
}
}
push_bitchunk(&mut buffer, out[0]);
if len > 64 {
push_bitchunk(&mut buffer, out[1]);
}
let bitmap = Bitmap::from_u8_vec(buffer, bitmaps[0].len());
if bitmap.unset_bits() == bitmap.len() {
None
} else {
Some(bitmap)
}
},
}
}

// Errors iff the two arrays have a different length.
#[inline]
pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> PolarsResult<()> {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-arrow/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ pub trait NativeType:
+ std::ops::IndexMut<usize, Output = u8>
+ for<'a> TryFrom<&'a [u8]>
+ std::fmt::Debug
+ Default;
+ Default
+ IntoIterator<Item = u8>;

/// To bytes in little endian
fn to_le_bytes(&self) -> Self::Bytes;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::compute::utils::combine_validities_and;
use arrow::compute::utils::combine_validities_and_many;
use compare_inner::NullOrderCmp;
use polars_row::{convert_columns, EncodingField, RowsEncoded};
use polars_utils::iter::EnumerateIdxTrait;
Expand Down Expand Up @@ -121,7 +121,7 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls(
.collect::<Vec<_>>();
let rows = _get_rows_encoded_unordered(&sliced)?;

let validity = sliced
let validities = sliced
.iter()
.flat_map(|s| {
let s = s.rechunk();
Expand All @@ -131,7 +131,9 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls(
.into_iter()
.map(|arr| arr.validity().cloned())
})
.fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref()));
.collect::<Vec<_>>();

let validity = combine_validities_and_many(&validities);
Ok(rows.into_array().with_validity_typed(validity))
});
let chunks = POOL.install(|| chunks.collect::<PolarsResult<Vec<_>>>());
Expand Down
10 changes: 2 additions & 8 deletions crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ use polars_utils::hashing::BytesHash;
use rayon::prelude::*;

use super::IntoDf;
const LHS_NAME: &str = "POLARS_K_L";
const RHS_NAME: &str = "POLARS_K_R";

pub trait DataFrameJoinOps: IntoDf {
/// Generic join method. Can be used to join on multiple columns.
Expand Down Expand Up @@ -260,12 +258,8 @@ pub trait DataFrameJoinOps: IntoDf {
};
}

let lhs_keys = prepare_keys_multiple(&selected_left, args.join_nulls)?
.into_series()
.with_name(LHS_NAME);
let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)?
.into_series()
.with_name(RHS_NAME);
let lhs_keys = prepare_keys_multiple(&selected_left, args.join_nulls)?.into_series();
let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)?.into_series();
let names_right = selected_right.iter().map(|s| s.name()).collect::<Vec<_>>();

// Multiple keys.
Expand Down
17 changes: 4 additions & 13 deletions crates/polars-ops/src/series/ops/fused.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow::array::PrimitiveArray;
use arrow::compute::utils::combine_validities_and;
use arrow::compute::utils::combine_validities_and3;
use polars_core::prelude::*;
use polars_core::utils::align_chunks_ternary;
use polars_core::with_match_physical_numeric_polars_type;
Expand All @@ -11,10 +11,7 @@ fn fma_arr<T: NumericNative>(
c: &PrimitiveArray<T>,
) -> PrimitiveArray<T> {
assert_eq!(a.len(), b.len());
let validity = combine_validities_and(
combine_validities_and(a.validity(), b.validity()).as_ref(),
c.validity(),
);
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
let a = a.values().as_slice();
let b = b.values().as_slice();
let c = c.values().as_slice();
Expand Down Expand Up @@ -65,10 +62,7 @@ fn fsm_arr<T: NumericNative>(
c: &PrimitiveArray<T>,
) -> PrimitiveArray<T> {
assert_eq!(a.len(), b.len());
let validity = combine_validities_and(
combine_validities_and(a.validity(), b.validity()).as_ref(),
c.validity(),
);
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
let a = a.values().as_slice();
let b = b.values().as_slice();
let c = c.values().as_slice();
Expand Down Expand Up @@ -118,10 +112,7 @@ fn fms_arr<T: NumericNative>(
c: &PrimitiveArray<T>,
) -> PrimitiveArray<T> {
assert_eq!(a.len(), b.len());
let validity = combine_validities_and(
combine_validities_and(a.validity(), b.validity()).as_ref(),
c.validity(),
);
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
let a = a.values().as_slice();
let b = b.values().as_slice();
let c = c.values().as_slice();
Expand Down
9 changes: 5 additions & 4 deletions crates/polars-pipe/src/executors/sinks/joins/row_values.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use arrow::array::{ArrayRef, BinaryArray, StaticArray};
use arrow::compute::utils::combine_validities_and;
use arrow::compute::utils::combine_validities_and_many;
use polars_core::error::PolarsResult;
use polars_row::RowsEncoded;

Expand Down Expand Up @@ -80,11 +80,12 @@ impl RowValues {
Ok(if join_nulls {
array
} else {
let validity = self
let validities = self
.join_columns_material
.iter()
.map(|arr| arr.validity().cloned())
.fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref()));
.map(|arr| arr.validity())
.collect::<Vec<_>>();
let validity = combine_validities_and_many(&validities);
array.with_validity_typed(validity)
})
}
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,3 +835,25 @@ def test_join_list_non_numeric() -> None:
"lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]],
"count": [1, 2, 1],
}


@pytest.mark.slow()
def test_join_4_columns_with_validity() -> None:
# join on 4 columns so we trigger combine validities
# use 138 as that is 2 u64 and a remainder
a = pl.DataFrame(
{"a": [None if a % 6 == 0 else a for a in range(138)]}
).with_columns(
b=pl.col("a"),
c=pl.col("a"),
d=pl.col("a"),
)

assert a.join(a, on=["a", "b", "c", "d"], how="inner", join_nulls=True).shape == (
644,
4,
)
assert a.join(a, on=["a", "b", "c", "d"], how="inner", join_nulls=False).shape == (
115,
4,
)

0 comments on commit a919601

Please sign in to comment.