diff --git a/Cargo.lock b/Cargo.lock index c3398fea60fa..8af927d697e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1866,6 +1866,7 @@ dependencies = [ "tokio-stream", "toml", "tonic", + "twox-hash", "typetag", "url", "uuid", diff --git a/common/datablocks/src/kernels/data_block_group_by.rs b/common/datablocks/src/kernels/data_block_group_by.rs index b53990cbf8d0..1d96aa4784ab 100644 --- a/common/datablocks/src/kernels/data_block_group_by.rs +++ b/common/datablocks/src/kernels/data_block_group_by.rs @@ -14,6 +14,7 @@ use common_datavalues::remove_nullable; use common_datavalues::DataType; +use common_datavalues::DataTypeImpl; use common_datavalues::TypeID; use common_exception::Result; @@ -28,6 +29,7 @@ use crate::HashMethod; use crate::HashMethodSingleString; impl DataBlock { + // TODO(leiysky): replace with `DataBlock::choose_hash_method_with_types` and deprecate this method pub fn choose_hash_method( block: &DataBlock, column_names: &[String], @@ -68,6 +70,44 @@ impl DataBlock { } } + pub fn choose_hash_method_with_types( + hash_key_types: &[DataTypeImpl], + ) -> Result { + if hash_key_types.len() == 1 { + let typ = &hash_key_types[0]; + if typ.data_type_id() == TypeID::String { + return Ok(HashMethodKind::SingleString( + HashMethodSingleString::default(), + )); + } + } + + let mut group_key_len = 0; + for typ in hash_key_types { + let typ = remove_nullable(typ); + + if typ.data_type_id().is_numeric() || typ.data_type_id().is_date_or_date_time() { + group_key_len += typ.data_type_id().numeric_byte_size()?; + + //extra one byte for null flag + if typ.is_nullable() { + group_key_len += 1; + } + } else { + return Ok(HashMethodKind::Serializer(HashMethodSerializer::default())); + } + } + + match group_key_len { + 1 => Ok(HashMethodKind::KeysU8(HashMethodKeysU8::default())), + 2 => Ok(HashMethodKind::KeysU16(HashMethodKeysU16::default())), + 3..=4 => Ok(HashMethodKind::KeysU32(HashMethodKeysU32::default())), + 5..=8 => Ok(HashMethodKind::KeysU64(HashMethodKeysU64::default())), + // TODO support u128, u256 + _ => Ok(HashMethodKind::Serializer(HashMethodSerializer::default())), + } + } + pub fn group_by_blocks(block: &DataBlock, column_names: &[String]) -> Result> { let method = Self::choose_hash_method(block, column_names)?; Ok(match method { diff --git a/query/Cargo.toml b/query/Cargo.toml index 7fe49282a9fd..283626410c99 100644 --- a/query/Cargo.toml +++ b/query/Cargo.toml @@ -111,6 +111,7 @@ time = "0.3.9" tokio-rustls = "0.23.3" tokio-stream = { version = "0.1.8", features = ["net"] } tonic = "=0.6.2" +twox-hash = "1.6.2" typetag = "0.1.8" uuid = { version = "0.8.2", features = ["serde", "v4"] } walkdir = "2.3.2" diff --git a/query/src/interpreters/interpreter_select_v2.rs b/query/src/interpreters/interpreter_select_v2.rs index c000bfe50c10..2826c109e331 100644 --- a/query/src/interpreters/interpreter_select_v2.rs +++ b/query/src/interpreters/interpreter_select_v2.rs @@ -21,6 +21,7 @@ use common_tracing::tracing; use crate::interpreters::stream::ProcessorExecutorStream; use crate::interpreters::Interpreter; use crate::interpreters::InterpreterPtr; +use crate::pipelines::new::executor::PipelineExecutor; use crate::pipelines::new::executor::PipelinePullingExecutor; use crate::sessions::QueryContext; use crate::sql::Planner; @@ -52,9 +53,17 @@ impl Interpreter for SelectInterpreterV2 { _input_stream: Option, ) -> Result { let mut planner = Planner::new(self.ctx.clone()); - let pipeline = planner.plan_sql(self.query.as_str()).await?; + let (root_pipeline, pipelines) = planner.plan_sql(self.query.as_str()).await?; let async_runtime = self.ctx.get_storage_runtime(); - let executor = PipelinePullingExecutor::try_create(async_runtime, pipeline)?; + + // Spawn sub-pipelines + for pipeline in pipelines { + let executor = PipelineExecutor::create(async_runtime.clone(), pipeline)?; + executor.execute()?; + } + + // Spawn root pipeline + let executor = PipelinePullingExecutor::try_create(async_runtime, root_pipeline)?; let executor_stream = Box::pin(ProcessorExecutorStream::create(executor)?); Ok(Box::pin(self.ctx.try_create_abortable(executor_stream)?)) } diff --git a/query/src/pipelines/new/processors/mod.rs b/query/src/pipelines/new/processors/mod.rs index 9c76bfaf69bb..565db8ac0c64 100644 --- a/query/src/pipelines/new/processors/mod.rs +++ b/query/src/pipelines/new/processors/mod.rs @@ -47,8 +47,11 @@ pub use sources::SyncSourcer; pub use transforms::AggregatorParams; pub use transforms::AggregatorTransformParams; pub use transforms::BlockCompactor; +pub use transforms::ChainHashTable; pub use transforms::ExpressionTransform; +pub use transforms::HashJoinState; pub use transforms::ProjectionTransform; +pub use transforms::SinkBuildHashTable; pub use transforms::SortMergeCompactor; pub use transforms::SubQueriesPuller; pub use transforms::TransformAddOn; @@ -59,6 +62,7 @@ pub use transforms::TransformCompact; pub use transforms::TransformCreateSets; pub use transforms::TransformDummy; pub use transforms::TransformFilter; +pub use transforms::TransformHashJoinProbe; pub use transforms::TransformHaving; pub use transforms::TransformLimit; pub use transforms::TransformLimitBy; diff --git a/query/src/pipelines/new/processors/transforms/hash_join/hash.rs b/query/src/pipelines/new/processors/transforms/hash_join/hash.rs new file mode 100644 index 000000000000..ae9eeb967003 --- /dev/null +++ b/query/src/pipelines/new/processors/transforms/hash_join/hash.rs @@ -0,0 +1,59 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::hash::Hasher; + +use common_datavalues::ColumnRef; +use common_datavalues::ColumnWithField; +use common_datavalues::DataField; +use common_datavalues::PrimitiveColumn; +use common_datavalues::Series; +use common_exception::Result; +use common_functions::scalars::FunctionContext; +use common_functions::scalars::FunctionFactory; +use twox_hash::XxHash64; + +pub type HashVector = Vec; + +pub struct HashUtil; + +impl HashUtil { + pub fn compute_hash(column: &ColumnRef) -> Result { + let hash_function = FunctionFactory::instance().get("xxhash64", &[&column.data_type()])?; + let field = DataField::new("", column.data_type()); + let result = hash_function.eval( + FunctionContext::default(), + &[ColumnWithField::new(column.clone(), field)], + column.len(), + )?; + + let result = Series::remove_nullable(&result); + let result = Series::check_get::>(&result)?; + Ok(result.values().to_vec()) + } + + pub fn combine_hashes(inputs: &[HashVector], size: usize) -> HashVector { + static XXHASH_SEED: u64 = 0; + + let mut result = Vec::with_capacity(size); + result.resize(size, XxHash64::with_seed(XXHASH_SEED)); + for input in inputs.iter() { + assert_eq!(input.len(), size); + for i in 0..size { + result[i].write_u64(input[i]); + } + } + result.into_iter().map(|h| h.finish()).collect() + } +} diff --git a/query/src/pipelines/new/processors/transforms/hash_join/hash_table.rs b/query/src/pipelines/new/processors/transforms/hash_join/hash_table.rs new file mode 100644 index 000000000000..cf6632d86633 --- /dev/null +++ b/query/src/pipelines/new/processors/transforms/hash_join/hash_table.rs @@ -0,0 +1,255 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::RwLock; + +use common_datablocks::DataBlock; +use common_datavalues::ColumnRef; +use common_datavalues::DataSchemaRef; +use common_exception::Result; +use common_planners::Expression; + +use crate::common::ExpressionEvaluator; +use crate::pipelines::new::processors::transforms::hash_join::hash::HashUtil; +use crate::pipelines::new::processors::transforms::hash_join::hash::HashVector; +use crate::pipelines::new::processors::transforms::hash_join::row::compare_and_combine; +use crate::pipelines::new::processors::transforms::hash_join::row::RowPtr; +use crate::pipelines::new::processors::transforms::hash_join::row::RowSpace; +use crate::sessions::QueryContext; + +/// Concurrent hash table for hash join. +pub trait HashJoinState: Send + Sync { + /// Build hash table with input DataBlock + fn build(&self, input: DataBlock) -> Result<()>; + + /// Probe the hash table and retrieve matched rows as DataBlocks + fn probe(&self, input: &DataBlock) -> Result>; + + /// Attach to state + fn attach(&self) -> Result<()>; + + /// Detach to state + fn detach(&self) -> Result<()>; + + /// Is building finished. + fn is_finished(&self) -> Result; + + /// Finish building hash table, will be called only once as soon as all handles + /// have been detached from current state. + fn finish(&self) -> Result<()>; +} + +pub struct ChainHashTable { + /// Reference count + ref_count: Mutex, + is_finished: Mutex, + + build_expressions: Vec, + probe_expressions: Vec, + + ctx: Arc, + + /// A shared big hash table stores all the rows from build side + hash_table: RwLock>>, + row_space: RowSpace, +} + +impl ChainHashTable { + pub fn try_create( + build_expressions: Vec, + probe_expressions: Vec, + build_data_schema: DataSchemaRef, + _probe_data_schema: DataSchemaRef, + ctx: Arc, + ) -> Result { + Ok(Self { + row_space: RowSpace::new(build_data_schema), + ref_count: Mutex::new(0), + is_finished: Mutex::new(false), + build_expressions, + probe_expressions, + ctx, + hash_table: RwLock::new(vec![]), + }) + } + + fn get_matched_ptrs(&self, hash_key: u64) -> Vec { + let hash_table = self.hash_table.read().unwrap(); + let mut ptr: Option = hash_table[hash_key as usize]; + let mut result: Vec = vec![]; + + while let Some(v) = ptr { + result.push(v); + ptr = self.row_space.get_next(&v); + } + + result + } + + fn hash(&self, columns: &[ColumnRef], row_count: usize) -> Result { + let hash_values = columns + .iter() + .map(|col| HashUtil::compute_hash(col)) + .collect::>>()?; + Ok(HashUtil::combine_hashes(&hash_values, row_count)) + } + + fn apply_capacity(hash_vector: &HashVector, capacity: usize) -> HashVector { + // TODO: implement in a more efficient way + let mut result = HashVector::with_capacity(capacity); + for hash in hash_vector { + result.push(*hash % (capacity as u64)); + } + result + } +} + +impl HashJoinState for ChainHashTable { + fn build(&self, input: DataBlock) -> Result<()> { + let build_keys = self + .build_expressions + .iter() + .map(|expr| { + ExpressionEvaluator::eval(self.ctx.try_get_function_context()?, expr, &input) + }) + .collect::>>()?; + + let hash_values = self.hash(&build_keys, input.num_rows())?; + + self.row_space.push(input, hash_values)?; + + Ok(()) + } + + fn probe(&self, input: &DataBlock) -> Result> { + let probe_keys = self + .probe_expressions + .iter() + .map(|expr| { + ExpressionEvaluator::eval(self.ctx.try_get_function_context()?, expr, input) + }) + .collect::>>()?; + + let hash_values = self.hash(&probe_keys, input.num_rows())?; + let hash_values = + ChainHashTable::apply_capacity(&hash_values, self.hash_table.read().unwrap().len()); + + let mut results: Vec = vec![]; + for (i, hash_value) in hash_values.iter().enumerate().take(input.num_rows()) { + let probe_result_ptrs = self.get_matched_ptrs(*hash_value); + if probe_result_ptrs.is_empty() { + // No matched row for current probe row + continue; + } + let result_block = self.row_space.gather(&probe_result_ptrs)?; + + let probe_block = DataBlock::block_take_by_indices(input, &[i as u32])?; + let mut replicated_probe_block = DataBlock::empty(); + for (i, col) in probe_block.columns().iter().enumerate() { + let replicated_col = col.replicate(&[result_block.num_rows()]); + replicated_probe_block = replicated_probe_block + .add_column(replicated_col, probe_block.schema().field(i).clone())?; + } + + let build_keys = self + .build_expressions + .iter() + .map(|expr| { + ExpressionEvaluator::eval( + self.ctx.try_get_function_context()?, + expr, + &result_block, + ) + }) + .collect::>>()?; + + let probe_keys = self + .probe_expressions + .iter() + .map(|expr| { + ExpressionEvaluator::eval( + self.ctx.try_get_function_context()?, + expr, + &replicated_probe_block, + ) + }) + .collect::>>()?; + + let output = compare_and_combine( + replicated_probe_block, + result_block, + &build_keys, + &probe_keys, + self.ctx.clone(), + )?; + results.push(output); + } + + Ok(results) + } + + fn attach(&self) -> Result<()> { + let mut count = self.ref_count.lock().unwrap(); + *count += 1; + Ok(()) + } + + fn detach(&self) -> Result<()> { + let mut count = self.ref_count.lock().unwrap(); + *count -= 1; + if *count == 0 { + self.finish()?; + let mut is_finished = self.is_finished.lock().unwrap(); + *is_finished = true; + Ok(()) + } else { + Ok(()) + } + } + + fn is_finished(&self) -> Result { + Ok(*self.is_finished.lock().unwrap()) + } + + fn finish(&self) -> Result<()> { + let mut hash_table = self.hash_table.write().unwrap(); + hash_table.resize(self.row_space.num_rows(), None); + + { + let mut chunks = self.row_space.chunks.write().unwrap(); + for chunk_index in 0..chunks.len() { + let chunk = &chunks[chunk_index]; + let hash_values = + ChainHashTable::apply_capacity(&chunk.hash_values, hash_table.len()); + for (row_index, hash_value) in hash_values.iter().enumerate().take(chunk.num_rows()) + { + let ptr = RowPtr { + chunk_index: chunk_index as u32, + row_index: row_index as u32, + }; + + if let Some(previous_ptr) = &hash_table[*hash_value as usize] { + chunks[ptr.chunk_index as usize].next_ptr[ptr.row_index as usize] = + Some(*previous_ptr); + } + hash_table[*hash_value as usize] = Some(ptr); + } + } + } + + Ok(()) + } +} diff --git a/query/src/pipelines/new/processors/transforms/hash_join/mod.rs b/query/src/pipelines/new/processors/transforms/hash_join/mod.rs new file mode 100644 index 000000000000..fe74090800cd --- /dev/null +++ b/query/src/pipelines/new/processors/transforms/hash_join/mod.rs @@ -0,0 +1,20 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod hash; +mod hash_table; +mod row; + +pub use hash_table::ChainHashTable; +pub use hash_table::HashJoinState; diff --git a/query/src/pipelines/new/processors/transforms/hash_join/row.rs b/query/src/pipelines/new/processors/transforms/hash_join/row.rs new file mode 100644 index 000000000000..55c0736ad40b --- /dev/null +++ b/query/src/pipelines/new/processors/transforms/hash_join/row.rs @@ -0,0 +1,179 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::sync::RwLock; + +use common_datablocks::DataBlock; +use common_datavalues::ColumnRef; +use common_datavalues::DataField; +use common_datavalues::DataSchema; +use common_datavalues::DataSchemaRef; +use common_exception::Result; +use common_planners::Expression; + +use crate::common::ExpressionEvaluator; +use crate::pipelines::new::processors::transforms::hash_join::hash::HashVector; +use crate::sessions::QueryContext; + +pub struct Chunk { + pub data_block: DataBlock, + pub hash_values: HashVector, + pub next_ptr: Vec>, +} + +impl Chunk { + pub fn num_rows(&self) -> usize { + self.data_block.num_rows() + } +} + +#[derive(Clone, Copy)] +pub struct RowPtr { + pub chunk_index: u32, + pub row_index: u32, +} + +pub struct RowSpace { + pub data_schema: DataSchemaRef, + pub chunks: RwLock>, +} + +impl RowSpace { + pub fn new(data_schema: DataSchemaRef) -> Self { + Self { + data_schema, + chunks: RwLock::new(vec![]), + } + } + + pub fn push(&self, data_block: DataBlock, hash_values: HashVector) -> Result<()> { + let row_count = data_block.num_rows(); + let chunk = Chunk { + data_block, + hash_values, + next_ptr: vec![None; row_count], + }; + + { + // Acquire write lock in current scope + let mut chunks = self.chunks.write().unwrap(); + chunks.push(chunk); + } + + Ok(()) + } + + pub fn num_rows(&self) -> usize { + self.chunks + .read() + .unwrap() + .iter() + .fold(0, |acc, v| acc + v.num_rows()) + } + + pub fn get_next(&self, row_ptr: &RowPtr) -> Option { + self.chunks.read().unwrap()[row_ptr.chunk_index as usize].next_ptr + [row_ptr.row_index as usize] + } + + pub fn gather(&self, row_ptrs: &[RowPtr]) -> Result { + let _row_count = row_ptrs.len(); + let mut data_blocks = vec![]; + + { + // Acquire read lock in current scope + let chunks = self.chunks.read().unwrap(); + for row_ptr in row_ptrs.iter() { + assert!((row_ptr.chunk_index as usize) < chunks.len()); + let block = self.gather_single_chunk(&chunks[row_ptr.chunk_index as usize], &[ + row_ptr.row_index, + ])?; + if !block.is_empty() { + data_blocks.push(block); + } + } + } + + if !data_blocks.is_empty() { + let data_block = DataBlock::concat_blocks(&data_blocks)?; + Ok(data_block) + } else { + Ok(DataBlock::empty_with_schema(self.data_schema.clone())) + } + } + + fn gather_single_chunk(&self, chunk: &Chunk, indices: &[u32]) -> Result { + DataBlock::block_take_by_indices(&chunk.data_block, indices) + } +} + +pub fn compare_and_combine( + probe_input: DataBlock, + probe_result: DataBlock, + build_keys: &[ColumnRef], + probe_keys: &[ColumnRef], + ctx: Arc, +) -> Result { + assert_eq!(build_keys.len(), probe_keys.len()); + let mut compare_exprs: Vec = Vec::with_capacity(build_keys.len()); + let mut data_fields: Vec = Vec::with_capacity(build_keys.len() + probe_keys.len()); + let mut columns: Vec = Vec::with_capacity(build_keys.len() + probe_keys.len()); + for (idx, (build_key, probe_key)) in build_keys.iter().zip(probe_keys).enumerate() { + let build_key_name = format!("build_key_{idx}"); + let probe_key_name = format!("probe_key_{idx}"); + + let build_key_data_type = build_key.data_type(); + let probe_key_data_type = probe_key.data_type(); + + columns.push(build_key.clone()); + columns.push(probe_key.clone()); + + data_fields.push(DataField::new(build_key_name.as_str(), build_key_data_type)); + data_fields.push(DataField::new(probe_key_name.as_str(), probe_key_data_type)); + + let compare_expr = Expression::BinaryExpression { + left: Box::new(Expression::Column(build_key_name.clone())), + right: Box::new(Expression::Column(probe_key_name.clone())), + op: "=".to_string(), + }; + compare_exprs.push(compare_expr); + } + + let predicate = compare_exprs + .into_iter() + .reduce(|prev, next| Expression::BinaryExpression { + left: Box::new(prev), + op: "and".to_string(), + right: Box::new(next), + }) + .unwrap(); + + let data_block = DataBlock::create(Arc::new(DataSchema::new(data_fields)), columns); + + let filter = + ExpressionEvaluator::eval(ctx.try_get_function_context()?, &predicate, &data_block)?; + + let mut produce_block = probe_input; + for (col, field) in probe_result + .columns() + .iter() + .zip(probe_result.schema().fields().iter()) + { + produce_block = produce_block.add_column(col.clone(), field.clone())?; + } + produce_block = DataBlock::filter_block(&produce_block, &filter)?; + + Ok(produce_block) +} diff --git a/query/src/pipelines/new/processors/transforms/mod.rs b/query/src/pipelines/new/processors/transforms/mod.rs index 08c2b33ba756..035708e242f2 100644 --- a/query/src/pipelines/new/processors/transforms/mod.rs +++ b/query/src/pipelines/new/processors/transforms/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. mod aggregator; +mod hash_join; mod transform; mod transform_addon; mod transform_aggregator; @@ -23,6 +24,7 @@ mod transform_create_sets; mod transform_dummy; mod transform_expression; mod transform_filter; +mod transform_hash_join; mod transform_limit; mod transform_limit_by; mod transform_sort_merge; @@ -30,6 +32,8 @@ mod transform_sort_partial; pub use aggregator::AggregatorParams; pub use aggregator::AggregatorTransformParams; +pub use hash_join::ChainHashTable; +pub use hash_join::HashJoinState; pub use transform_addon::TransformAddOn; pub use transform_aggregator::TransformAggregator; pub use transform_block_compact::BlockCompactor; @@ -44,6 +48,8 @@ pub use transform_expression::ExpressionTransform; pub use transform_expression::ProjectionTransform; pub use transform_filter::TransformFilter; pub use transform_filter::TransformHaving; +pub use transform_hash_join::SinkBuildHashTable; +pub use transform_hash_join::TransformHashJoinProbe; pub use transform_limit::TransformLimit; pub use transform_limit_by::TransformLimitBy; pub use transform_sort_merge::SortMergeCompactor; diff --git a/query/src/pipelines/new/processors/transforms/transform_hash_join.rs b/query/src/pipelines/new/processors/transforms/transform_hash_join.rs new file mode 100644 index 000000000000..16798f8922a9 --- /dev/null +++ b/query/src/pipelines/new/processors/transforms/transform_hash_join.rs @@ -0,0 +1,151 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_datablocks::DataBlock; +use common_datavalues::DataSchemaRef; +use common_exception::Result; + +use crate::pipelines::new::processors::port::InputPort; +use crate::pipelines::new::processors::port::OutputPort; +use crate::pipelines::new::processors::processor::Event; +use crate::pipelines::new::processors::processor::ProcessorPtr; +use crate::pipelines::new::processors::transforms::hash_join::HashJoinState; +use crate::pipelines::new::processors::Processor; +use crate::pipelines::new::processors::Sink; +use crate::sessions::QueryContext; + +pub struct SinkBuildHashTable { + join_state: Arc, +} + +impl SinkBuildHashTable { + pub fn try_create(join_state: Arc) -> Result { + join_state.attach()?; + Ok(Self { join_state }) + } +} + +impl Sink for SinkBuildHashTable { + const NAME: &'static str = "BuildHashTable"; + + fn on_finish(&mut self) -> Result<()> { + self.join_state.detach() + } + + fn consume(&mut self, data_block: DataBlock) -> Result<()> { + self.join_state.build(data_block) + } +} + +enum HashJoinStep { + Build, + Probe, +} + +pub struct TransformHashJoinProbe { + input_data: Option, + output_data_blocks: Vec, + + input_port: Arc, + output_port: Arc, + step: HashJoinStep, + join_state: Arc, +} + +impl TransformHashJoinProbe { + pub fn create( + _ctx: Arc, + input_port: Arc, + output_port: Arc, + join_state: Arc, + _output_schema: DataSchemaRef, + ) -> ProcessorPtr { + ProcessorPtr::create(Box::new(TransformHashJoinProbe { + input_data: None, + output_data_blocks: vec![], + input_port, + output_port, + step: HashJoinStep::Build, + join_state, + })) + } + + fn probe(&mut self, block: &DataBlock) -> Result<()> { + self.output_data_blocks + .append(&mut self.join_state.probe(block)?); + Ok(()) + } +} + +impl Processor for TransformHashJoinProbe { + fn name(&self) -> &'static str { + static NAME: &str = "TransformHashJoin"; + NAME + } + + fn event(&mut self) -> Result { + match self.step { + HashJoinStep::Build => { + if self.join_state.is_finished()? { + self.step = HashJoinStep::Probe; + Ok(Event::Sync) + } else { + Ok(Event::NeedData) + } + } + HashJoinStep::Probe => { + if self.output_port.is_finished() { + self.input_port.finish(); + return Ok(Event::Finished); + } + + if !self.output_port.can_push() { + self.input_port.set_not_need_data(); + return Ok(Event::NeedConsume); + } + + if !self.output_data_blocks.is_empty() { + self.output_port + .push_data(Ok(self.output_data_blocks.remove(0))); + } + + if self.input_data.is_some() { + return Ok(Event::Sync); + } + + if let Some(data) = self.input_port.pull_data() { + self.input_data = Some(data?); + return Ok(Event::Sync); + } + + self.input_port.set_need_data(); + Ok(Event::NeedData) + } + } + } + + fn process(&mut self) -> Result<()> { + match self.step { + HashJoinStep::Build => Ok(()), + HashJoinStep::Probe => { + if let Some(data) = self.input_data.take() { + self.probe(&data)?; + } + Ok(()) + } + } + } +} diff --git a/query/src/servers/mysql/mysql_interactive_worker.rs b/query/src/servers/mysql/mysql_interactive_worker.rs index 4ce4af49c760..53079372d0af 100644 --- a/query/src/servers/mysql/mysql_interactive_worker.rs +++ b/query/src/servers/mysql/mysql_interactive_worker.rs @@ -22,7 +22,6 @@ use common_exception::ErrorCode; use common_exception::Result; use common_exception::ToErrorCode; use common_io::prelude::*; -use common_planners::PlanNode; use common_tracing::tracing; use common_tracing::tracing::Instrument; use metrics::histogram; @@ -45,6 +44,8 @@ use crate::servers::mysql::MySQLFederated; use crate::servers::mysql::MYSQL_VERSION; use crate::sessions::QueryContext; use crate::sessions::SessionRef; +use crate::sql::DfParser; +use crate::sql::DfStatement; use crate::sql::PlanParser; use crate::users::CertifiedInfo; @@ -282,40 +283,42 @@ impl InteractiveWorkerBase { let context = self.session.create_query_context().await?; context.attach_query_str(query); - let (plan, hints) = PlanParser::parse_with_hint(query, context.clone()).await; - if let (Some(hint_error_code), Err(error_code)) = ( - hints - .iter() - .find(|v| v.error_code.is_some()) - .and_then(|x| x.error_code), - &plan, - ) { - // Pre-check if parsing error can be ignored - if hint_error_code == error_code.code() { - return Ok((vec![DataBlock::empty()], String::from(""))); - } - } - - let plan = match plan { - Ok(p) => p, - Err(e) => { - InterpreterQueryLog::fail_to_start(context, e.clone()).await; - return Err(e); - } - }; - tracing::debug!("Get logic plan:\n{:?}", plan); - let settings = context.get_settings(); + let (stmts, hints) = + DfParser::parse_sql(query, context.get_current_session().get_type())?; + let interpreter: Arc = if settings.get_enable_new_processor_framework()? != 0 && context.get_cluster().is_empty() && settings.get_enable_planner_v2()? != 0 - && matches!(plan, PlanNode::Select(..)) + && matches!(stmts.get(0), Some(DfStatement::Query(_))) { // New planner is enabled, and the statement is ensured to be `SELECT` statement. SelectInterpreterV2::try_create(context.clone(), query)? } else { + let (plan, _) = PlanParser::parse_with_hint(query, context.clone()).await; + if let (Some(hint_error_code), Err(error_code)) = ( + hints + .iter() + .find(|v| v.error_code.is_some()) + .and_then(|x| x.error_code), + &plan, + ) { + // Pre-check if parsing error can be ignored + if hint_error_code == error_code.code() { + return Ok((vec![DataBlock::empty()], String::from(""))); + } + } + + let plan = match plan { + Ok(p) => p, + Err(e) => { + InterpreterQueryLog::fail_to_start(context, e.clone()).await; + return Err(e); + } + }; + tracing::debug!("Get logic plan:\n{:?}", plan); InterpreterFactory::get(context.clone(), plan)? }; diff --git a/query/src/sql/exec/data_schema_builder.rs b/query/src/sql/exec/data_schema_builder.rs index 5a28e45b2ad9..b0a5fca5d0e0 100644 --- a/query/src/sql/exec/data_schema_builder.rs +++ b/query/src/sql/exec/data_schema_builder.rs @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use common_datavalues::DataField; -use common_datavalues::DataSchema; use common_datavalues::DataSchemaRef; +use common_datavalues::DataSchemaRefExt; use common_datavalues::DataTypeImpl; use common_exception::ErrorCode; use common_exception::Result; @@ -55,7 +53,7 @@ impl<'a> DataSchemaBuilder<'a> { }; new_data_fields.push(field); } - Ok(Arc::new(DataSchema::new(new_data_fields))) + Ok(DataSchemaRefExt::create(new_data_fields)) } pub fn build_project( @@ -72,7 +70,7 @@ impl<'a> DataSchemaBuilder<'a> { fields.push(field); } - Ok(Arc::new(DataSchema::new(fields))) + Ok(DataSchemaRefExt::create(fields)) } pub fn build_physical_scan(&self, plan: &PhysicalScan) -> Result { @@ -89,7 +87,7 @@ impl<'a> DataSchemaBuilder<'a> { fields.push(field); } - Ok(Arc::new(DataSchema::new(fields))) + Ok(DataSchemaRefExt::create(fields)) } pub fn build_canonical_schema(&self, columns: &[IndexType]) -> DataSchemaRef { @@ -106,7 +104,7 @@ impl<'a> DataSchemaBuilder<'a> { fields.push(field); } - Arc::new(DataSchema::new(fields)) + DataSchemaRefExt::create(fields) } pub fn build_group_by( @@ -133,7 +131,7 @@ impl<'a> DataSchemaBuilder<'a> { }; fields.push(field); } - Ok(Arc::new(DataSchema::new(fields))) + Ok(DataSchemaRefExt::create(fields)) } pub fn build_agg_func( @@ -174,10 +172,23 @@ impl<'a> DataSchemaBuilder<'a> { _ => { return Err(ErrorCode::LogicalError( "Expression must be aggregated function", - )) + )); } } } - Ok((Arc::new(DataSchema::new(fields)), agg_inner_expressions)) + Ok((DataSchemaRefExt::create(fields), agg_inner_expressions)) + } + + pub fn build_join(&self, left: DataSchemaRef, right: DataSchemaRef) -> DataSchemaRef { + // TODO: NATURAL JOIN and USING + let mut fields = Vec::with_capacity(left.num_fields() + right.num_fields()); + for field in left.fields().iter() { + fields.push(field.clone()); + } + for field in right.fields().iter() { + fields.push(field.clone()); + } + + DataSchemaRefExt::create(fields) } } diff --git a/query/src/sql/exec/mod.rs b/query/src/sql/exec/mod.rs index 5199e7a1bffd..d8276af53303 100644 --- a/query/src/sql/exec/mod.rs +++ b/query/src/sql/exec/mod.rs @@ -31,20 +31,29 @@ pub use util::decode_field_name; pub use util::format_field_name; use super::plans::BasePlan; +use crate::pipelines::new::processors::port::InputPort; use crate::pipelines::new::processors::AggregatorParams; use crate::pipelines::new::processors::AggregatorTransformParams; +use crate::pipelines::new::processors::ChainHashTable; use crate::pipelines::new::processors::ExpressionTransform; +use crate::pipelines::new::processors::HashJoinState; use crate::pipelines::new::processors::ProjectionTransform; +use crate::pipelines::new::processors::SinkBuildHashTable; +use crate::pipelines::new::processors::Sinker; use crate::pipelines::new::processors::TransformAggregator; use crate::pipelines::new::processors::TransformFilter; +use crate::pipelines::new::processors::TransformHashJoinProbe; use crate::pipelines::new::NewPipeline; +use crate::pipelines::new::SinkPipeBuilder; use crate::sessions::QueryContext; use crate::sql::exec::data_schema_builder::DataSchemaBuilder; use crate::sql::exec::expression_builder::ExpressionBuilder; use crate::sql::exec::util::check_physical; use crate::sql::optimizer::SExpr; use crate::sql::plans::AggregatePlan; +use crate::sql::plans::AndExpr; use crate::sql::plans::FilterPlan; +use crate::sql::plans::PhysicalHashJoin; use crate::sql::plans::PhysicalScan; use crate::sql::plans::PlanType; use crate::sql::plans::ProjectPlan; @@ -57,7 +66,8 @@ pub struct PipelineBuilder { metadata: Metadata, result_columns: Vec<(IndexType, String)>, expression: SExpr, - pipeline: NewPipeline, + + pipelines: Vec, } impl PipelineBuilder { @@ -72,21 +82,29 @@ impl PipelineBuilder { metadata, result_columns, expression, - pipeline: NewPipeline::create(), + + pipelines: vec![], } } - pub fn spawn(mut self) -> Result { + pub fn spawn(mut self) -> Result<(NewPipeline, Vec)> { let expr = self.expression.clone(); - let schema = self.build_pipeline(&expr)?; - self.align_data_schema(schema)?; + let mut pipeline = NewPipeline::create(); + let schema = self.build_pipeline(self.ctx.clone(), &expr, &mut pipeline)?; + self.align_data_schema(schema, &mut pipeline)?; let settings = self.ctx.get_settings(); - self.pipeline - .set_max_threads(settings.get_max_threads()? as usize); - Ok(self.pipeline) + pipeline.set_max_threads(settings.get_max_threads()? as usize); + for pipeline in self.pipelines.iter_mut() { + pipeline.set_max_threads(settings.get_max_threads()? as usize); + } + Ok((pipeline, self.pipelines)) } - fn align_data_schema(&mut self, input_schema: DataSchemaRef) -> Result<()> { + fn align_data_schema( + &mut self, + input_schema: DataSchemaRef, + pipeline: &mut NewPipeline, + ) -> Result<()> { let mut projections = Vec::with_capacity(self.result_columns.len()); let mut output_fields = Vec::with_capacity(self.result_columns.len()); for (index, name) in self.result_columns.iter() { @@ -104,21 +122,25 @@ impl PipelineBuilder { } let output_schema = Arc::new(DataSchema::new(output_fields)); - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - ProjectionTransform::try_create( - transform_input_port, - transform_output_port, - input_schema.clone(), - output_schema.clone(), - projections.clone(), - self.ctx.clone(), - ) - })?; + pipeline.add_transform(|transform_input_port, transform_output_port| { + ProjectionTransform::try_create( + transform_input_port, + transform_output_port, + input_schema.clone(), + output_schema.clone(), + projections.clone(), + self.ctx.clone(), + ) + })?; Ok(()) } - fn build_pipeline(&mut self, expression: &SExpr) -> Result { + fn build_pipeline( + &mut self, + context: Arc, + expression: &SExpr, + pipeline: &mut NewPipeline, + ) -> Result { if !check_physical(expression) { return Err(ErrorCode::LogicalError("Invalid physical plan")); } @@ -128,22 +150,43 @@ impl PipelineBuilder { match plan.plan_type() { PlanType::PhysicalScan => { let physical_scan: PhysicalScan = plan.try_into()?; - self.build_physical_scan(&physical_scan) + self.build_physical_scan(&physical_scan, pipeline) } PlanType::Project => { let project: ProjectPlan = plan.try_into()?; - let input_schema = self.build_pipeline(&expression.children()[0])?; - self.build_project(&project, input_schema) + let input_schema = + self.build_pipeline(context, &expression.children()[0], pipeline)?; + self.build_project(&project, input_schema, pipeline) } PlanType::Filter => { let filter: FilterPlan = plan.try_into()?; - let input_schema = self.build_pipeline(&expression.children()[0])?; - self.build_filter(&filter, input_schema) + let input_schema = + self.build_pipeline(context, &expression.children()[0], pipeline)?; + self.build_filter(&filter, input_schema, pipeline) } PlanType::Aggregate => { let aggregate: AggregatePlan = plan.try_into()?; - let input_schema = self.build_pipeline(&expression.children()[0])?; - self.build_aggregate(&aggregate, input_schema) + let input_schema = + self.build_pipeline(context, &expression.children()[0], pipeline)?; + self.build_aggregate(&aggregate, input_schema, pipeline) + } + PlanType::PhysicalHashJoin => { + let hash_join: PhysicalHashJoin = plan.try_into()?; + let probe_schema = + self.build_pipeline(context.clone(), &expression.children()[0], pipeline)?; + let mut child_pipeline = NewPipeline::create(); + let build_schema = self.build_pipeline( + QueryContext::create_from(context), + &expression.children()[1], + &mut child_pipeline, + )?; + self.build_hash_join( + &hash_join, + build_schema, + probe_schema, + child_pipeline, + pipeline, + ) } _ => Err(ErrorCode::LogicalError("Invalid physical plan")), } @@ -153,6 +196,7 @@ impl PipelineBuilder { &mut self, project: &ProjectPlan, input_schema: DataSchemaRef, + pipeline: &mut NewPipeline, ) -> Result { let schema_builder = DataSchemaBuilder::new(&self.metadata); let output_schema = schema_builder.build_project(project, input_schema.clone())?; @@ -163,17 +207,16 @@ impl PipelineBuilder { let expression = expr_builder.build_and_rename(scalar, item.index)?; expressions.push(expression); } - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - ProjectionTransform::try_create( - transform_input_port, - transform_output_port, - input_schema.clone(), - output_schema.clone(), - expressions.clone(), - self.ctx.clone(), - ) - })?; + pipeline.add_transform(|transform_input_port, transform_output_port| { + ProjectionTransform::try_create( + transform_input_port, + transform_output_port, + input_schema.clone(), + output_schema.clone(), + expressions.clone(), + self.ctx.clone(), + ) + })?; Ok(output_schema) } @@ -182,11 +225,19 @@ impl PipelineBuilder { &mut self, filter: &FilterPlan, input_schema: DataSchemaRef, + pipeline: &mut NewPipeline, ) -> Result { let output_schema = input_schema.clone(); let eb = ExpressionBuilder::create(&self.metadata); - let scalar = &filter.predicate; - let mut pred = eb.build(scalar)?; + let scalars = &filter.predicates; + let pred = scalars.iter().cloned().reduce(|acc, v| { + AndExpr { + left: Box::new(acc), + right: Box::new(v), + } + .into() + }); + let mut pred = eb.build(&pred.unwrap())?; let no_agg_expression = find_aggregate_exprs_in_expr(&pred).is_empty(); if !no_agg_expression && !filter.is_having { return Err(ErrorCode::SyntaxException( @@ -196,27 +247,29 @@ impl PipelineBuilder { if !no_agg_expression && filter.is_having { pred = eb.normalize_aggr_to_col(pred.clone())?; } - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - TransformFilter::try_create( - input_schema.clone(), - pred.clone(), - transform_input_port, - transform_output_port, - self.ctx.clone(), - ) - })?; + pipeline.add_transform(|transform_input_port, transform_output_port| { + TransformFilter::try_create( + input_schema.clone(), + pred.clone(), + transform_input_port, + transform_output_port, + self.ctx.clone(), + ) + })?; Ok(output_schema) } - fn build_physical_scan(&mut self, scan: &PhysicalScan) -> Result { + fn build_physical_scan( + &mut self, + scan: &PhysicalScan, + pipeline: &mut NewPipeline, + ) -> Result { let table_entry = self.metadata.table(scan.table_index); let plan = table_entry.source.clone(); let table = self.ctx.build_table_from_source_plan(&plan)?; self.ctx.try_set_partitions(plan.parts.clone())?; - table.read2(self.ctx.clone(), &plan, &mut self.pipeline)?; - + table.read2(self.ctx.clone(), &plan, pipeline)?; let columns: Vec = scan.columns.iter().cloned().collect(); let projections: Vec = columns .iter() @@ -232,17 +285,16 @@ impl PipelineBuilder { let input_schema = schema_builder.build_canonical_schema(&columns); let output_schema = schema_builder.build_physical_scan(scan)?; - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - ProjectionTransform::try_create( - transform_input_port, - transform_output_port, - input_schema.clone(), - output_schema.clone(), - projections.clone(), - self.ctx.clone(), - ) - })?; + pipeline.add_transform(|transform_input_port, transform_output_port| { + ProjectionTransform::try_create( + transform_input_port, + transform_output_port, + input_schema.clone(), + output_schema.clone(), + projections.clone(), + self.ctx.clone(), + ) + })?; Ok(output_schema) } @@ -251,6 +303,7 @@ impl PipelineBuilder { &mut self, aggregate: &AggregatePlan, input_schema: DataSchemaRef, + pipeline: &mut NewPipeline, ) -> Result { let mut agg_expressions = Vec::with_capacity(aggregate.agg_expr.len()); let expr_builder = ExpressionBuilder::create(&self.metadata); @@ -277,34 +330,32 @@ impl PipelineBuilder { let pre_input_schema = input_schema.clone(); let input_schema = schema_builder.build_group_by(input_schema, group_expressions.as_slice())?; - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - ExpressionTransform::try_create( - transform_input_port, - transform_output_port, - pre_input_schema.clone(), - input_schema.clone(), - group_expressions.clone(), - self.ctx.clone(), - ) - })?; + pipeline.add_transform(|transform_input_port, transform_output_port| { + ExpressionTransform::try_create( + transform_input_port, + transform_output_port, + pre_input_schema.clone(), + input_schema.clone(), + group_expressions.clone(), + self.ctx.clone(), + ) + })?; // Process aggregation function with non-column expression, such as sum(3) let pre_input_schema = input_schema.clone(); let res = schema_builder.build_agg_func(pre_input_schema.clone(), agg_expressions.as_slice())?; let input_schema = res.0; - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - ExpressionTransform::try_create( - transform_input_port, - transform_output_port, - pre_input_schema.clone(), - input_schema.clone(), - res.1.clone(), - self.ctx.clone(), - ) - })?; + pipeline.add_transform(|transform_input_port, transform_output_port| { + ExpressionTransform::try_create( + transform_input_port, + transform_output_port, + pre_input_schema.clone(), + input_schema.clone(), + res.1.clone(), + self.ctx.clone(), + ) + })?; // Get partial schema from agg_expressions let partial_data_fields = @@ -324,41 +375,111 @@ impl PipelineBuilder { &input_schema, &partial_schema, )?; - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - TransformAggregator::try_create_partial( - transform_input_port.clone(), - transform_output_port.clone(), - AggregatorTransformParams::try_create( - transform_input_port, - transform_output_port, - &partial_aggr_params, - )?, - self.ctx.clone(), - ) - })?; + pipeline.add_transform(|transform_input_port, transform_output_port| { + TransformAggregator::try_create_partial( + transform_input_port.clone(), + transform_output_port.clone(), + AggregatorTransformParams::try_create( + transform_input_port, + transform_output_port, + &partial_aggr_params, + )?, + self.ctx.clone(), + ) + })?; - self.pipeline.resize(1)?; + pipeline.resize(1)?; let final_aggr_params = AggregatorParams::try_create_v2( &agg_expressions, &group_expressions, &input_schema, &final_schema, )?; - self.pipeline - .add_transform(|transform_input_port, transform_output_port| { - TransformAggregator::try_create_final( - transform_input_port.clone(), - transform_output_port.clone(), - AggregatorTransformParams::try_create( - transform_input_port, - transform_output_port, - &final_aggr_params, - )?, - self.ctx.clone(), - ) - })?; + + pipeline.add_transform(|transform_input_port, transform_output_port| { + TransformAggregator::try_create_final( + transform_input_port.clone(), + transform_output_port.clone(), + AggregatorTransformParams::try_create( + transform_input_port, + transform_output_port, + &final_aggr_params, + )?, + self.ctx.clone(), + ) + })?; Ok(final_schema) } + + fn build_hash_join( + &mut self, + hash_join: &PhysicalHashJoin, + build_schema: DataSchemaRef, + probe_schema: DataSchemaRef, + mut child_pipeline: NewPipeline, + pipeline: &mut NewPipeline, + ) -> Result { + let builder = DataSchemaBuilder::new(&self.metadata); + let output_schema = builder.build_join(probe_schema.clone(), build_schema.clone()); + + let eb = ExpressionBuilder::create(&self.metadata); + let build_expressions = hash_join + .build_keys + .iter() + .map(|scalar| eb.build(scalar)) + .collect::>>()?; + let probe_expressions = hash_join + .probe_keys + .iter() + .map(|scalar| eb.build(scalar)) + .collect::>>()?; + + let hash_join_state = Arc::new(ChainHashTable::try_create( + build_expressions, + probe_expressions, + build_schema, + probe_schema, + self.ctx.clone(), + )?); + + // Build side + self.build_sink_hash_table(hash_join_state.clone(), &mut child_pipeline)?; + + // Probe side + pipeline.add_transform(|input, output| { + Ok(TransformHashJoinProbe::create( + self.ctx.clone(), + input, + output, + hash_join_state.clone(), + output_schema.clone(), + )) + })?; + + self.pipelines.push(child_pipeline); + + Ok(output_schema) + } + + fn build_sink_hash_table( + &mut self, + state: Arc, + pipeline: &mut NewPipeline, + ) -> Result<()> { + let mut sink_pipeline_builder = SinkPipeBuilder::create(); + for _ in 0..pipeline.output_len() { + let input_port = InputPort::create(); + sink_pipeline_builder.add_sink( + input_port.clone(), + Sinker::::create( + input_port, + SinkBuildHashTable::try_create(state.clone())?, + ), + ); + } + + pipeline.add_pipe(sink_pipeline_builder.finalize()); + Ok(()) + } } diff --git a/query/src/sql/optimizer/heuristic/implement.rs b/query/src/sql/optimizer/heuristic/implement.rs index c65783184e1b..6453fcad87e3 100644 --- a/query/src/sql/optimizer/heuristic/implement.rs +++ b/query/src/sql/optimizer/heuristic/implement.rs @@ -21,7 +21,8 @@ use crate::sql::optimizer::rule::TransformState; use crate::sql::optimizer::SExpr; lazy_static! { - static ref DEFAULT_IMPLEMENT_RULES: Vec = vec![RuleID::ImplementGet]; + static ref DEFAULT_IMPLEMENT_RULES: Vec = + vec![RuleID::ImplementGet, RuleID::ImplementHashJoin]; } pub struct HeuristicImplementor { diff --git a/query/src/sql/optimizer/rule/factory.rs b/query/src/sql/optimizer/rule/factory.rs index e886ba6a339b..b6bcc03d2928 100644 --- a/query/src/sql/optimizer/rule/factory.rs +++ b/query/src/sql/optimizer/rule/factory.rs @@ -15,6 +15,7 @@ use common_exception::Result; use crate::sql::optimizer::rule::rule_implement_get::RuleImplementGet; +use crate::sql::optimizer::rule::rule_implement_hash_join::RuleImplementHashJoin; use crate::sql::optimizer::rule::RuleID; use crate::sql::optimizer::rule::RulePtr; @@ -28,6 +29,7 @@ impl RuleFactory { pub fn create_rule(&self, id: RuleID) -> Result { match id { RuleID::ImplementGet => Ok(Box::new(RuleImplementGet::create())), + RuleID::ImplementHashJoin => Ok(Box::new(RuleImplementHashJoin::create())), } } } diff --git a/query/src/sql/optimizer/rule/mod.rs b/query/src/sql/optimizer/rule/mod.rs index 2ffe8afe0133..05dbb83b8105 100644 --- a/query/src/sql/optimizer/rule/mod.rs +++ b/query/src/sql/optimizer/rule/mod.rs @@ -18,6 +18,7 @@ use crate::sql::optimizer::SExpr; mod factory; mod rule_implement_get; +mod rule_implement_hash_join; mod rule_set; mod transform_state; @@ -38,12 +39,14 @@ pub trait Rule { #[derive(Copy, Clone, Eq, PartialEq, Hash)] pub enum RuleID { ImplementGet, + ImplementHashJoin, } impl RuleID { pub fn name(&self) -> &'static str { match self { RuleID::ImplementGet => "ImplementGet", + RuleID::ImplementHashJoin => "ImplementHashJoin", } } @@ -53,6 +56,7 @@ impl RuleID { pub fn uid(&self) -> u32 { match self { RuleID::ImplementGet => 0, + RuleID::ImplementHashJoin => 1, } } } diff --git a/query/src/sql/optimizer/rule/rule_implement_hash_join.rs b/query/src/sql/optimizer/rule/rule_implement_hash_join.rs new file mode 100644 index 000000000000..08cb784cf848 --- /dev/null +++ b/query/src/sql/optimizer/rule/rule_implement_hash_join.rs @@ -0,0 +1,83 @@ +// Copyright 2021 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_exception::Result; + +use crate::sql::optimizer::rule::transform_state::TransformState; +use crate::sql::optimizer::rule::Rule; +use crate::sql::optimizer::rule::RuleID; +use crate::sql::optimizer::SExpr; +use crate::sql::plans::LogicalInnerJoin; +use crate::sql::plans::PatternPlan; +use crate::sql::plans::PhysicalHashJoin; +use crate::sql::plans::PlanType; + +pub struct RuleImplementHashJoin { + id: RuleID, + pattern: SExpr, +} + +impl RuleImplementHashJoin { + pub fn create() -> Self { + RuleImplementHashJoin { + id: RuleID::ImplementHashJoin, + pattern: SExpr::create_binary( + PatternPlan { + plan_type: PlanType::LogicalInnerJoin, + } + .into(), + SExpr::create_leaf( + PatternPlan { + plan_type: PlanType::Pattern, + } + .into(), + ), + SExpr::create_leaf( + PatternPlan { + plan_type: PlanType::Pattern, + } + .into(), + ), + ), + } + } +} + +impl Rule for RuleImplementHashJoin { + fn id(&self) -> RuleID { + self.id + } + + fn apply(&self, expression: &SExpr, state: &mut TransformState) -> Result<()> { + let plan = expression.plan(); + let logical_inner_join: LogicalInnerJoin = plan.try_into()?; + + let result = SExpr::create( + PhysicalHashJoin { + build_keys: logical_inner_join.right_conditions, + probe_keys: logical_inner_join.left_conditions, + } + .into(), + expression.children().to_vec(), + expression.original_group(), + ); + state.add_result(result); + + Ok(()) + } + + fn pattern(&self) -> &SExpr { + &self.pattern + } +} diff --git a/query/src/sql/planner/binder/bind_context.rs b/query/src/sql/planner/binder/bind_context.rs index b484c4768446..5b147d3976d1 100644 --- a/query/src/sql/planner/binder/bind_context.rs +++ b/query/src/sql/planner/binder/bind_context.rs @@ -53,7 +53,7 @@ pub struct BindContext { } impl BindContext { - pub fn create() -> Self { + pub fn new() -> Self { Self::default() } diff --git a/query/src/sql/planner/binder/join.rs b/query/src/sql/planner/binder/join.rs new file mode 100644 index 000000000000..071924384b96 --- /dev/null +++ b/query/src/sql/planner/binder/join.rs @@ -0,0 +1,249 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use async_recursion::async_recursion; +use common_ast::ast::Expr; +use common_ast::ast::Join; +use common_ast::ast::JoinCondition; +use common_ast::ast::JoinOperator; +use common_exception::ErrorCode; +use common_exception::Result; + +use crate::sql::binder::scalar_common::split_conjunctions; +use crate::sql::binder::scalar_common::split_equivalent_predicate; +use crate::sql::optimizer::ColumnSet; +use crate::sql::optimizer::SExpr; +use crate::sql::planner::binder::scalar::ScalarBinder; +use crate::sql::planner::binder::Binder; +use crate::sql::plans::FilterPlan; +use crate::sql::plans::LogicalInnerJoin; +use crate::sql::plans::Scalar; +use crate::sql::plans::ScalarExpr; +use crate::sql::BindContext; + +impl Binder { + #[async_recursion] + pub(super) async fn bind_join( + &mut self, + bind_context: &BindContext, + join: &Join, + ) -> Result { + let left_child = self.bind_table_reference(&join.left, bind_context).await?; + let right_child = self.bind_table_reference(&join.right, bind_context).await?; + + let mut bind_context = BindContext::new(); + for column in left_child.all_column_bindings() { + bind_context.add_column_binding(column.clone()); + } + for column in right_child.all_column_bindings() { + bind_context.add_column_binding(column.clone()); + } + + let mut left_join_conditions: Vec = vec![]; + let mut right_join_conditions: Vec = vec![]; + let mut other_conditions: Vec = vec![]; + let join_condition_resolver = + JoinConditionResolver::new(&left_child, &right_child, &bind_context, &join.condition); + join_condition_resolver.resolve( + &mut left_join_conditions, + &mut right_join_conditions, + &mut other_conditions, + )?; + + match &join.op { + JoinOperator::Inner => { + bind_context = self.bind_inner_join( + left_join_conditions, + right_join_conditions, + bind_context, + left_child.expression.unwrap(), + right_child.expression.unwrap(), + )?; + } + JoinOperator::LeftOuter => { + return Err(ErrorCode::UnImplement( + "Unsupported join type: LEFT OUTER JOIN", + )); + } + JoinOperator::RightOuter => { + return Err(ErrorCode::UnImplement( + "Unsupported join type: RIGHT OUTER JOIN", + )); + } + JoinOperator::FullOuter => { + return Err(ErrorCode::UnImplement( + "Unsupported join type: FULL OUTER JOIN", + )); + } + JoinOperator::CrossJoin => { + return Err(ErrorCode::UnImplement("Unsupported join type: CROSS JOIN")); + } + } + + if !other_conditions.is_empty() { + let filter_plan = FilterPlan { + predicates: other_conditions, + is_having: false, + }; + let new_expr = + SExpr::create_unary(filter_plan.into(), bind_context.expression.clone().unwrap()); + bind_context.expression = Some(new_expr); + } + + Ok(bind_context) + } + + fn bind_inner_join( + &mut self, + left_conditions: Vec, + right_conditions: Vec, + mut bind_context: BindContext, + left_child: SExpr, + right_child: SExpr, + ) -> Result { + let inner_join = LogicalInnerJoin { + left_conditions, + right_conditions, + }; + let expr = SExpr::create_binary(inner_join.into(), left_child, right_child); + bind_context.expression = Some(expr); + + Ok(bind_context) + } +} + +struct JoinConditionResolver<'a> { + left_context: &'a BindContext, + right_context: &'a BindContext, + join_context: &'a BindContext, + join_condition: &'a JoinCondition, +} + +impl<'a> JoinConditionResolver<'a> { + pub fn new( + left_context: &'a BindContext, + right_context: &'a BindContext, + join_context: &'a BindContext, + join_condition: &'a JoinCondition, + ) -> Self { + Self { + left_context, + right_context, + join_context, + join_condition, + } + } + + pub fn resolve( + &self, + left_join_conditions: &mut Vec, + right_join_conditions: &mut Vec, + other_join_conditions: &mut Vec, + ) -> Result<()> { + match &self.join_condition { + JoinCondition::On(cond) => { + self.resolve_on( + cond, + left_join_conditions, + right_join_conditions, + other_join_conditions, + )?; + } + JoinCondition::Using(_) => { + return Err(ErrorCode::UnImplement("USING clause is not supported yet. Please specify join condition with ON clause.")); + } + JoinCondition::Natural => { + return Err(ErrorCode::UnImplement("NATURAL JOIN is not supported yet. Please specify join condition with ON clause.")); + } + JoinCondition::None => { + return Err(ErrorCode::UnImplement("JOIN without condition is not supported yet. Please specify join condition with ON clause.")); + } + } + Ok(()) + } + + fn resolve_on( + &self, + condition: &Expr, + left_join_conditions: &mut Vec, + right_join_conditions: &mut Vec, + other_join_conditions: &mut Vec, + ) -> Result<()> { + let scalar_binder = ScalarBinder::new(self.join_context); + let (scalar, _) = scalar_binder.bind_expr(condition)?; + let conjunctions = split_conjunctions(&scalar); + + for expr in conjunctions.iter() { + self.resolve_predicate( + expr, + left_join_conditions, + right_join_conditions, + other_join_conditions, + )?; + } + Ok(()) + } + + fn resolve_predicate( + &self, + predicate: &Scalar, + left_join_conditions: &mut Vec, + right_join_conditions: &mut Vec, + other_join_conditions: &mut Vec, + ) -> Result<()> { + // Given two tables: t1(a, b), t2(a, b) + // A predicate can be regarded as an equi-predicate iff: + // + // - The predicate is literally an equivalence expression, e.g. `t1.a = t2.a` + // - Each side of `=` only contains columns from one table and the both sides are disjoint. + // For example, `t1.a + t1.b = t2.a` is a valid one while `t1.a + t2.a = t2.b` isn't. + // + // Only equi-predicate can be exploited by common join algorithms(e.g. sort-merge join, hash join). + // For the predicates that aren't equi-predicate, we will lift them as a `Filter` operator. + if let Some((left, right)) = split_equivalent_predicate(predicate) { + let left_used_columns = left.used_columns(); + let right_used_columns = right.used_columns(); + let left_columns: ColumnSet = self.left_context.all_column_bindings().iter().fold( + ColumnSet::new(), + |mut acc, v| { + acc.insert(v.index); + acc + }, + ); + let right_columns: ColumnSet = self.right_context.all_column_bindings().iter().fold( + ColumnSet::new(), + |mut acc, v| { + acc.insert(v.index); + acc + }, + ); + + // TODO(leiysky): bump types of left conditions and right conditions + if left_used_columns.is_subset(&left_columns) + && right_used_columns.is_subset(&right_columns) + { + left_join_conditions.push(left); + right_join_conditions.push(right); + } else if left_used_columns.is_subset(&right_columns) + && right_used_columns.is_subset(&left_columns) + { + left_join_conditions.push(right); + right_join_conditions.push(left); + } + } else { + other_join_conditions.push(predicate.clone()); + } + Ok(()) + } +} diff --git a/query/src/sql/planner/binder/mod.rs b/query/src/sql/planner/binder/mod.rs index b372dfb5d372..47f62cf2b85f 100644 --- a/query/src/sql/planner/binder/mod.rs +++ b/query/src/sql/planner/binder/mod.rs @@ -27,6 +27,7 @@ use crate::storages::Table; mod aggregate; mod bind_context; +mod join; mod project; mod scalar; mod scalar_common; @@ -56,7 +57,7 @@ impl Binder { } pub async fn bind<'a>(mut self, stmt: &Statement<'a>) -> Result { - let init_bind_context = BindContext::create(); + let init_bind_context = BindContext::new(); let bind_context = self.bind_statement(stmt, &init_bind_context).await?; Ok(BindResult::create(bind_context, self.metadata)) } diff --git a/query/src/sql/planner/binder/project.rs b/query/src/sql/planner/binder/project.rs index 11effbc2e838..e00af52e6ca4 100644 --- a/query/src/sql/planner/binder/project.rs +++ b/query/src/sql/planner/binder/project.rs @@ -73,7 +73,7 @@ impl Binder { select_list: &[SelectTarget], input_context: &BindContext, ) -> Result { - let mut output_context = BindContext::create(); + let mut output_context = BindContext::new(); output_context.expression = input_context.expression.clone(); for select_target in select_list { match select_target { diff --git a/query/src/sql/planner/binder/scalar_common.rs b/query/src/sql/planner/binder/scalar_common.rs index a9918bde0082..b04c4d8538b5 100644 --- a/query/src/sql/planner/binder/scalar_common.rs +++ b/query/src/sql/planner/binder/scalar_common.rs @@ -16,6 +16,9 @@ use common_exception::Result; use crate::sql::binder::scalar_visitor::Recursion; use crate::sql::binder::scalar_visitor::ScalarVisitor; +use crate::sql::plans::AndExpr; +use crate::sql::plans::ComparisonExpr; +use crate::sql::plans::ComparisonOp; use crate::sql::plans::Scalar; use crate::sql::BindContext; @@ -90,3 +93,25 @@ pub fn find_aggregate_scalars_from_bind_context(bind_context: &BindContext) -> R .collect::>(); Ok(find_aggregate_scalars(&scalars)) } + +pub fn split_conjunctions(scalar: &Scalar) -> Vec { + match scalar { + Scalar::AndExpr(AndExpr { left, right }) => { + vec![split_conjunctions(left), split_conjunctions(right)].concat() + } + _ => { + vec![scalar.clone()] + } + } +} + +pub fn split_equivalent_predicate(scalar: &Scalar) -> Option<(Scalar, Scalar)> { + match scalar { + Scalar::ComparisonExpr(ComparisonExpr { op, left, right }) + if *op == ComparisonOp::Equal => + { + Some((*left.clone(), *right.clone())) + } + _ => None, + } +} diff --git a/query/src/sql/planner/binder/select.rs b/query/src/sql/planner/binder/select.rs index cdd431690e3b..537c6d8acfb9 100644 --- a/query/src/sql/planner/binder/select.rs +++ b/query/src/sql/planner/binder/select.rs @@ -25,6 +25,7 @@ use common_exception::ErrorCode; use common_exception::Result; use common_planners::Expression; +use crate::sql::binder::scalar_common::split_conjunctions; use crate::sql::optimizer::SExpr; use crate::sql::planner::binder::scalar::ScalarBinder; use crate::sql::planner::binder::BindContext; @@ -72,7 +73,7 @@ impl Binder { let mut input_context = if let Some(from) = &stmt.from { self.bind_table_reference(from, bind_context).await? } else { - BindContext::create() + BindContext::new() }; if let Some(expr) = &stmt.selection { @@ -100,7 +101,7 @@ impl Binder { Ok(output_context) } - async fn bind_table_reference( + pub(super) async fn bind_table_reference( &mut self, stmt: &TableReference, bind_context: &BindContext, @@ -176,12 +177,13 @@ impl Binder { } Ok(result) } + TableReference::Join(join) => self.bind_join(bind_context, join).await, _ => Err(ErrorCode::UnImplement("Unsupported table reference type")), } } async fn bind_base_table(&mut self, table_index: IndexType) -> Result { - let mut bind_context = BindContext::create(); + let mut bind_context = BindContext::new(); let columns = self.metadata.columns_by_table_index(table_index); let table = self.metadata.table(table_index); for column in columns.iter() { @@ -214,7 +216,7 @@ impl Binder { let scalar_binder = ScalarBinder::new(bind_context); let (scalar, _) = scalar_binder.bind_expr(expr)?; let filter_plan = FilterPlan { - predicate: scalar, + predicates: split_conjunctions(&scalar), is_having, }; let new_expr = diff --git a/query/src/sql/planner/mod.rs b/query/src/sql/planner/mod.rs index 25be3e8a2a3e..127eda6380bd 100644 --- a/query/src/sql/planner/mod.rs +++ b/query/src/sql/planner/mod.rs @@ -47,7 +47,7 @@ impl Planner { Planner { ctx } } - pub async fn plan_sql<'a>(&mut self, sql: &'a str) -> Result { + pub async fn plan_sql<'a>(&mut self, sql: &'a str) -> Result<(NewPipeline, Vec)> { // Step 1: parse SQL text into AST let tokens = tokenize_sql(sql)?; let stmts = parse_sql(&tokens)?; @@ -71,8 +71,8 @@ impl Planner { bind_result.metadata, optimized_expr, ); - let pipeline = pb.spawn()?; + let pipelines = pb.spawn()?; - Ok(pipeline) + Ok(pipelines) } } diff --git a/query/src/sql/planner/plans/filter.rs b/query/src/sql/planner/plans/filter.rs index 19f24e416a37..a080f38402b0 100644 --- a/query/src/sql/planner/plans/filter.rs +++ b/query/src/sql/planner/plans/filter.rs @@ -25,8 +25,7 @@ use crate::sql::plans::Scalar; #[derive(Clone)] pub struct FilterPlan { - // TODO: split predicate into conjunctions - pub predicate: Scalar, + pub predicates: Vec, // True if the plan represents having, else the plan represents where pub is_having: bool, } diff --git a/query/src/sql/planner/plans/hash_join.rs b/query/src/sql/planner/plans/hash_join.rs new file mode 100644 index 000000000000..9ec54fb047e0 --- /dev/null +++ b/query/src/sql/planner/plans/hash_join.rs @@ -0,0 +1,61 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; + +use crate::sql::optimizer::PhysicalProperty; +use crate::sql::optimizer::SExpr; +use crate::sql::plans::BasePlan; +use crate::sql::plans::LogicalPlan; +use crate::sql::plans::PhysicalPlan; +use crate::sql::plans::PlanType; +use crate::sql::plans::Scalar; + +#[derive(Clone)] +pub struct PhysicalHashJoin { + pub build_keys: Vec, + pub probe_keys: Vec, +} + +impl BasePlan for PhysicalHashJoin { + fn plan_type(&self) -> PlanType { + PlanType::PhysicalHashJoin + } + + fn is_physical(&self) -> bool { + true + } + + fn is_logical(&self) -> bool { + false + } + + fn as_physical(&self) -> Option<&dyn PhysicalPlan> { + Some(self) + } + + fn as_logical(&self) -> Option<&dyn LogicalPlan> { + None + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl PhysicalPlan for PhysicalHashJoin { + fn compute_physical_prop(&self, _expression: &SExpr) -> PhysicalProperty { + todo!() + } +} diff --git a/query/src/sql/planner/plans/logical_join.rs b/query/src/sql/planner/plans/logical_join.rs new file mode 100644 index 000000000000..be106c3542b4 --- /dev/null +++ b/query/src/sql/planner/plans/logical_join.rs @@ -0,0 +1,61 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; + +use crate::sql::optimizer::RelationalProperty; +use crate::sql::optimizer::SExpr; +use crate::sql::plans::BasePlan; +use crate::sql::plans::LogicalPlan; +use crate::sql::plans::PhysicalPlan; +use crate::sql::plans::PlanType; +use crate::sql::plans::Scalar; + +#[derive(Clone)] +pub struct LogicalInnerJoin { + pub left_conditions: Vec, + pub right_conditions: Vec, +} + +impl BasePlan for LogicalInnerJoin { + fn plan_type(&self) -> PlanType { + PlanType::LogicalInnerJoin + } + + fn is_physical(&self) -> bool { + false + } + + fn is_logical(&self) -> bool { + true + } + + fn as_physical(&self) -> Option<&dyn PhysicalPlan> { + None + } + + fn as_logical(&self) -> Option<&dyn LogicalPlan> { + Some(self) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl LogicalPlan for LogicalInnerJoin { + fn compute_relational_prop(&self, _expression: &SExpr) -> RelationalProperty { + todo!() + } +} diff --git a/query/src/sql/planner/plans/mod.rs b/query/src/sql/planner/plans/mod.rs index fba715db9d78..4e95f27eefb1 100644 --- a/query/src/sql/planner/plans/mod.rs +++ b/query/src/sql/planner/plans/mod.rs @@ -14,7 +14,9 @@ mod aggregate; mod filter; +mod hash_join; mod logical_get; +mod logical_join; mod pattern; mod physical_scan; mod project; @@ -25,7 +27,9 @@ use std::any::Any; pub use aggregate::AggregatePlan; use enum_dispatch::enum_dispatch; pub use filter::FilterPlan; +pub use hash_join::PhysicalHashJoin; pub use logical_get::LogicalGet; +pub use logical_join::LogicalInnerJoin; pub use pattern::PatternPlan; pub use physical_scan::PhysicalScan; pub use project::ProjectItem; @@ -68,9 +72,11 @@ pub trait PhysicalPlan { pub enum PlanType { // Logical operators LogicalGet, + LogicalInnerJoin, // Physical operators PhysicalScan, + PhysicalHashJoin, // Operators that are both logical and physical Project, @@ -85,9 +91,14 @@ pub enum PlanType { #[derive(Clone)] pub enum BasePlanImpl { LogicalGet(LogicalGet), + LogicalInnerJoin(LogicalInnerJoin), + PhysicalScan(PhysicalScan), + PhysicalHashJoin(PhysicalHashJoin), + Project(ProjectPlan), Filter(FilterPlan), Aggregate(AggregatePlan), + Pattern(PatternPlan), } diff --git a/query/src/sql/planner/plans/scalar.rs b/query/src/sql/planner/plans/scalar.rs index 7ee5b281dce8..eba675b0e443 100644 --- a/query/src/sql/planner/plans/scalar.rs +++ b/query/src/sql/planner/plans/scalar.rs @@ -21,14 +21,14 @@ use common_exception::Result; use enum_dispatch::enum_dispatch; use crate::sql::binder::ColumnBinding; +use crate::sql::optimizer::ColumnSet; #[enum_dispatch] pub trait ScalarExpr { /// Get return type and nullability fn data_type(&self) -> DataTypeImpl; - // TODO: implement this in the future - // fn used_columns(&self) -> ColumnSet; + fn used_columns(&self) -> ColumnSet; // TODO: implement this in the future // fn outer_columns(&self) -> ColumnSet; @@ -60,6 +60,10 @@ impl ScalarExpr for BoundColumnRef { fn data_type(&self) -> DataTypeImpl { self.column.data_type.clone() } + + fn used_columns(&self) -> ColumnSet { + ColumnSet::from([self.column.index]) + } } #[derive(Clone, PartialEq, Debug)] @@ -71,6 +75,10 @@ impl ScalarExpr for ConstantExpr { fn data_type(&self) -> DataTypeImpl { self.value.data_type() } + + fn used_columns(&self) -> ColumnSet { + ColumnSet::new() + } } #[derive(Clone, PartialEq, Debug)] @@ -83,6 +91,12 @@ impl ScalarExpr for AndExpr { fn data_type(&self) -> DataTypeImpl { BooleanType::new_impl() } + + fn used_columns(&self) -> ColumnSet { + let left: ColumnSet = self.left.used_columns(); + let right: ColumnSet = self.right.used_columns(); + left.union(&right).cloned().collect() + } } #[derive(Clone, PartialEq, Debug)] @@ -95,6 +109,12 @@ impl ScalarExpr for OrExpr { fn data_type(&self) -> DataTypeImpl { BooleanType::new_impl() } + + fn used_columns(&self) -> ColumnSet { + let left: ColumnSet = self.left.used_columns(); + let right: ColumnSet = self.right.used_columns(); + left.union(&right).cloned().collect() + } } #[derive(Clone, PartialEq, Debug)] @@ -159,6 +179,12 @@ impl ScalarExpr for ComparisonExpr { fn data_type(&self) -> DataTypeImpl { BooleanType::new_impl() } + + fn used_columns(&self) -> ColumnSet { + let left: ColumnSet = self.left.used_columns(); + let right: ColumnSet = self.right.used_columns(); + left.union(&right).cloned().collect() + } } #[derive(Clone, PartialEq, Debug)] @@ -174,13 +200,20 @@ impl ScalarExpr for AggregateFunction { fn data_type(&self) -> DataTypeImpl { self.return_type.clone() } + + fn used_columns(&self) -> ColumnSet { + let mut result = ColumnSet::new(); + for scalar in self.args.iter() { + result = result.union(&scalar.used_columns()).cloned().collect(); + } + result + } } #[derive(Clone, PartialEq, Debug)] pub struct FunctionCall { pub arguments: Vec, - // pub function: Box, pub func_name: String, pub arg_types: Vec, pub return_type: DataTypeImpl, @@ -190,6 +223,14 @@ impl ScalarExpr for FunctionCall { fn data_type(&self) -> DataTypeImpl { self.return_type.clone() } + + fn used_columns(&self) -> ColumnSet { + let mut result = ColumnSet::new(); + for scalar in self.arguments.iter() { + result = result.union(&scalar.used_columns()).cloned().collect(); + } + result + } } #[derive(Clone, PartialEq, Debug)] @@ -203,4 +244,8 @@ impl ScalarExpr for CastExpr { fn data_type(&self) -> DataTypeImpl { self.target_type.clone() } + + fn used_columns(&self) -> ColumnSet { + self.argument.used_columns() + } } diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result index 7136730a4f20..dde6d6e3fae2 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result @@ -45,3 +45,10 @@ 0 0 9 +1 1 +2 2 +3 3 +1 1 +2 2 +2 1 +3 2 diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql index dae9cac80467..c62be55e7f6a 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql @@ -43,4 +43,22 @@ select count(null) from numbers(1000); SELECT max(number) FROM numbers_mt (10) where number > 99999999998; SELECT max(number) FROM numbers_mt (10) where number > 2; -set enable_planner_v2 = 0; \ No newline at end of file + +-- Inner join +create table t(a int); +insert into t values(1),(2),(3); +create table t1(b float); +insert into t1 values(1.0),(2.0),(3.0); +create table t2(c uint32 null); +insert into t2 values(1),(2),(null); + +select * from t inner join t1 on cast(t.a as float) = t1.b; +select * from t inner join t2 on t.a = t2.c; +select * from t inner join t2 on t.a = t2.c + 1; +select * from t inner join t2 on t.a = t2.c + 1 and t.a - 1 = t2.c; + +drop table t; +drop table t1; +drop table t2; + +set enable_planner_v2 = 0;