From 883c08eea1be091ec53a2220ea2d1261c1adaeb8 Mon Sep 17 00:00:00 2001 From: xudong963 Date: Mon, 4 Jul 2022 17:33:25 +0800 Subject: [PATCH] fix clippy, add annotations, remove tpch q8 --- .../transforms/hash_join/join_hash_table.rs | 129 +++++++----------- .../transforms/hash_join/result_blocks.rs | 49 ++++--- .../processors/transforms/transform_apply.rs | 2 +- query/src/sql/exec/pipeline_builder.rs | 2 +- .../optimizer/heuristic/subquery_rewriter.rs | 15 +- query/src/sql/planner/plans/logical_join.rs | 1 + query/src/sql/planner/plans/scalar.rs | 4 +- query/src/sql/planner/semantic/type_check.rs | 16 ++- .../0_stateless/13_tpch/13_0008_q8.result | 2 - .../suites/0_stateless/13_tpch/13_0008_q8.sql | 38 ------ 10 files changed, 106 insertions(+), 152 deletions(-) diff --git a/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs b/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs index efec533311551..e5461a592efcd 100644 --- a/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs +++ b/query/src/pipelines/new/processors/transforms/hash_join/join_hash_table.rs @@ -105,29 +105,29 @@ pub enum HashTable { #[derive(Clone, Eq, PartialEq, Debug)] pub enum MarkerKind { - TRUE, - FALSE, - NULL, + True, + False, + Null, } -pub struct JoinHashTable { - /// Reference count - ref_count: Mutex, - is_finished: Mutex, - +pub struct HashJoinUtil { pub(crate) build_keys: Vec>, pub(crate) probe_keys: Vec>, + pub(crate) join_type: JoinType, + pub(crate) other_predicate: Option>, + pub(crate) marker: RwLock>, + pub(crate) marker_index: Option, +} +pub struct JoinHashTable { pub(crate) ctx: Arc, - + /// Reference count + ref_count: Mutex, + is_finished: Mutex, /// A shared big hash table stores all the rows from build side pub(crate) hash_table: RwLock, pub(crate) row_space: RowSpace, - pub(crate) join_type: JoinType, - pub(crate) other_predicate: Option>, - - pub(crate) marker: RwLock>, - pub(crate) marker_index: Option, + pub(crate) hash_join_util: HashJoinUtil, } impl JoinHashTable { @@ -143,127 +143,107 @@ impl JoinHashTable { let hash_key_types: Vec = build_keys.iter().map(|expr| expr.data_type()).collect(); let method = DataBlock::choose_hash_method_with_types(&hash_key_types)?; + let hash_join_util = HashJoinUtil { + build_keys: build_keys + .iter() + .map(Evaluator::eval_physical_scalar) + .collect::>()?, + probe_keys: probe_keys + .iter() + .map(Evaluator::eval_physical_scalar) + .collect::>()?, + join_type, + other_predicate: other_predicate + .map(Evaluator::eval_physical_scalar) + .transpose()?, + marker: RwLock::new(vec![]), + marker_index, + }; Ok(match method { HashMethodKind::SingleString(_) | HashMethodKind::Serializer(_) => { Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::SerializerHashTable(SerializerHashTable { hash_table: HashMap::>::create(), hash_method: HashMethodSerializer::default(), }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?) } HashMethodKind::KeysU8(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU8HashTable(KeyU8HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?), HashMethodKind::KeysU16(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU16HashTable(KeyU16HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?), HashMethodKind::KeysU32(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU32HashTable(KeyU32HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?), HashMethodKind::KeysU64(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU64HashTable(KeyU64HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?), HashMethodKind::KeysU128(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU128HashTable(KeyU128HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?), HashMethodKind::KeysU256(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU256HashTable(KeyU256HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?), HashMethodKind::KeysU512(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - join_type, HashTable::KeyU512HashTable(KeyU512HashTable { hash_table: HashMap::>::create(), hash_method, }), - build_keys, - probe_keys, - other_predicate, build_schema, - marker_index, + hash_join_util, )?), }) } pub fn try_create( ctx: Arc, - join_type: JoinType, hash_table: HashTable, - build_keys: &[PhysicalScalar], - probe_keys: &[PhysicalScalar], - other_predicate: Option<&PhysicalScalar>, mut build_data_schema: DataSchemaRef, - marker_index: Option, + hash_join_util: HashJoinUtil, ) -> Result { - if join_type == JoinType::Left { + if hash_join_util.join_type == JoinType::Left { let mut nullable_field = Vec::with_capacity(build_data_schema.fields().len()); for field in build_data_schema.fields().iter() { nullable_field.push(DataField::new_nullable( @@ -277,22 +257,9 @@ impl JoinHashTable { row_space: RowSpace::new(build_data_schema), ref_count: Mutex::new(0), is_finished: Mutex::new(false), - build_keys: build_keys - .iter() - .map(Evaluator::eval_physical_scalar) - .collect::>()?, - probe_keys: probe_keys - .iter() - .map(Evaluator::eval_physical_scalar) - .collect::>()?, - other_predicate: other_predicate - .map(Evaluator::eval_physical_scalar) - .transpose()?, + hash_join_util, ctx, hash_table: RwLock::new(hash_table), - join_type, - marker: RwLock::new(vec![]), - marker_index, }) } @@ -363,6 +330,7 @@ impl JoinHashTable { ) -> Result> { let func_ctx = self.ctx.try_get_function_context()?; let probe_keys = self + .hash_join_util .probe_keys .iter() .map(|expr| Ok(expr.eval(&func_ctx, input)?.vector().clone())) @@ -430,6 +398,7 @@ impl HashJoinState for JoinHashTable { fn build(&self, input: DataBlock) -> Result<()> { let func_ctx = self.ctx.try_get_function_context()?; let build_cols = self + .hash_join_util .build_keys .iter() .map(|expr| Ok(expr.eval(&func_ctx, &input)?.vector().clone())) @@ -452,12 +421,12 @@ impl HashJoinState for JoinHashTable { } fn probe(&self, input: &DataBlock, probe_state: &mut ProbeState) -> Result> { - match self.join_type { + match self.hash_join_util.join_type { JoinType::Inner | JoinType::Semi | JoinType::Anti | JoinType::Left | JoinType::Mark => { self.probe_join(input, probe_state) } JoinType::Cross => self.probe_cross_join(input, probe_state), - _ => unimplemented!("{} is unimplemented", self.join_type), + _ => unimplemented!("{} is unimplemented", self.hash_join_util.join_type), } } @@ -486,20 +455,20 @@ impl HashJoinState for JoinHashTable { fn finish(&self) -> Result<()> { let chunks = self.row_space.chunks.read().unwrap(); - let mut marker = self.marker.write(); + let mut marker = self.hash_join_util.marker.write(); for chunk_index in 0..chunks.len() { let chunk = &chunks[chunk_index]; let mut columns = vec![]; if let Some(cols) = chunk.cols.as_ref() { columns = Vec::with_capacity(cols.len()); for col in cols.iter() { - if self.join_type == JoinType::Mark { + if self.hash_join_util.join_type == JoinType::Mark { assert_eq!(cols.len(), 1); for row_idx in 0..col.len() { if col.get(row_idx) == DataValue::Null { - marker.push(MarkerKind::NULL); + marker.push(MarkerKind::Null); } else { - marker.push(MarkerKind::FALSE); + marker.push(MarkerKind::False); } } } diff --git a/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs b/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs index 349d57683cbc2..fc3a448b75ea2 100644 --- a/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs +++ b/query/src/pipelines/new/processors/transforms/hash_join/result_blocks.rs @@ -25,6 +25,7 @@ use common_datavalues::DataSchemaRef; use common_datavalues::NullableColumn; use common_datavalues::NullableType; use common_datavalues::Series; +use common_exception::ErrorCode; use common_exception::Result; use super::JoinHashTable; @@ -53,7 +54,7 @@ impl JoinHashTable { let build_indexs = &mut probe_state.build_indexs; let mut results: Vec = vec![]; - match self.join_type { + match self.hash_join_util.join_type { JoinType::Inner => { for (i, key) in keys.iter().enumerate() { let probe_result_ptr = hash_table.find_key(key); @@ -74,7 +75,7 @@ impl JoinHashTable { let probe_block = DataBlock::block_take_by_indices(input, probe_indexs)?; let merged_block = self.merge_eq_block(&build_block, &probe_block)?; - match &self.other_predicate { + match &self.hash_join_util.other_predicate { Some(other_predicate) => { let func_ctx = self.ctx.try_get_function_context()?; let filter_vector = other_predicate.eval(&func_ctx, &merged_block)?; @@ -87,7 +88,7 @@ impl JoinHashTable { } } JoinType::Semi => { - if self.other_predicate.is_none() { + if self.hash_join_util.other_predicate.is_none() { let result = self.semi_anti_join::(hash_table, probe_state, keys, input)?; return Ok(vec![result]); @@ -102,7 +103,7 @@ impl JoinHashTable { } } JoinType::Anti => { - if self.other_predicate.is_none() { + if self.hash_join_util.other_predicate.is_none() { let result = self.semi_anti_join::(hash_table, probe_state, keys, input)?; return Ok(vec![result]); @@ -119,7 +120,7 @@ impl JoinHashTable { // probe_blocks left join build blocks JoinType::Left => { - if self.other_predicate.is_none() { + if self.hash_join_util.other_predicate.is_none() { let result = self.left_join::(hash_table, probe_state, keys, input)?; return Ok(vec![result]); @@ -130,6 +131,7 @@ impl JoinHashTable { } Mark => { let mut has_null = false; + // Check if there is any null in the probe block. if let Some(validity) = input.column(0).validity().1 { if validity.null_count() > 0 { has_null = true; @@ -139,35 +141,40 @@ impl JoinHashTable { let probe_result_ptr = hash_table.find_key(key); if let Some(v) = probe_result_ptr { let probe_result_ptrs = v.get_value(); - let mut marker = self.marker.write(); + let mut marker = self.hash_join_util.marker.write(); for ptr in probe_result_ptrs { - marker[ptr.row_index as usize] = MarkerKind::TRUE; + // If find join partner, set the marker to true. + marker[ptr.row_index as usize] = MarkerKind::True; } } } - let mut marker = self.marker.write(); + let mut marker = self.hash_join_util.marker.write(); let mut validity = MutableBitmap::new(); let mut boolean_bit_map = MutableBitmap::new(); for m in marker.iter_mut() { - if m == &mut MarkerKind::FALSE && has_null { - *m = MarkerKind::NULL; + if m == &mut MarkerKind::False && has_null { + *m = MarkerKind::Null; } - if m == &mut MarkerKind::NULL { + if m == &mut MarkerKind::Null { validity.push(false); } else { validity.push(true); } - if m == &mut MarkerKind::FALSE || m == &mut MarkerKind::NULL { - boolean_bit_map.push(false); - } else { + if m == &mut MarkerKind::True { boolean_bit_map.push(true); + } else { + boolean_bit_map.push(false); } } // transfer marker to a Nullable(BooleanColumn) let boolean_column = BooleanColumn::from_arrow_data(boolean_bit_map.into()); let marker_column = Self::set_validity(&boolean_column.arc(), &validity.into())?; let marker_schema = DataSchema::new(vec![DataField::new( - &self.marker_index.unwrap().to_string(), + &self + .hash_join_util + .marker_index + .ok_or_else(|| ErrorCode::LogicalError("Invalid mark join"))? + .to_string(), NullableType::new_impl(BooleanType::new_impl()), )]); let marker_block = @@ -267,8 +274,10 @@ impl JoinHashTable { let build_block = self.row_space.gather(build_indexs)?; let merged_block = self.merge_eq_block(&build_block, &probe_block)?; - let (bm, all_true, all_false) = - self.get_other_filters(&merged_block, self.other_predicate.as_ref().unwrap())?; + let (bm, all_true, all_false) = self.get_other_filters( + &merged_block, + self.hash_join_util.other_predicate.as_ref().unwrap(), + )?; let mut bm = match (bm, all_true, all_false) { (Some(b), _, _) => b.into_mut().right().unwrap(), @@ -356,8 +365,10 @@ impl JoinHashTable { return Ok(merged_block); } - let (bm, all_true, all_false) = - self.get_other_filters(&merged_block, self.other_predicate.as_ref().unwrap())?; + let (bm, all_true, all_false) = self.get_other_filters( + &merged_block, + self.hash_join_util.other_predicate.as_ref().unwrap(), + )?; if all_true { return Ok(merged_block); diff --git a/query/src/pipelines/new/processors/transforms/transform_apply.rs b/query/src/pipelines/new/processors/transforms/transform_apply.rs index d8b25667d7622..b6e141ac8569a 100644 --- a/query/src/pipelines/new/processors/transforms/transform_apply.rs +++ b/query/src/pipelines/new/processors/transforms/transform_apply.rs @@ -70,7 +70,7 @@ impl OuterRefRewriter { probe_keys: probe_keys.clone(), other_conditions: other_conditions.clone(), join_type: join_type.clone(), - marker_index: marker_index.clone(), + marker_index: *marker_index, }), PhysicalPlan::Limit { input, diff --git a/query/src/sql/exec/pipeline_builder.rs b/query/src/sql/exec/pipeline_builder.rs index 5f5557b3b03e5..f01e7947de12d 100644 --- a/query/src/sql/exec/pipeline_builder.rs +++ b/query/src/sql/exec/pipeline_builder.rs @@ -164,7 +164,7 @@ impl PipelineBuilder { probe_keys, other_conditions, join_type.clone(), - marker_index.clone(), + *marker_index, build_side_pipeline, pipeline, )?; diff --git a/query/src/sql/optimizer/heuristic/subquery_rewriter.rs b/query/src/sql/optimizer/heuristic/subquery_rewriter.rs index f80687065fffa..e943c49dfdfd0 100644 --- a/query/src/sql/optimizer/heuristic/subquery_rewriter.rs +++ b/query/src/sql/optimizer/heuristic/subquery_rewriter.rs @@ -305,7 +305,7 @@ impl SubqueryRewriter { .next() .ok_or_else(|| ErrorCode::LogicalError("Invalid subquery"))?; let column_name = format!("subquery_{}", index); - let right_condition = Scalar::BoundColumnRef(BoundColumnRef { + let left_condition = Scalar::BoundColumnRef(BoundColumnRef { column: ColumnBinding { table_name: None, column_name, @@ -315,14 +315,21 @@ impl SubqueryRewriter { }, }); if prop.outer_columns.is_empty() { + // Add a marker column to save comparison result. + // The column is Nullable(Boolean), the data value is TRUE, FALSE, or NULL. + // If subquery contains NULL, the comparison result is TRUE or NULL. Such as t1.a => {1, 3, 4}, select t1.a in (1, 2, NULL) from t1; The sql will return {true, null, null}. + // If subquery doesn't contain NULL, the comparison result is FALSE, TRUE, or NULL. let marker_index = self.metadata.write().add_column( "marker".to_string(), NullableType::new_impl(BooleanType::new_impl()), None, ); + // Consider the sql: select * from t1 where t1.a = any(select t2.a from t2); + // Will be transferred to:select t1.a, t2.a, marker_index from t2, t1 where t2.a = t1.a; + // Note that subquery is the left table, and it'll be the probe side. let mark_join = LogicalInnerJoin { right_conditions: vec![*subquery.child_expr.as_ref().unwrap().clone()], - left_conditions: vec![right_condition], + left_conditions: vec![left_condition], other_conditions: vec![], join_type: JoinType::Mark, marker_index: Some(marker_index), @@ -333,7 +340,9 @@ impl SubqueryRewriter { UnnestResult::MarkJoin { marker_index }, )) } else { - todo!() + Err(ErrorCode::LogicalError( + "Unsupported subquery type: Correlated AnySubquery", + )) } } _ => Err(ErrorCode::LogicalError(format!( diff --git a/query/src/sql/planner/plans/logical_join.rs b/query/src/sql/planner/plans/logical_join.rs index 0a4583dcc55ce..a8a2f9a0879a3 100644 --- a/query/src/sql/planner/plans/logical_join.rs +++ b/query/src/sql/planner/plans/logical_join.rs @@ -76,6 +76,7 @@ pub struct LogicalInnerJoin { pub right_conditions: Vec, pub other_conditions: Vec, pub join_type: JoinType, + // marker_index is for MarkJoin only. pub marker_index: Option, } diff --git a/query/src/sql/planner/plans/scalar.rs b/query/src/sql/planner/plans/scalar.rs index 03881d3d55865..65bc64b1e9842 100644 --- a/query/src/sql/planner/plans/scalar.rs +++ b/query/src/sql/planner/plans/scalar.rs @@ -526,9 +526,9 @@ pub enum SubqueryType { pub struct SubqueryExpr { pub typ: SubqueryType, pub subquery: SExpr, - // The expr that is used to compare the result of the subquery (IN/ANY/ALL) + // The expr that is used to compare the result of the subquery (IN/ANY/ALL), such as `t1.a in (select t2.a from t2)`, t1.a is `child_expr`. pub child_expr: Option>, - // Comparison operator for Any/All, such as a = Any (...) + // Comparison operator for Any/All, such as t1.a = Any (...), `compare_op` is `=`. pub compare_op: Option, pub data_type: DataTypeImpl, pub allow_multi_rows: bool, diff --git a/query/src/sql/planner/semantic/type_check.rs b/query/src/sql/planner/semantic/type_check.rs index f1327580c691a..ad14c9934300c 100644 --- a/query/src/sql/planner/semantic/type_check.rs +++ b/query/src/sql/planner/semantic/type_check.rs @@ -528,6 +528,7 @@ impl<'a> TypeChecker<'a> { expr, span, } => { + // Not in subquery will be transformed to not(Expr = Any(...)) if *not { return self .resolve_function( @@ -537,7 +538,7 @@ impl<'a> TypeChecker<'a> { subquery: subquery.clone(), not: false, expr: expr.clone(), - span: span.clone(), + span: *span, }], required_type, ) @@ -1129,10 +1130,13 @@ impl<'a> TypeChecker<'a> { let bind_context = BindContext::with_parent(Box::new(self.bind_context.clone())); let (s_expr, output_context) = binder.bind_query(&bind_context, subquery).await?; - if typ == SubqueryType::Scalar && output_context.columns.len() > 1 { - return Err(ErrorCode::SemanticError( - "Scalar subquery must return only one column", - )); + if (typ == SubqueryType::Scalar || typ == SubqueryType::Any) + && output_context.columns.len() > 1 + { + return Err(ErrorCode::SemanticError(format!( + "Subquery must return only one column, but got {} columns", + output_context.columns.len() + ))); } let mut data_type = output_context.columns[0].data_type.clone(); @@ -1145,7 +1149,7 @@ impl<'a> TypeChecker<'a> { assert_eq!(output_context.columns.len(), 1); let (mut scalar, scalar_data_type) = self.resolve(&expr, None).await?; if scalar_data_type != data_type { - // Make compare type keep consistent + // Make comparison scalar type keep consistent let coercion_type = merge_types(&scalar_data_type, &data_type)?; scalar = wrap_cast_if_needed(scalar, &coercion_type); data_type = coercion_type; diff --git a/tests/suites/0_stateless/13_tpch/13_0008_q8.result b/tests/suites/0_stateless/13_tpch/13_0008_q8.result index 72aaa69296e42..e69de29bb2d1d 100644 --- a/tests/suites/0_stateless/13_tpch/13_0008_q8.result +++ b/tests/suites/0_stateless/13_tpch/13_0008_q8.result @@ -1,2 +0,0 @@ -1995 0.0 -1996 0.0 diff --git a/tests/suites/0_stateless/13_tpch/13_0008_q8.sql b/tests/suites/0_stateless/13_tpch/13_0008_q8.sql index 34ad0d47beedd..e69de29bb2d1d 100644 --- a/tests/suites/0_stateless/13_tpch/13_0008_q8.sql +++ b/tests/suites/0_stateless/13_tpch/13_0008_q8.sql @@ -1,38 +0,0 @@ -set enable_planner_v2 = 1; -select - o_year, - sum(case - when nation = 'BRAZIL' then volume - else 0 - end) / sum(volume) as mkt_share -from - ( - select - extract(year from o_orderdate) as o_year, - l_extendedprice * (1 - l_discount) as volume, - n2.n_name as nation - from - part, - supplier, - lineitem, - orders, - customer, - nation n1, - nation n2, - region - where - p_partkey = l_partkey - and s_suppkey = l_suppkey - and l_orderkey = o_orderkey - and o_custkey = c_custkey - and c_nationkey = n1.n_nationkey - and n1.n_regionkey = r_regionkey - and r_name = 'AMERICA' - and s_nationkey = n2.n_nationkey - and o_orderdate between to_date('1995-01-01') and to_date('1996-12-31') - and p_type = 'ECONOMY ANODIZED STEEL' - ) as all_nations -group by - o_year -order by - o_year;