diff --git a/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs b/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs index 0c1bec7acf03..5e84fa954445 100644 --- a/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs +++ b/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs @@ -31,6 +31,7 @@ use common_datavalues::DataField; use common_datavalues::DataSchemaRef; use common_datavalues::DataSchemaRefExt; use common_datavalues::DataTypeImpl; +use common_datavalues::DataValue; use common_exception::Result; use common_hashtable::HashMap; use common_hashtable::HashTableKeyable; @@ -49,6 +50,7 @@ use crate::sessions::QueryContext; use crate::sql::exec::ColumnID; use crate::sql::exec::PhysicalScalar; use crate::sql::planner::plans::JoinType; +use crate::sql::IndexType; pub struct SerializerHashTable { pub(crate) hash_table: HashMap>, @@ -101,21 +103,31 @@ pub enum HashTable { KeyU512HashTable(KeyU512HashTable), } -pub struct JoinHashTable { - /// Reference count - ref_count: Mutex, - is_finished: Mutex, +#[derive(Clone, Eq, PartialEq, Debug)] +pub enum MarkerKind { + True, + False, + Null, +} +pub struct HashJoinDesc { pub(crate) build_keys: Vec>, pub(crate) probe_keys: Vec>, + pub(crate) join_type: JoinType, + pub(crate) other_predicate: Option>, + pub(crate) marker: RwLock>, + pub(crate) marker_index: Option, +} +pub struct JoinHashTable { pub(crate) ctx: Arc, - + /// Reference count + ref_count: Mutex, + is_finished: Mutex, /// A shared big hash table stores all the rows from build side pub(crate) hash_table: RwLock, pub(crate) row_space: RowSpace, - pub(crate) join_type: JoinType, - pub(crate) other_predicate: Option>, + pub(crate) hash_join_desc: HashJoinDesc, } impl JoinHashTable { @@ -126,122 +138,112 @@ impl JoinHashTable { probe_keys: &[PhysicalScalar], other_predicate: Option<&PhysicalScalar>, build_schema: DataSchemaRef, + marker_index: Option, ) -> Result> { let hash_key_types: Vec = build_keys.iter().map(|expr| expr.data_type()).collect(); let method = DataBlock::choose_hash_method_with_types(&hash_key_types)?; + let hash_join_desc = HashJoinDesc { + build_keys: build_keys + .iter() + .map(Evaluator::eval_physical_scalar) + .collect::>()?, + probe_keys: probe_keys + .iter() + .map(Evaluator::eval_physical_scalar) + .collect::>()?, + join_type, + other_predicate: other_predicate + .map(Evaluator::eval_physical_scalar) + .transpose()?, + marker: RwLock::new(vec![]), + marker_index, + }; Ok(match method { HashMethodKind::SingleString(_) | HashMethodKind::Serializer(_) => { Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::SerializerHashTable(SerializerHashTable { hash_table: HashMap::>::create(), hash_method: HashMethodSerializer::default(), }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?) } HashMethodKind::KeysU8(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU8HashTable(KeyU8HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?), HashMethodKind::KeysU16(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU16HashTable(KeyU16HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?), HashMethodKind::KeysU32(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU32HashTable(KeyU32HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?), HashMethodKind::KeysU64(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU64HashTable(KeyU64HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?), HashMethodKind::KeysU128(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU128HashTable(KeyU128HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?), HashMethodKind::KeysU256(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU256HashTable(KeyU256HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?), HashMethodKind::KeysU512(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU512HashTable(KeyU512HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, + hash_join_desc, )?), }) } pub fn try_create( ctx: Arc, - join_type: JoinType, hash_table: HashTable, - build_keys: &[PhysicalScalar], - probe_keys: &[PhysicalScalar], - other_predicate: Option<&PhysicalScalar>, mut build_data_schema: DataSchemaRef, + hash_join_desc: HashJoinDesc, ) -> Result { - if join_type == JoinType::Left { + if hash_join_desc.join_type == JoinType::Left { let mut nullable_field = Vec::with_capacity(build_data_schema.fields().len()); for field in build_data_schema.fields().iter() { nullable_field.push(DataField::new_nullable( @@ -255,20 +257,9 @@ impl JoinHashTable { row_space: RowSpace::new(build_data_schema), ref_count: Mutex::new(0), is_finished: Mutex::new(false), - build_keys: build_keys - .iter() - .map(Evaluator::eval_physical_scalar) - .collect::>()?, - probe_keys: probe_keys - .iter() - .map(Evaluator::eval_physical_scalar) - .collect::>()?, - other_predicate: other_predicate - .map(Evaluator::eval_physical_scalar) - .transpose()?, + hash_join_desc, ctx, hash_table: RwLock::new(hash_table), - join_type, }) } @@ -339,6 +330,7 @@ impl JoinHashTable { ) -> Result> { let func_ctx = self.ctx.try_get_function_context()?; let probe_keys = self + .hash_join_desc .probe_keys .iter() .map(|expr| Ok(expr.eval(&func_ctx, input)?.vector().clone())) @@ -406,6 +398,7 @@ impl HashJoinState for JoinHashTable { fn build(&self, input: DataBlock) -> Result<()> { let func_ctx = self.ctx.try_get_function_context()?; let build_cols = self + .hash_join_desc .build_keys .iter() .map(|expr| Ok(expr.eval(&func_ctx, &input)?.vector().clone())) @@ -428,12 +421,12 @@ impl HashJoinState for JoinHashTable { } fn probe(&self, input: &DataBlock, probe_state: &mut ProbeState) -> Result> { - match self.join_type { - JoinType::Inner | JoinType::Semi | JoinType::Anti | JoinType::Left => { + match self.hash_join_desc.join_type { + JoinType::Inner | JoinType::Semi | JoinType::Anti | JoinType::Left | JoinType::Mark => { self.probe_join(input, probe_state) } JoinType::Cross => self.probe_cross_join(input, probe_state), - _ => unimplemented!("{} is unimplemented", self.join_type), + _ => unimplemented!("{} is unimplemented", self.hash_join_desc.join_type), } } @@ -461,13 +454,24 @@ impl HashJoinState for JoinHashTable { } fn finish(&self) -> Result<()> { - let chunks = self.row_space.chunks.write().unwrap(); + let chunks = self.row_space.chunks.read().unwrap(); + let mut marker = self.hash_join_desc.marker.write(); for chunk_index in 0..chunks.len() { let chunk = &chunks[chunk_index]; let mut columns = vec![]; if let Some(cols) = chunk.cols.as_ref() { columns = Vec::with_capacity(cols.len()); for col in cols.iter() { + if self.hash_join_desc.join_type == JoinType::Mark { + assert_eq!(cols.len(), 1); + for row_idx in 0..col.len() { + if col.get(row_idx) == DataValue::Null { + marker.push(MarkerKind::Null); + } else { + marker.push(MarkerKind::False); + } + } + } columns.push(col); } } diff --git a/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs b/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs index 723d3b0c2d1d..6dc5a0ded9f5 100644 --- a/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs +++ b/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs @@ -16,10 +16,16 @@ use common_arrow::arrow::bitmap::Bitmap; use common_arrow::arrow::bitmap::MutableBitmap; use common_datablocks::DataBlock; use common_datavalues::BooleanColumn; +use common_datavalues::BooleanType; use common_datavalues::Column; use common_datavalues::ColumnRef; +use common_datavalues::DataField; +use common_datavalues::DataSchema; +use common_datavalues::DataSchemaRef; use common_datavalues::NullableColumn; +use common_datavalues::NullableType; use common_datavalues::Series; +use common_exception::ErrorCode; use common_exception::Result; use common_hashtable::HashMap; use common_hashtable::HashTableKeyable; @@ -27,9 +33,11 @@ use common_hashtable::HashTableKeyable; use super::JoinHashTable; use super::ProbeState; use crate::common::EvalNode; +use crate::pipelines::new::processors::transforms::hash_join::join_hash_table::MarkerKind; use crate::pipelines::new::processors::transforms::hash_join::row::RowPtr; use crate::sql::exec::ColumnID; use crate::sql::planner::plans::JoinType; +use crate::sql::plans::JoinType::Mark; impl JoinHashTable { pub(crate) fn result_blocks( @@ -46,7 +54,7 @@ impl JoinHashTable { let build_indexs = &mut probe_state.build_indexs; let mut results: Vec = vec![]; - match self.join_type { + match self.hash_join_desc.join_type { JoinType::Inner => { for (i, key) in keys.iter().enumerate() { let probe_result_ptr = hash_table.find_key(key); @@ -67,7 +75,7 @@ impl JoinHashTable { let probe_block = DataBlock::block_take_by_indices(input, probe_indexs)?; let merged_block = self.merge_eq_block(&build_block, &probe_block)?; - match &self.other_predicate { + match &self.hash_join_desc.other_predicate { Some(other_predicate) => { let func_ctx = self.ctx.try_get_function_context()?; let filter_vector = other_predicate.eval(&func_ctx, &merged_block)?; @@ -80,7 +88,7 @@ impl JoinHashTable { } } JoinType::Semi => { - if self.other_predicate.is_none() { + if self.hash_join_desc.other_predicate.is_none() { let result = self.semi_anti_join::(hash_table, probe_state, keys, input)?; return Ok(vec![result]); @@ -95,7 +103,7 @@ impl JoinHashTable { } } JoinType::Anti => { - if self.other_predicate.is_none() { + if self.hash_join_desc.other_predicate.is_none() { let result = self.semi_anti_join::(hash_table, probe_state, keys, input)?; return Ok(vec![result]); @@ -112,7 +120,7 @@ impl JoinHashTable { // probe_blocks left join build blocks JoinType::Left => { - if self.other_predicate.is_none() { + if self.hash_join_desc.other_predicate.is_none() { let result = self.left_join::(hash_table, probe_state, keys, input)?; return Ok(vec![result]); @@ -121,6 +129,67 @@ impl JoinHashTable { return Ok(vec![result]); } } + Mark => { + let mut has_null = false; + // `probe_column` is the subquery result column. + // For sql: select * from t1 where t1.a in (select t2.a from t2); t2.a is the `probe_column`, + let probe_column = input.column(0); + // Check if there is any null in the probe column. + if let Some(validity) = probe_column.validity().1 { + if validity.null_count() > 0 { + has_null = true; + } + } + for key in keys.iter() { + let probe_result_ptr = hash_table.find_key(key); + if let Some(v) = probe_result_ptr { + let probe_result_ptrs = v.get_value(); + let mut marker = self.hash_join_desc.marker.write(); + for ptr in probe_result_ptrs { + // If find join partner, set the marker to true. + marker[ptr.row_index as usize] = MarkerKind::True; + } + } + } + let mut marker = self.hash_join_desc.marker.write(); + let mut validity = MutableBitmap::new(); + let mut boolean_bit_map = MutableBitmap::new(); + for m in marker.iter_mut() { + if m == &mut MarkerKind::False && has_null { + *m = MarkerKind::Null; + } + if m == &mut MarkerKind::Null { + validity.push(false); + } else { + validity.push(true); + } + if m == &mut MarkerKind::True { + boolean_bit_map.push(true); + } else { + boolean_bit_map.push(false); + } + } + // transfer marker to a Nullable(BooleanColumn) + let boolean_column = BooleanColumn::from_arrow_data(boolean_bit_map.into()); + let marker_column = Self::set_validity(&boolean_column.arc(), &validity.into())?; + let marker_schema = DataSchema::new(vec![DataField::new( + &self + .hash_join_desc + .marker_index + .ok_or_else(|| ErrorCode::LogicalError("Invalid mark join"))? + .to_string(), + NullableType::new_impl(BooleanType::new_impl()), + )]); + let marker_block = + DataBlock::create(DataSchemaRef::from(marker_schema), vec![marker_column]); + let build_indexs = &mut probe_state.build_indexs; + for entity in hash_table.iter() { + build_indexs.extend_from_slice(entity.get_value()); + } + let build_block = self.row_space.gather(build_indexs)?; + let result = self.merge_eq_block(&marker_block, &build_block)?; + results.push(result); + } _ => unreachable!(), } Ok(results) @@ -208,8 +277,10 @@ impl JoinHashTable { let build_block = self.row_space.gather(build_indexs)?; let merged_block = self.merge_eq_block(&build_block, &probe_block)?; - let (bm, all_true, all_false) = - self.get_other_filters(&merged_block, self.other_predicate.as_ref().unwrap())?; + let (bm, all_true, all_false) = self.get_other_filters( + &merged_block, + self.hash_join_desc.other_predicate.as_ref().unwrap(), + )?; let mut bm = match (bm, all_true, all_false) { (Some(b), _, _) => b.into_mut().right().unwrap(), @@ -297,8 +368,10 @@ impl JoinHashTable { return Ok(merged_block); } - let (bm, all_true, all_false) = - self.get_other_filters(&merged_block, self.other_predicate.as_ref().unwrap())?; + let (bm, all_true, all_false) = self.get_other_filters( + &merged_block, + self.hash_join_desc.other_predicate.as_ref().unwrap(), + )?; if all_true { return Ok(merged_block); diff --git a/query/src/pipelines/new/processors/transforms/transform_apply.rs b/query/src/pipelines/new/processors/transforms/transform_apply.rs index 6384fa9124e0..b6e141ac8569 100644 --- a/query/src/pipelines/new/processors/transforms/transform_apply.rs +++ b/query/src/pipelines/new/processors/transforms/transform_apply.rs @@ -61,6 +61,7 @@ impl OuterRefRewriter { probe_keys, other_conditions, join_type, + marker_index, } => Ok(PhysicalPlan::HashJoin { build: Box::new(self.rewrite_physical_plan(build)?), probe: Box::new(self.rewrite_physical_plan(probe)?), @@ -69,6 +70,7 @@ impl OuterRefRewriter { probe_keys: probe_keys.clone(), other_conditions: other_conditions.clone(), join_type: join_type.clone(), + marker_index: *marker_index, }), PhysicalPlan::Limit { input, diff --git a/query/src/sql/exec/physical_plan.rs b/query/src/sql/exec/physical_plan.rs index 36e838b7e68a..824c5c133aa3 100644 --- a/query/src/sql/exec/physical_plan.rs +++ b/query/src/sql/exec/physical_plan.rs @@ -17,17 +17,20 @@ use std::collections::BTreeSet; use common_datablocks::DataBlock; use common_datavalues::wrap_nullable; +use common_datavalues::BooleanType; use common_datavalues::DataField; use common_datavalues::DataSchemaRef; use common_datavalues::DataSchemaRefExt; use common_datavalues::DataTypeImpl; use common_datavalues::DataValue; +use common_datavalues::NullableType; use common_datavalues::ToDataType; use common_datavalues::Vu8; use common_exception::Result; use common_planners::ReadDataSourcePlan; use crate::sql::plans::JoinType; +use crate::sql::IndexType; pub type ColumnID = String; @@ -76,6 +79,7 @@ pub enum PhysicalPlan { probe_keys: Vec, other_conditions: Vec, join_type: JoinType, + marker_index: Option, }, CrossApply { input: Box, @@ -181,6 +185,15 @@ impl PhysicalPlan { // Do nothing } + JoinType::Mark => { + fields.clear(); + fields = build.output_schema()?.fields().clone(); + fields.push(DataField::new( + "marker", + NullableType::new_impl(BooleanType::new_impl()), + )); + } + _ => { for field in build.output_schema()?.fields() { fields.push(DataField::new( diff --git a/query/src/sql/exec/physical_plan_builder.rs b/query/src/sql/exec/physical_plan_builder.rs index 63c88d671669..cb36b276ce77 100644 --- a/query/src/sql/exec/physical_plan_builder.rs +++ b/query/src/sql/exec/physical_plan_builder.rs @@ -84,6 +84,7 @@ impl PhysicalPlanBuilder { builder.build(v) }) .collect::>()?, + marker_index: join.marker_index, }) } RelOperator::Project(project) => { diff --git a/query/src/sql/exec/pipeline_builder.rs b/query/src/sql/exec/pipeline_builder.rs index 40d0c2e2a4ac..f01e7947de12 100644 --- a/query/src/sql/exec/pipeline_builder.rs +++ b/query/src/sql/exec/pipeline_builder.rs @@ -58,6 +58,7 @@ use crate::sql::exec::PhysicalScalar; use crate::sql::exec::SortDesc; use crate::sql::plans::JoinType; use crate::sql::ColumnBinding; +use crate::sql::IndexType; #[derive(Default)] pub struct PipelineBuilder { @@ -149,6 +150,7 @@ impl PipelineBuilder { probe_keys, other_conditions, join_type, + marker_index, } => { let mut build_side_pipeline = NewPipeline::create(); let build_side_context = QueryContext::create_from(context.clone()); @@ -162,6 +164,7 @@ impl PipelineBuilder { probe_keys, other_conditions, join_type.clone(), + *marker_index, build_side_pipeline, pipeline, )?; @@ -491,6 +494,7 @@ impl PipelineBuilder { probe_keys: &[PhysicalScalar], other_conditions: &[PhysicalScalar], join_type: JoinType, + marker_index: Option, mut child_pipeline: NewPipeline, pipeline: &mut NewPipeline, ) -> Result<()> { @@ -524,6 +528,7 @@ impl PipelineBuilder { probe_keys, predicate.as_ref(), build_schema, + marker_index, )?; // Build side diff --git a/query/src/sql/optimizer/heuristic/decorrelate.rs b/query/src/sql/optimizer/heuristic/decorrelate.rs index 3f5a0f2868fa..c0b27e7ffcfd 100644 --- a/query/src/sql/optimizer/heuristic/decorrelate.rs +++ b/query/src/sql/optimizer/heuristic/decorrelate.rs @@ -159,6 +159,7 @@ pub fn try_decorrelate_subquery(input: &SExpr, subquery: &SubqueryExpr) -> Resul SubqueryType::Exists => JoinType::Semi, SubqueryType::NotExists => JoinType::Anti, }, + marker_index: None, }; // Rewrite plan to semi-join. diff --git a/query/src/sql/optimizer/heuristic/subquery_rewriter.rs b/query/src/sql/optimizer/heuristic/subquery_rewriter.rs index f834d1456269..b165bdcaaff4 100644 --- a/query/src/sql/optimizer/heuristic/subquery_rewriter.rs +++ b/query/src/sql/optimizer/heuristic/subquery_rewriter.rs @@ -14,6 +14,7 @@ use common_datavalues::BooleanType; use common_datavalues::DataValue; +use common_datavalues::NullableType; use common_exception::ErrorCode; use common_exception::Result; use common_functions::aggregates::AggregateFunctionFactory; @@ -45,12 +46,14 @@ use crate::sql::plans::Scalar; use crate::sql::plans::ScalarItem; use crate::sql::plans::SubqueryExpr; use crate::sql::plans::SubqueryType; +use crate::sql::IndexType; use crate::sql::MetadataRef; enum UnnestResult { Uncorrelated, Apply, SimpleJoin, // SemiJoin or AntiJoin + MarkJoin { marker_index: IndexType }, } /// Rewrite subquery into `Apply` operator @@ -142,6 +145,7 @@ impl SubqueryRewriter { right_conditions: vec![], other_conditions: vec![], join_type: JoinType::Cross, + marker_index: None, } .into(), UnnestResult::Uncorrelated, @@ -272,6 +276,7 @@ impl SubqueryRewriter { right_conditions: vec![], other_conditions: vec![], join_type: JoinType::Cross, + marker_index: None, } .into(), UnnestResult::Uncorrelated, @@ -291,6 +296,57 @@ impl SubqueryRewriter { result, )) } + SubqueryType::Any => { + let rel_expr = RelExpr::with_s_expr(&subquery.subquery); + let prop = rel_expr.derive_relational_prop()?; + let output_columns = prop.output_columns.clone(); + let index = output_columns + .iter() + .take(1) + .next() + .ok_or_else(|| ErrorCode::LogicalError("Invalid subquery"))?; + let column_name = format!("subquery_{}", index); + let left_condition = Scalar::BoundColumnRef(BoundColumnRef { + column: ColumnBinding { + database_name: None, + table_name: None, + column_name, + index: *index, + data_type: subquery.data_type.clone(), + visible_in_unqualified_wildcard: false, + }, + }); + if prop.outer_columns.is_empty() { + // Add a marker column to save comparison result. + // The column is Nullable(Boolean), the data value is TRUE, FALSE, or NULL. + // If subquery contains NULL, the comparison result is TRUE or NULL. Such as t1.a => {1, 3, 4}, select t1.a in (1, 2, NULL) from t1; The sql will return {true, null, null}. + // If subquery doesn't contain NULL, the comparison result is FALSE, TRUE, or NULL. + let marker_index = self.metadata.write().add_column( + "marker".to_string(), + NullableType::new_impl(BooleanType::new_impl()), + None, + ); + // Consider the sql: select * from t1 where t1.a = any(select t2.a from t2); + // Will be transferred to:select t1.a, t2.a, marker_index from t2, t1 where t2.a = t1.a; + // Note that subquery is the left table, and it'll be the probe side. + let mark_join = LogicalInnerJoin { + right_conditions: vec![*subquery.child_expr.as_ref().unwrap().clone()], + left_conditions: vec![left_condition], + other_conditions: vec![], + join_type: JoinType::Mark, + marker_index: Some(marker_index), + } + .into(); + Ok(( + SExpr::create_binary(mark_join, subquery.subquery.clone(), left.clone()), + UnnestResult::MarkJoin { marker_index }, + )) + } else { + Err(ErrorCode::LogicalError( + "Unsupported subquery type: Correlated AnySubquery", + )) + } + } _ => Err(ErrorCode::LogicalError(format!( "Unsupported subquery type: {:?}", &subquery.typ @@ -410,18 +466,25 @@ impl SubqueryRewriter { s_expr, )); } - - let rel_expr = RelExpr::with_s_expr(s_expr.child(1)?); + let rel_expr = if subquery.typ == SubqueryType::Any { + RelExpr::with_s_expr(s_expr.child(0)?) + } else { + RelExpr::with_s_expr(s_expr.child(1)?) + }; let prop = rel_expr.derive_relational_prop()?; // Extract the subquery and replace it with the ColumnBinding from it. - let index = *prop - .output_columns - .iter() - .take(1) - .next() - .ok_or_else(|| ErrorCode::LogicalError("Invalid subquery"))?; - let name = format!("subquery_{}", index); + let (index, name) = if let UnnestResult::MarkJoin { marker_index } = result { + (marker_index, "marker".to_string()) + } else { + let index = *prop + .output_columns + .iter() + .take(1) + .next() + .ok_or_else(|| ErrorCode::LogicalError("Invalid subquery"))?; + (index, format!("subquery_{}", index)) + }; let column_ref = ColumnBinding { database_name: None, table_name: None, diff --git a/query/src/sql/optimizer/rule/rule_implement_hash_join.rs b/query/src/sql/optimizer/rule/rule_implement_hash_join.rs index 16a37fc05938..93f7b53acf2d 100644 --- a/query/src/sql/optimizer/rule/rule_implement_hash_join.rs +++ b/query/src/sql/optimizer/rule/rule_implement_hash_join.rs @@ -69,6 +69,7 @@ impl Rule for RuleImplementHashJoin { probe_keys: logical_join.left_conditions, other_conditions: logical_join.other_conditions, join_type: logical_join.join_type, + marker_index: logical_join.marker_index, } .into(), expression.children().to_vec(), diff --git a/query/src/sql/planner/binder/join.rs b/query/src/sql/planner/binder/join.rs index ac2e6ae1347d..93d844ab3a53 100644 --- a/query/src/sql/planner/binder/join.rs +++ b/query/src/sql/planner/binder/join.rs @@ -189,6 +189,7 @@ impl<'a> Binder { right_conditions, other_conditions, join_type, + marker_index: None, }; let expr = SExpr::create_binary(inner_join.into(), left_child, right_child); diff --git a/query/src/sql/planner/plans/hash_join.rs b/query/src/sql/planner/plans/hash_join.rs index d89aee12bee4..babc22959849 100644 --- a/query/src/sql/planner/plans/hash_join.rs +++ b/query/src/sql/planner/plans/hash_join.rs @@ -20,6 +20,7 @@ use crate::sql::plans::Operator; use crate::sql::plans::PhysicalPlan; use crate::sql::plans::RelOp; use crate::sql::plans::Scalar; +use crate::sql::IndexType; #[derive(Clone, Debug)] pub struct PhysicalHashJoin { @@ -27,6 +28,7 @@ pub struct PhysicalHashJoin { pub probe_keys: Vec, pub other_conditions: Vec, pub join_type: JoinType, + pub marker_index: Option, } impl Operator for PhysicalHashJoin { diff --git a/query/src/sql/planner/plans/logical_join.rs b/query/src/sql/planner/plans/logical_join.rs index 5da1c99324d8..a8a2f9a0879a 100644 --- a/query/src/sql/planner/plans/logical_join.rs +++ b/query/src/sql/planner/plans/logical_join.rs @@ -25,6 +25,7 @@ use crate::sql::plans::Operator; use crate::sql::plans::PhysicalPlan; use crate::sql::plans::RelOp; use crate::sql::plans::Scalar; +use crate::sql::IndexType; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub enum JoinType { @@ -35,6 +36,7 @@ pub enum JoinType { Semi, Anti, Cross, + Mark, } impl Display for JoinType { @@ -61,6 +63,9 @@ impl Display for JoinType { JoinType::Cross => { write!(f, "CROSS") } + JoinType::Mark => { + write!(f, "MARK") + } } } } @@ -71,6 +76,8 @@ pub struct LogicalInnerJoin { pub right_conditions: Vec, pub other_conditions: Vec, pub join_type: JoinType, + // marker_index is for MarkJoin only. + pub marker_index: Option, } impl Operator for LogicalInnerJoin { diff --git a/query/src/sql/planner/plans/scalar.rs b/query/src/sql/planner/plans/scalar.rs index 54571b5ffba4..65bc64b1e984 100644 --- a/query/src/sql/planner/plans/scalar.rs +++ b/query/src/sql/planner/plans/scalar.rs @@ -526,6 +526,10 @@ pub enum SubqueryType { pub struct SubqueryExpr { pub typ: SubqueryType, pub subquery: SExpr, + // The expr that is used to compare the result of the subquery (IN/ANY/ALL), such as `t1.a in (select t2.a from t2)`, t1.a is `child_expr`. + pub child_expr: Option>, + // Comparison operator for Any/All, such as t1.a = Any (...), `compare_op` is `=`. + pub compare_op: Option, pub data_type: DataTypeImpl, pub allow_multi_rows: bool, pub outer_columns: ColumnSet, diff --git a/query/src/sql/planner/semantic/type_check.rs b/query/src/sql/planner/semantic/type_check.rs index a8a13b03ee47..03c489935194 100644 --- a/query/src/sql/planner/semantic/type_check.rs +++ b/query/src/sql/planner/semantic/type_check.rs @@ -50,6 +50,7 @@ use common_functions::scalars::TupleFunction; use crate::common::Evaluator; use crate::sessions::QueryContext; +use crate::sql::binder::wrap_cast_if_needed; use crate::sql::binder::Binder; use crate::sql::optimizer::RelExpr; use crate::sql::planner::metadata::optimize_remove_count_args; @@ -512,15 +513,51 @@ impl<'a> TypeChecker<'a> { subquery, true, None, + None, + None, ) .await? } Expr::Subquery { subquery, .. } => { - self.resolve_subquery(SubqueryType::Scalar, subquery, false, None) + self.resolve_subquery(SubqueryType::Scalar, subquery, false, None, None, None) .await? } + Expr::InSubquery { + subquery, + not, + expr, + span, + } => { + // Not in subquery will be transformed to not(Expr = Any(...)) + if *not { + return self + .resolve_function( + span, + "not", + &[&Expr::InSubquery { + subquery: subquery.clone(), + not: false, + expr: expr.clone(), + span: *span, + }], + required_type, + ) + .await; + } + // InSubquery will be transformed to Expr = Any(...) + self.resolve_subquery( + SubqueryType::Any, + subquery, + true, + Some(*expr.clone()), + Some(ComparisonOp::Equal), + None, + ) + .await? + } + Expr::MapAccess { span, expr, @@ -720,11 +757,6 @@ impl<'a> TypeChecker<'a> { ) .await? } - - _ => Err(ErrorCode::UnImplement(format!( - "Unsupported expr: {:?}", - expr - )))?, }; self.post_resolve(&scalar, &data_type) @@ -1086,6 +1118,8 @@ impl<'a> TypeChecker<'a> { typ: SubqueryType, subquery: &Query<'_>, allow_multi_rows: bool, + child_expr: Option>, + compare_op: Option, _required_type: Option, ) -> Result<(Scalar, DataTypeImpl)> { let mut binder = Binder::new( @@ -1098,19 +1132,37 @@ impl<'a> TypeChecker<'a> { let bind_context = BindContext::with_parent(Box::new(self.bind_context.clone())); let (s_expr, output_context) = binder.bind_query(&bind_context, subquery).await?; - if typ == SubqueryType::Scalar && output_context.columns.len() > 1 { - return Err(ErrorCode::SemanticError( - "Scalar subquery must return only one column", - )); + if (typ == SubqueryType::Scalar || typ == SubqueryType::Any) + && output_context.columns.len() > 1 + { + return Err(ErrorCode::SemanticError(format!( + "Subquery must return only one column, but got {} columns", + output_context.columns.len() + ))); } - let data_type = output_context.columns[0].data_type.clone(); + let mut data_type = output_context.columns[0].data_type.clone(); let rel_expr = RelExpr::with_s_expr(&s_expr); let rel_prop = rel_expr.derive_relational_prop()?; + let mut child_scalar = None; + if let Some(expr) = child_expr { + assert_eq!(output_context.columns.len(), 1); + let (mut scalar, scalar_data_type) = self.resolve(&expr, None).await?; + if scalar_data_type != data_type { + // Make comparison scalar type keep consistent + let coercion_type = merge_types(&scalar_data_type, &data_type)?; + scalar = wrap_cast_if_needed(scalar, &coercion_type); + data_type = coercion_type; + } + child_scalar = Some(Box::new(scalar)); + } + let subquery_expr = SubqueryExpr { subquery: s_expr, + child_expr: child_scalar, + compare_op, data_type: data_type.clone(), allow_multi_rows, typ, diff --git a/query/tests/it/sql/planner/format/mod.rs b/query/tests/it/sql/planner/format/mod.rs index 1492293d5d0a..86eabc142984 100644 --- a/query/tests/it/sql/planner/format/mod.rs +++ b/query/tests/it/sql/planner/format/mod.rs @@ -144,6 +144,7 @@ fn test_format() { .into()], other_conditions: vec![], join_type: JoinType::Inner, + marker_index: None, } .into(), SExpr::create_unary( diff --git a/tests/suites/0_stateless/13_tpch/13_0008_q8.result b/tests/suites/0_stateless/13_tpch/13_0008_q8.result index 72aaa69296e4..e69de29bb2d1 100644 --- a/tests/suites/0_stateless/13_tpch/13_0008_q8.result +++ b/tests/suites/0_stateless/13_tpch/13_0008_q8.result @@ -1,2 +0,0 @@ -1995 0.0 -1996 0.0 diff --git a/tests/suites/0_stateless/13_tpch/13_0008_q8.sql b/tests/suites/0_stateless/13_tpch/13_0008_q8.sql index 34ad0d47beed..e69de29bb2d1 100644 --- a/tests/suites/0_stateless/13_tpch/13_0008_q8.sql +++ b/tests/suites/0_stateless/13_tpch/13_0008_q8.sql @@ -1,38 +0,0 @@ -set enable_planner_v2 = 1; -select - o_year, - sum(case - when nation = 'BRAZIL' then volume - else 0 - end) / sum(volume) as mkt_share -from - ( - select - extract(year from o_orderdate) as o_year, - l_extendedprice * (1 - l_discount) as volume, - n2.n_name as nation - from - part, - supplier, - lineitem, - orders, - customer, - nation n1, - nation n2, - region - where - p_partkey = l_partkey - and s_suppkey = l_suppkey - and l_orderkey = o_orderkey - and o_custkey = c_custkey - and c_nationkey = n1.n_nationkey - and n1.n_regionkey = r_regionkey - and r_name = 'AMERICA' - and s_nationkey = n2.n_nationkey - and o_orderdate between to_date('1995-01-01') and to_date('1996-12-31') - and p_type = 'ECONOMY ANODIZED STEEL' - ) as all_nations -group by - o_year -order by - o_year; diff --git a/tests/suites/0_stateless/13_tpch/13_0016_q16.result b/tests/suites/0_stateless/13_tpch/13_0016_q16.result new file mode 100644 index 000000000000..fc0c42644eae --- /dev/null +++ b/tests/suites/0_stateless/13_tpch/13_0016_q16.result @@ -0,0 +1,34 @@ +Brand#11 PROMO ANODIZED TIN 45 4 +Brand#11 SMALL PLATED COPPER 45 4 +Brand#11 STANDARD POLISHED TIN 45 4 +Brand#13 MEDIUM ANODIZED STEEL 36 4 +Brand#14 SMALL ANODIZED NICKEL 45 4 +Brand#15 LARGE ANODIZED BRASS 45 4 +Brand#21 LARGE BURNISHED COPPER 19 4 +Brand#23 ECONOMY BRUSHED COPPER 9 4 +Brand#25 MEDIUM PLATED BRASS 45 4 +Brand#31 ECONOMY PLATED STEEL 23 4 +Brand#31 PROMO POLISHED TIN 23 4 +Brand#32 MEDIUM BURNISHED BRASS 49 4 +Brand#33 LARGE BRUSHED TIN 36 4 +Brand#33 SMALL BURNISHED NICKEL 3 4 +Brand#34 LARGE PLATED BRASS 45 4 +Brand#34 MEDIUM BRUSHED COPPER 9 4 +Brand#34 SMALL PLATED BRASS 14 4 +Brand#35 STANDARD ANODIZED STEEL 23 4 +Brand#43 PROMO POLISHED BRASS 19 4 +Brand#43 SMALL BRUSHED NICKEL 9 4 +Brand#44 SMALL PLATED COPPER 19 4 +Brand#52 MEDIUM BURNISHED TIN 45 4 +Brand#52 SMALL BURNISHED NICKEL 14 4 +Brand#53 MEDIUM BRUSHED COPPER 3 4 +Brand#55 STANDARD ANODIZED BRASS 36 4 +Brand#55 STANDARD BRUSHED COPPER 3 4 +Brand#13 SMALL BRUSHED NICKEL 19 2 +Brand#25 SMALL BURNISHED COPPER 3 2 +Brand#43 MEDIUM ANODIZED BRASS 14 2 +Brand#53 STANDARD PLATED STEEL 45 2 +Brand#24 MEDIUM PLATED STEEL 19 1 +Brand#51 ECONOMY POLISHED STEEL 49 1 +Brand#53 LARGE BURNISHED NICKEL 23 1 +Brand#54 ECONOMY ANODIZED BRASS 9 1 diff --git a/tests/suites/0_stateless/13_tpch/13_0016_q16.sql b/tests/suites/0_stateless/13_tpch/13_0016_q16.sql new file mode 100644 index 000000000000..845f705657d1 --- /dev/null +++ b/tests/suites/0_stateless/13_tpch/13_0016_q16.sql @@ -0,0 +1,31 @@ +set enable_planner_v2 = 1; +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' +) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size; \ No newline at end of file diff --git a/tests/suites/0_stateless/13_tpch/13_0018_q18.result b/tests/suites/0_stateless/13_tpch/13_0018_q18.result new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/suites/0_stateless/13_tpch/13_0018_q18.sql b/tests/suites/0_stateless/13_tpch/13_0018_q18.sql new file mode 100644 index 000000000000..63649fa088ec --- /dev/null +++ b/tests/suites/0_stateless/13_tpch/13_0018_q18.sql @@ -0,0 +1,33 @@ +set enable_planner_v2 = 1; +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate; \ No newline at end of file diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result index 73441041009e..5db7e08d3282 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result @@ -424,3 +424,5 @@ NULL NULL NULL 1 1 0 1 +1 2 +2 3 diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql index 9e2c3f554242..edde814cec4e 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql @@ -358,5 +358,14 @@ select * from numbers(5) as t where not exists (select * from numbers(3) where n select * from numbers(5) as t where exists (select number as a from numbers(3) where number = t.number and number > 0 and t.number < 2); select * from numbers(5) as t where exists (select * from numbers(3) where number > t.number); +-- (Not)IN Subquery +create table t1(a int, b int); +create table t2(a int, b int); +insert into t1 values(1, 2), (2, 3); +insert into t2 values(3, 4), (2, 3); +select * from t1 where t1.a not in (select t2.a from t2); +select * from t1 where t1.a in (select t2.a from t2); +drop table t1; +drop table t2; set enable_planner_v2 = 0;