Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement single inequality joins for join_where #18727

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 242 additions & 11 deletions crates/polars-ops/src/frame/join/iejoin/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
mod filtered_bit_array;
mod l1_l2;

use std::cmp::min;

use filtered_bit_array::FilteredBitArray;
use l1_l2::*;
use polars_core::chunked_array::ChunkedArray;
use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType};
use polars_core::frame::DataFrame;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::{_set_partition_size, split};
use polars_core::{with_match_physical_numeric_polars_type, POOL};
use polars_error::{polars_err, PolarsResult};
use polars_utils::binary_search::ExponentialSearch;
use polars_utils::itertools::Itertools;
use polars_utils::slice::GetSaferUnchecked;
use polars_utils::total_ord::TotalEq;
use polars_utils::total_ord::{TotalEq, TotalOrd};
use polars_utils::IdxSize;
use rayon::prelude::*;
#[cfg(feature = "serde")]
Expand All @@ -40,7 +43,7 @@ impl InequalityOperator {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct IEJoinOptions {
pub operator1: InequalityOperator,
pub operator2: InequalityOperator,
pub operator2: Option<InequalityOperator>,
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -61,10 +64,7 @@ fn ie_join_impl_t<T: PolarsNumericType>(
let mut left_row_idx: Vec<IdxSize> = vec![];
let mut right_row_idx: Vec<IdxSize> = vec![];

let slice_end = match slice {
Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),
_ => None,
};
let slice_end = slice_end_index(slice);
let mut match_count = 0;

let ca: &ChunkedArray<T> = x.as_ref().as_ref();
Expand Down Expand Up @@ -130,6 +130,78 @@ fn ie_join_impl_t<T: PolarsNumericType>(
Ok((left_row_idx, right_row_idx))
}

fn piecewise_merge_join_impl_t<T, P>(
slice: Option<(i64, usize)>,
left_order: Option<&[IdxSize]>,
right_order: Option<&[IdxSize]>,
left_ordered: Series,
right_ordered: Series,
mut pred: P,
) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>
where
T: PolarsNumericType,
P: FnMut(&T::Native, &T::Native) -> bool,
{
let slice_end = slice_end_index(slice);

let mut left_row_idx: Vec<IdxSize> = vec![];
let mut right_row_idx: Vec<IdxSize> = vec![];

let left_ca: &ChunkedArray<T> = left_ordered.as_ref().as_ref();
let right_ca: &ChunkedArray<T> = right_ordered.as_ref().as_ref();

debug_assert!(left_order.is_none_or(|order| order.len() == left_ca.len()));
debug_assert!(right_order.is_none_or(|order| order.len() == right_ca.len()));

let mut left_idx = 0;
let mut right_idx = 0;
let mut match_count = 0;

while left_idx < left_ca.len() {
debug_assert!(left_ca.get(left_idx).is_some());
let left_val = unsafe { left_ca.value_unchecked(left_idx) };
while right_idx < right_ca.len() {
debug_assert!(right_ca.get(right_idx).is_some());
let right_val = unsafe { right_ca.value_unchecked(right_idx) };
if pred(&left_val, &right_val) {
// If the predicate is true, then it will also be true for all
// remaining rows from the right side.
let left_row = match left_order {
None => left_idx as IdxSize,
Some(order) => order[left_idx],
};
let right_end_idx = match slice_end {
None => right_ca.len(),
Some(end) => min(right_ca.len(), (end as usize) - match_count + right_idx),
};
for included_right_row_idx in right_idx..right_end_idx {
let right_row = match right_order {
None => included_right_row_idx as IdxSize,
Some(order) => order[included_right_row_idx],
};
left_row_idx.push(left_row);
right_row_idx.push(right_row);
}
match_count += right_end_idx - right_idx;
break;
} else {
right_idx += 1;
}
}
if right_idx == right_ca.len() {
// We've reached the end of the right side
// so there can be no more matches for LHS rows
break;
}
if slice_end.is_some_and(|end| match_count >= end as usize) {
break;
}
left_idx += 1;
}

Ok((left_row_idx, right_row_idx))
}

pub(super) fn iejoin_par(
left: &DataFrame,
right: &DataFrame,
Expand Down Expand Up @@ -206,7 +278,7 @@ pub(super) fn iejoin_par(
};

if include_block {
let (l, r) = unsafe {
let (mut l, mut r) = unsafe {
(
selected_left
.iter()
Expand All @@ -218,9 +290,21 @@ pub(super) fn iejoin_par(
.collect_vec(),
)
};
let sorted_flag = if l1_descending {
IsSorted::Descending
} else {
IsSorted::Ascending
};
// We sorted using the first series
l[0].set_sorted_flag(sorted_flag);
r[0].set_sorted_flag(sorted_flag);

// Compute the row indexes
let (idx_l, idx_r) = iejoin_tuples(l, r, options, None)?;
let (idx_l, idx_r) = if options.operator2.is_some() {
iejoin_tuples(l, r, options, None)
} else {
piecewise_merge_join_tuples(l, r, options, None)
}?;

if idx_l.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -264,8 +348,11 @@ pub(super) fn iejoin(
suffix: Option<PlSmallStr>,
slice: Option<(i64, usize)>,
) -> PolarsResult<DataFrame> {
let (left_row_idx, right_row_idx) =
iejoin_tuples(selected_left, selected_right, options, slice)?;
let (left_row_idx, right_row_idx) = if options.operator2.is_some() {
iejoin_tuples(selected_left, selected_right, options, slice)
} else {
piecewise_merge_join_tuples(selected_left, selected_right, options, slice)
}?;
unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) }
}

Expand Down Expand Up @@ -308,7 +395,12 @@ fn iejoin_tuples(
};

let op1 = options.operator1;
let op2 = options.operator2;
let op2 = match options.operator2 {
None => {
return Err(polars_err!(ComputeError: "IEJoin requires two inequality operators"));
},
Some(op2) => op2,
};

// Determine the sort order based on the comparison operators used.
// We want to sort L1 so that "x[i] op1 x[j]" is true for j > i,
Expand Down Expand Up @@ -381,3 +473,142 @@ fn iejoin_tuples(
};
Ok((left_row_idx, right_row_idx))
}

/// Piecewise merge join, for joins with only a single inequality.
fn piecewise_merge_join_tuples(
selected_left: Vec<Series>,
selected_right: Vec<Series>,
options: &IEJoinOptions,
slice: Option<(i64, usize)>,
) -> PolarsResult<(IdxCa, IdxCa)> {
if selected_left.len() != 1 {
return Err(
polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the left DataFrame"),
);
};
if selected_right.len() != 1 {
return Err(
polars_err!(ComputeError: "Piecewise merge join requires exactly one expression from the right DataFrame"),
);
};
if options.operator2.is_some() {
return Err(
polars_err!(ComputeError: "Piecewise merge join expects only one inequality operator"),
);
}

let op = options.operator1;
// The left side is sorted such that if the condition is false, it will also
// be false for the same RHS row and all following LHS rows.
// The right side is sorted such that if the condition is true then it is also
// true for the same LHS row and all following RHS rows.
// The desired sort order should match the l1 order used in iejoin_par
// so we don't need to re-sort slices when doing a parallel join.
let descending = matches!(op, InequalityOperator::Gt | InequalityOperator::GtEq);

let left = selected_left[0].to_physical_repr().into_owned();
let mut right = selected_right[0].to_physical_repr().into_owned();
let must_cast = right.dtype().matches_schema_type(left.dtype())?;
if must_cast {
right = right.cast(left.dtype())?;
}

fn get_sorted(series: Series, descending: bool) -> (Series, Option<IdxCa>) {
let expected_flag = if descending {
IsSorted::Descending
} else {
IsSorted::Ascending
};
if (series.is_sorted_flag() == expected_flag || series.len() <= 1) && !series.has_nulls() {
// Fast path, no need to re-sort
(series, None)
} else {
let sort_options = SortOptions::default()
.with_nulls_last(false)
.with_order_descending(descending);

// Get order and slice to ignore any null values, which cannot be match results
let order = series
.arg_sort(sort_options)
.slice(
series.null_count() as i64,
series.len() - series.null_count(),
)
.rechunk();
let ordered = unsafe { series.take_unchecked(&order) };
(ordered, Some(order))
}
}

let (left_ordered, left_order) = get_sorted(left, descending);
debug_assert!(left_order
.as_ref()
.is_none_or(|order| order.chunks().len() == 1));
let left_order = left_order
.as_ref()
.map(|order| order.downcast_get(0).unwrap().values().as_slice());

let (right_ordered, right_order) = get_sorted(right, descending);
debug_assert!(right_order
.as_ref()
.is_none_or(|order| order.chunks().len() == 1));
let right_order = right_order
.as_ref()
.map(|order| order.downcast_get(0).unwrap().values().as_slice());

let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(left_ordered.dtype(), |$T| {
match op {
InequalityOperator::Lt => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_lt(r),
),
InequalityOperator::LtEq => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_le(r),
),
InequalityOperator::Gt => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_gt(r),
),
InequalityOperator::GtEq => piecewise_merge_join_impl_t::<$T, _>(
slice,
left_order,
right_order,
left_ordered,
right_ordered,
|l, r| l.tot_ge(r),
),
}
})?;

debug_assert_eq!(left_row_idx.len(), right_row_idx.len());
let left_row_idx = IdxCa::from_vec("".into(), left_row_idx);
let right_row_idx = IdxCa::from_vec("".into(), right_row_idx);
let (left_row_idx, right_row_idx) = match slice {
None => (left_row_idx, right_row_idx),
Some((offset, len)) => (
left_row_idx.slice(offset, len),
right_row_idx.slice(offset, len),
),
};
Ok((left_row_idx, right_row_idx))
}

fn slice_end_index(slice: Option<(i64, usize)>) -> Option<i64> {
match slice {
Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)),
_ => None,
}
}
37 changes: 21 additions & 16 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,12 @@ fn resolve_join_where(
&suffix,
);
join_node
}
// TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case.
else if ie_right_on.len() >= 2 {
} else if ie_right_on.len() >= 2 {
// Do an IEjoin.
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: ie_op[1],
operator2: Some(ie_op[1]),
});

let join_node = resolve_join(
Expand All @@ -390,31 +388,38 @@ fn resolve_join_where(
remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix))
}
join_node
} else if ie_right_on.len() == 1 {
// For a single inequality comparison, we use the piecewise merge join algorithm
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: None,
});

resolve_join(
Either::Right(input_left),
Either::Right(input_right),
ie_left_on,
ie_right_on,
vec![],
options.clone(),
ctxt,
)?
} else {
// No predicates found that are supported in a fast algorithm.
// Do a cross join and follow up with filters.
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

let join_node = resolve_join(
resolve_join(
Either::Right(input_left),
Either::Right(input_right),
vec![],
vec![],
vec![],
options.clone(),
ctxt,
)?;
// TODO: This can be removed once we support the single IEjoin.
ie_predicates_to_remaining(
&mut remaining_preds,
ie_left_on,
ie_right_on,
ie_op,
&schema_right,
&suffix,
);
join_node
)?
};

let IR::Join {
Expand Down
Loading