Skip to content

Commit

Permalink
[FEAT] Support null safe equal join in native execution
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Nov 5, 2024
1 parent c86e2ff commit 4e7489c
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 16 deletions.
10 changes: 5 additions & 5 deletions src/daft-core/src/utils/dyn_compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ pub fn build_dyn_compare(

pub fn build_dyn_multi_array_compare(
schema: &Schema,
nulls_equal: bool,
nans_equal: bool,
nulls_equal: &[bool],
nans_equal: &[bool],
) -> DaftResult<MultiDynArrayComparator> {
let mut fn_list = Vec::with_capacity(schema.len());
for field in schema.fields.values() {
for (idx, field) in schema.fields.values().enumerate() {
fn_list.push(build_dyn_compare(
&field.dtype,
&field.dtype,
nulls_equal,
nans_equal,
nulls_equal[idx],
nans_equal[idx],
)?);
}
let combined_fn = Box::new(
Expand Down
9 changes: 7 additions & 2 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ pub fn physical_plan_to_pipeline(
right,
left_on,
right_on,
null_equals_null,
join_type,
schema,
}) => {
Expand Down Expand Up @@ -368,9 +369,13 @@ pub fn physical_plan_to_pipeline(
.zip(key_schema.fields.values())
.map(|(e, f)| e.clone().cast(&f.dtype))
.collect::<Vec<_>>();

// we should move to a builder pattern
let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?;
let build_sink = HashJoinBuildSink::new(
key_schema,
casted_build_on,
null_equals_null.clone(),
join_type,
)?;
let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?;
let build_node =
BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed();
Expand Down
11 changes: 10 additions & 1 deletion src/daft-local-execution/src/sinks/hash_join_build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ impl ProbeTableState {
fn new(
key_schema: &SchemaRef,
projection: Vec<ExprRef>,
nulls_equal_aware: Option<&Vec<bool>>,
join_type: &JoinType,
) -> DaftResult<Self> {
let track_indices = !matches!(join_type, JoinType::Anti | JoinType::Semi);
Ok(Self::Building {
probe_table_builder: Some(make_probeable_builder(key_schema.clone(), track_indices)?),
probe_table_builder: Some(make_probeable_builder(
key_schema.clone(),
nulls_equal_aware,
track_indices,
)?),
projection,
tables: Vec::new(),
})
Expand Down Expand Up @@ -83,18 +88,21 @@ impl BlockingSinkState for ProbeTableState {
pub struct HashJoinBuildSink {
key_schema: SchemaRef,
projection: Vec<ExprRef>,
nulls_equal_aware: Option<Vec<bool>>,
join_type: JoinType,
}

impl HashJoinBuildSink {
pub(crate) fn new(
key_schema: SchemaRef,
projection: Vec<ExprRef>,
nulls_equal_aware: Option<Vec<bool>>,
join_type: &JoinType,
) -> DaftResult<Self> {
Ok(Self {
key_schema,
projection,
nulls_equal_aware,
join_type: *join_type,
})
}
Expand Down Expand Up @@ -144,6 +152,7 @@ impl BlockingSink for HashJoinBuildSink {
Ok(Box::new(ProbeTableState::new(
&self.key_schema,
self.projection.clone(),
self.nulls_equal_aware.as_ref(),
&self.join_type,
)?))
}
Expand Down
3 changes: 3 additions & 0 deletions src/daft-physical-plan/src/local_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ impl LocalPhysicalPlan {
right: LocalPhysicalPlanRef,
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
null_equals_null: Option<Vec<bool>>,
join_type: JoinType,
schema: SchemaRef,
) -> LocalPhysicalPlanRef {
Expand All @@ -267,6 +268,7 @@ impl LocalPhysicalPlan {
right,
left_on,
right_on,
null_equals_null,
join_type,
schema,
})
Expand Down Expand Up @@ -449,6 +451,7 @@ pub struct HashJoin {
pub right: LocalPhysicalPlanRef,
pub left_on: Vec<ExprRef>,
pub right_on: Vec<ExprRef>,
pub null_equals_null: Option<Vec<bool>>,
pub join_type: JoinType,
pub schema: SchemaRef,
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-physical-plan/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult<LocalPhysicalPlanRef> {
right,
join.left_on.clone(),
join.right_on.clone(),
join.null_equals_nulls.clone(),
join.join_type,
join.output_schema.clone(),
))
Expand Down
11 changes: 9 additions & 2 deletions src/daft-table/src/probeable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@ struct ArrowTableEntry(Vec<Box<dyn arrow2::array::Array>>);

pub fn make_probeable_builder(
schema: SchemaRef,
nulls_equal_aware: Option<&Vec<bool>>,
track_indices: bool,
) -> DaftResult<Box<dyn ProbeableBuilder>> {
if track_indices {
Ok(Box::new(ProbeTableBuilder(ProbeTable::new(schema)?)))
Ok(Box::new(ProbeTableBuilder(ProbeTable::new(
schema,
nulls_equal_aware,
)?)))
} else {
Ok(Box::new(ProbeSetBuilder(ProbeSet::new(schema)?)))
Ok(Box::new(ProbeSetBuilder(ProbeSet::new(
schema,
nulls_equal_aware,
)?)))
}
}

Expand Down
20 changes: 17 additions & 3 deletions src/daft-table/src/probeable/probe_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
sync::Arc,
};

use common_error::DaftResult;
use common_error::{DaftError, DaftResult};
use daft_core::{
array::ops::as_arrow::AsArrow,
prelude::SchemaRef,
Expand Down Expand Up @@ -31,12 +31,26 @@ impl ProbeSet {

const DEFAULT_SIZE: usize = 20;

pub(crate) fn new(schema: SchemaRef) -> DaftResult<Self> {
pub(crate) fn new(
schema: SchemaRef,
nulls_equal_aware: Option<&Vec<bool>>,
) -> DaftResult<Self> {
let hash_table = HashMap::<IndexHash, (), IdentityBuildHasher>::with_capacity_and_hasher(
Self::DEFAULT_SIZE,
Default::default(),
);
let compare_fn = build_dyn_multi_array_compare(&schema, false, false)?;
if let Some(null_equal_aware) = nulls_equal_aware {
if null_equal_aware.len() != schema.len() {
return Err(DaftError::InternalError(
format!("null_equal_aware should have the same length as the schema. Expected: {}, Found: {}",
schema.len(), null_equal_aware.len())));
}
}
let default_nulls_equal = vec![false; schema.len()];
let nulls_equal = nulls_equal_aware.unwrap_or_else(|| default_nulls_equal.as_ref());
let nans_equal = &vec![false; schema.len()];
let compare_fn =
build_dyn_multi_array_compare(&schema, nulls_equal.as_slice(), nans_equal.as_slice())?;
Ok(Self {
schema,
hash_table,
Expand Down
17 changes: 14 additions & 3 deletions src/daft-table/src/probeable/probe_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
sync::Arc,
};

use common_error::DaftResult;
use common_error::{DaftError, DaftResult};
use daft_core::{
array::ops::as_arrow::AsArrow,
prelude::SchemaRef,
Expand Down Expand Up @@ -32,13 +32,24 @@ impl ProbeTable {

const DEFAULT_SIZE: usize = 20;

pub(crate) fn new(schema: SchemaRef) -> DaftResult<Self> {
pub(crate) fn new(schema: SchemaRef, null_equal_aware: Option<&Vec<bool>>) -> DaftResult<Self> {
let hash_table =
HashMap::<IndexHash, Vec<u64>, IdentityBuildHasher>::with_capacity_and_hasher(
Self::DEFAULT_SIZE,
Default::default(),
);
let compare_fn = build_dyn_multi_array_compare(&schema, false, false)?;
if let Some(null_equal_aware) = null_equal_aware {
if null_equal_aware.len() != schema.len() {
return Err(DaftError::InternalError(
format!("null_equal_aware should have the same length as the schema. Expected: {}, Found: {}",
schema.len(), null_equal_aware.len())));
}
}
let default_nulls_equal = vec![false; schema.len()];
let nulls_equal = null_equal_aware.unwrap_or_else(|| default_nulls_equal.as_ref());
let nans_equal = &vec![false; schema.len()];
let compare_fn =
build_dyn_multi_array_compare(&schema, nulls_equal.as_slice(), nans_equal.as_slice())?;
Ok(Self {
schema,
hash_table,
Expand Down

0 comments on commit 4e7489c

Please sign in to comment.