diff --git a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp index f983de91b37..cd9d6235f52 100644 --- a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp +++ b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp @@ -24,7 +24,7 @@ namespace DB { ParallelAggregatingBlockInputStream::ParallelAggregatingBlockInputStream( const BlockInputStreams & inputs, - const BlockInputStreamPtr & additional_input_at_end, + const BlockInputStreams & additional_inputs_at_end, const Aggregator::Params & params_, const FileProviderPtr & file_provider_, bool final_, @@ -41,11 +41,10 @@ ParallelAggregatingBlockInputStream::ParallelAggregatingBlockInputStream( , keys_size(params.keys_size) , aggregates_size(params.aggregates_size) , handler(*this) - , processor(inputs, additional_input_at_end, max_threads, handler, log) + , processor(inputs, additional_inputs_at_end, max_threads, handler, log) { children = inputs; - if (additional_input_at_end) - children.push_back(additional_input_at_end); + children.insert(children.end(), additional_inputs_at_end.begin(), additional_inputs_at_end.end()); } diff --git a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.h b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.h index 41e61786370..907622c8364 100644 --- a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.h +++ b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.h @@ -36,7 +36,7 @@ class ParallelAggregatingBlockInputStream : public IProfilingBlockInputStream */ ParallelAggregatingBlockInputStream( const BlockInputStreams & inputs, - const BlockInputStreamPtr & additional_input_at_end, + const BlockInputStreams & additional_inputs_at_end, const Aggregator::Params & params_, const FileProviderPtr & file_provider_, bool final_, diff --git a/dbms/src/DataStreams/ParallelInputsProcessor.h b/dbms/src/DataStreams/ParallelInputsProcessor.h index 34c70a7085e..57ab37e1756 100644 --- a/dbms/src/DataStreams/ParallelInputsProcessor.h +++ b/dbms/src/DataStreams/ParallelInputsProcessor.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -83,9 +84,8 @@ template class ParallelInputsProcessor { public: - /** additional_input_at_end - if not nullptr, - * then the blocks from this source will start to be processed only after all other sources are processed. - * This is done in the main thread. + /** additional_inputs_at_end - if not empty, + * then the blocks from the sources will start to be processed only after all other sources are processed. * * Intended for implementation of FULL and RIGHT JOIN * - where you must first make JOIN in parallel, while noting which keys are not found, @@ -93,19 +93,18 @@ class ParallelInputsProcessor */ ParallelInputsProcessor( const BlockInputStreams & inputs_, - const BlockInputStreamPtr & additional_input_at_end_, + const BlockInputStreams & additional_inputs_at_end_, size_t max_threads_, Handler & handler_, const LoggerPtr & log_) : inputs(inputs_) - , additional_input_at_end(additional_input_at_end_) - , max_threads(std::min(inputs_.size(), max_threads_)) + , additional_inputs_at_end(additional_inputs_at_end_) + , max_threads(std::min(std::max(inputs_.size(), additional_inputs_at_end_.size()), max_threads_)) , handler(handler_) + , working_inputs(inputs_) + , working_additional_inputs(additional_inputs_at_end_) , log(log_) - { - for (size_t i = 0; i < inputs_.size(); ++i) - unprepared_inputs.emplace(inputs_[i], i); - } + {} ~ParallelInputsProcessor() { @@ -132,36 +131,21 @@ class ParallelInputsProcessor /// Ask all sources to stop earlier than they run out. void cancel(bool kill) { - finish = true; + working_inputs.available_inputs.cancel(); + working_additional_inputs.available_inputs.cancel(); - for (auto & input : inputs) - { - if (IProfilingBlockInputStream * child = dynamic_cast(&*input)) - { - try - { - child->cancel(kill); - } - catch (...) - { - /** If you can not ask one or more sources to stop. - * (for example, the connection is broken for distributed query processing) - * - then do not care. - */ - LOG_FMT_ERROR(log, "Exception while cancelling {}", child->getName()); - } - } - } + cancelStreams(inputs, kill); + cancelStreams(additional_inputs_at_end, kill); } /// Wait until all threads are finished, before the destructor. void wait() { - if (joined_threads) - return; if (thread_manager) + { thread_manager->wait(); - joined_threads = true; + thread_manager.reset(); + } } size_t getNumActiveThreads() const @@ -181,13 +165,78 @@ class ParallelInputsProcessor BlockInputStreamPtr in; size_t i; /// The source number (for debugging). - InputData() {} + InputData() + : i(0) + {} InputData(const BlockInputStreamPtr & in_, size_t i_) : in(in_) , i(i_) {} }; + struct WorkingInputs + { + explicit WorkingInputs(const BlockInputStreams & inputs_) + : available_inputs(inputs_.size()) + , active_inputs(inputs_.size()) + , unprepared_inputs(inputs_.size()) + { + for (size_t i = 0; i < inputs_.size(); ++i) + unprepared_inputs.emplace(inputs_[i], i); + } + /** A set of available sources that are not currently processed by any thread. + * Each thread takes one source from this set, takes a block out of the source (at this moment the source does the calculations) + * and (if the source is not run out), puts it back into the set of available sources. + * + * The question arises what is better to use: + * - the queue (just processed source will be processed the next time later than the rest) + * - stack (just processed source will be processed as soon as possible). + * + * The stack is better than the queue when you need to do work on reading one source more consequentially, + * and theoretically, this allows you to achieve more consequent/consistent reads from the disk. + * + * But when using the stack, there is a problem with distributed query processing: + * data is read only from a part of the servers, and on the other servers + * a timeout occurs during send, and the request processing ends with an exception. + * + * Therefore, a queue is used. This can be improved in the future. + */ + using AvailableInputs = MPMCQueue; + AvailableInputs available_inputs; + + /// How many active input streams. + std::atomic active_inputs; + + /** For parallel preparing (readPrefix) child streams. + * First, streams are located here. + * After a stream was prepared, it is moved to "available_inputs" for reading. + */ + using UnpreparedInputs = MPMCQueue; + UnpreparedInputs unprepared_inputs; + }; + + void cancelStreams(const BlockInputStreams & streams, bool kill) + { + for (const auto & input : streams) + { + if (auto * p_child = dynamic_cast(&*input)) + { + try + { + p_child->cancel(kill); + } + catch (...) + { + /** If you can not ask one or more sources to stop. + * (for example, the connection is broken for distributed query processing) + * - then do not care. + */ + LOG_FMT_ERROR(log, "Exception while cancelling {}", p_child->getName()); + } + } + } + } + void publishPayload(BlockInputStreamPtr & stream, Block & block, size_t thread_num) { if constexpr (mode == StreamUnionMode::Basic) @@ -201,32 +250,24 @@ class ParallelInputsProcessor void thread(size_t thread_num) { - std::exception_ptr exception; + work(thread_num, working_inputs); + work(thread_num, working_additional_inputs); - try - { - while (!finish) - { - InputData unprepared_input; - { - std::lock_guard lock(unprepared_inputs_mutex); - - if (unprepared_inputs.empty()) - break; - - unprepared_input = unprepared_inputs.front(); - unprepared_inputs.pop(); - } + handler.onFinishThread(thread_num); - unprepared_input.in->readPrefix(); + if (0 == --active_threads) + { + handler.onFinish(); + } + } - { - std::lock_guard lock(available_inputs_mutex); - available_inputs.push(unprepared_input); - } - } + void work(size_t thread_num, WorkingInputs & work) + { + std::exception_ptr exception; - loop(thread_num); + try + { + loop(thread_num, work); } catch (...) { @@ -237,134 +278,63 @@ class ParallelInputsProcessor { handler.onException(exception, thread_num); } - - handler.onFinishThread(thread_num); - - /// The last thread on the output indicates that there is no more data. - if (0 == --active_threads) - { - /// And then it processes an additional source, if there is one. - if (additional_input_at_end) - { - try - { - additional_input_at_end->readPrefix(); - while (Block block = additional_input_at_end->read()) - publishPayload(additional_input_at_end, block, thread_num); - } - catch (...) - { - exception = std::current_exception(); - } - - if (exception) - { - handler.onException(exception, thread_num); - } - } - - handler.onFinish(); /// TODO If in `onFinish` or `onFinishThread` there is an exception, then std::terminate is called. - } } - void loop(size_t thread_num) + /// This function may be called in different threads. + /// If no exception occurs, we can ensure that the work is all done when the function + /// returns in any thread. + void loop(size_t thread_num, WorkingInputs & work) { - while (!finish) /// You may need to stop work earlier than all sources run out. + if (work.active_inputs == 0) { - InputData input; + return; + } - /// Select the next source. - { - std::lock_guard lock(available_inputs_mutex); + InputData input; - /// If there are no free sources, then this thread is no longer needed. (But other threads can work with their sources.) - if (available_inputs.empty()) - break; - - input = available_inputs.front(); + while (work.unprepared_inputs.tryPop(input)) + { + input.in->readPrefix(); - /// We remove the source from the queue of available sources. - available_inputs.pop(); - } + work.available_inputs.push(input); + } + // The condition is false when all input streams are exhausted or + // an exception occurred then the queue was cancelled. + while (work.available_inputs.pop(input)) + { /// The main work. Block block = input.in->read(); + if (block) { - if (finish) - break; - - /// If this source is not run out yet, then put the resulting block in the ready queue. + work.available_inputs.push(input); + publishPayload(input.in, block, thread_num); + } + else + { + if (0 == --work.active_inputs) { - std::lock_guard lock(available_inputs_mutex); - - if (block) - { - available_inputs.push(input); - } - else - { - if (available_inputs.empty()) - break; - } - } - - if (finish) + work.available_inputs.finish(); break; - - if (block) - publishPayload(input.in, block, thread_num); + } } } } - BlockInputStreams inputs; - BlockInputStreamPtr additional_input_at_end; + const BlockInputStreams inputs; + const BlockInputStreams additional_inputs_at_end; unsigned max_threads; Handler & handler; std::shared_ptr thread_manager; - /** A set of available sources that are not currently processed by any thread. - * Each thread takes one source from this set, takes a block out of the source (at this moment the source does the calculations) - * and (if the source is not run out), puts it back into the set of available sources. - * - * The question arises what is better to use: - * - the queue (just processed source will be processed the next time later than the rest) - * - stack (just processed source will be processed as soon as possible). - * - * The stack is better than the queue when you need to do work on reading one source more consequentially, - * and theoretically, this allows you to achieve more consequent/consistent reads from the disk. - * - * But when using the stack, there is a problem with distributed query processing: - * data is read only from a part of the servers, and on the other servers - * a timeout occurs during send, and the request processing ends with an exception. - * - * Therefore, a queue is used. This can be improved in the future. - */ - using AvailableInputs = std::queue; - AvailableInputs available_inputs; - - /** For parallel preparing (readPrefix) child streams. - * First, streams are located here. - * After a stream was prepared, it is moved to "available_inputs" for reading. - */ - using UnpreparedInputs = std::queue; - UnpreparedInputs unprepared_inputs; - - /// For operations with available_inputs. - std::mutex available_inputs_mutex; - - /// For operations with unprepared_inputs. - std::mutex unprepared_inputs_mutex; + WorkingInputs working_inputs; + WorkingInputs working_additional_inputs; /// How many sources ran out. std::atomic active_threads{0}; - /// Finish the threads work (before the sources run out). - std::atomic finish{false}; - /// Wait for the completion of all threads. - std::atomic joined_threads{false}; const LoggerPtr log; }; diff --git a/dbms/src/DataStreams/UnionBlockInputStream.h b/dbms/src/DataStreams/UnionBlockInputStream.h index a782c3dd087..ffcc8d77c10 100644 --- a/dbms/src/DataStreams/UnionBlockInputStream.h +++ b/dbms/src/DataStreams/UnionBlockInputStream.h @@ -94,20 +94,19 @@ class UnionBlockInputStream final : public IProfilingBlockInputStream public: UnionBlockInputStream( BlockInputStreams inputs, - BlockInputStreamPtr additional_input_at_end, + BlockInputStreams additional_inputs_at_end, size_t max_threads, const String & req_id, ExceptionCallback exception_callback_ = ExceptionCallback()) - : output_queue(std::min(inputs.size(), max_threads) * 5) // reduce contention + : output_queue(std::min(std::max(inputs.size(), additional_inputs_at_end.size()), max_threads) * 5) // reduce contention , log(Logger::get(NAME, req_id)) , handler(*this) - , processor(inputs, additional_input_at_end, max_threads, handler, log) + , processor(inputs, additional_inputs_at_end, max_threads, handler, log) , exception_callback(exception_callback_) { // TODO: assert capacity of output_queue is not less than processor.getMaxThreads() children = inputs; - if (additional_input_at_end) - children.push_back(additional_input_at_end); + children.insert(children.end(), additional_inputs_at_end.begin(), additional_inputs_at_end.end()); size_t num_children = children.size(); if (num_children > 1) diff --git a/dbms/src/DataStreams/tests/union_stream2.cpp b/dbms/src/DataStreams/tests/union_stream2.cpp index f939cda4e14..fb3f7238414 100644 --- a/dbms/src/DataStreams/tests/union_stream2.cpp +++ b/dbms/src/DataStreams/tests/union_stream2.cpp @@ -51,7 +51,7 @@ try for (size_t i = 0, size = streams.size(); i < size; ++i) streams[i] = std::make_shared(streams[i]); - BlockInputStreamPtr stream = std::make_shared>(streams, nullptr, settings.max_threads, /*req_id=*/""); + BlockInputStreamPtr stream = std::make_shared>(streams, BlockInputStreams{}, settings.max_threads, /*req_id=*/""); stream = std::make_shared(stream, 10, 0, ""); WriteBufferFromFileDescriptor wb(STDERR_FILENO); diff --git a/dbms/src/Debug/astToExecutor.cpp b/dbms/src/Debug/astToExecutor.cpp index edadc7a9940..6cbba6efe99 100644 --- a/dbms/src/Debug/astToExecutor.cpp +++ b/dbms/src/Debug/astToExecutor.cpp @@ -1554,7 +1554,7 @@ ExecutorPtr compileAggregation(ExecutorPtr input, size_t & executor_index, ASTPt ci.tp = TiDB::TypeLongLong; ci.flag = TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull; } - else if (func->name == "max" || func->name == "min" || func->name == "first_row") + else if (func->name == "max" || func->name == "min" || func->name == "first_row" || func->name == "sum") { ci = children_ci[0]; ci.flag &= ~TiDB::ColumnFlagNotNull; diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 30d033c870f..4fb98add2a9 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -379,34 +379,39 @@ void DAGQueryBlockInterpreter::executeAggregation( is_final_agg); /// If there are several sources, then we perform parallel aggregation - if (pipeline.streams.size() > 1) + if (pipeline.streams.size() > 1 || pipeline.streams_with_non_joined_data.size() > 1) { const Settings & settings = context.getSettingsRef(); - BlockInputStreamPtr stream_with_non_joined_data = combinedNonJoinedDataStream(pipeline, max_streams, log); - pipeline.firstStream() = std::make_shared( + BlockInputStreamPtr stream = std::make_shared( pipeline.streams, - stream_with_non_joined_data, + pipeline.streams_with_non_joined_data, params, context.getFileProvider(), true, max_streams, settings.aggregation_memory_efficient_merge_threads ? static_cast(settings.aggregation_memory_efficient_merge_threads) : static_cast(settings.max_threads), log->identifier()); + pipeline.streams.resize(1); + pipeline.streams_with_non_joined_data.clear(); + pipeline.firstStream() = std::move(stream); + // should record for agg before restore concurrency. See #3804. recordProfileStreams(pipeline, query_block.aggregation_name); restorePipelineConcurrency(pipeline); } else { - BlockInputStreamPtr stream_with_non_joined_data = combinedNonJoinedDataStream(pipeline, max_streams, log); BlockInputStreams inputs; if (!pipeline.streams.empty()) inputs.push_back(pipeline.firstStream()); - else - pipeline.streams.resize(1); - if (stream_with_non_joined_data) - inputs.push_back(stream_with_non_joined_data); + + if (!pipeline.streams_with_non_joined_data.empty()) + inputs.push_back(pipeline.streams_with_non_joined_data.at(0)); + + pipeline.streams.resize(1); + pipeline.streams_with_non_joined_data.clear(); + pipeline.firstStream() = std::make_shared( std::make_shared(inputs, log->identifier()), params, diff --git a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp index 9de5b83626f..c747823b69d 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp @@ -42,32 +42,6 @@ void restoreConcurrency( } } -BlockInputStreamPtr combinedNonJoinedDataStream( - DAGPipeline & pipeline, - size_t max_threads, - const LoggerPtr & log, - bool ignore_block) -{ - BlockInputStreamPtr ret = nullptr; - if (pipeline.streams_with_non_joined_data.size() == 1) - ret = pipeline.streams_with_non_joined_data.at(0); - else if (pipeline.streams_with_non_joined_data.size() > 1) - { - if (ignore_block) - { - ret = std::make_shared(pipeline.streams_with_non_joined_data, nullptr, max_threads, log->identifier()); - ret->setExtraInfo("combine non joined(ignore block)"); - } - else - { - ret = std::make_shared(pipeline.streams_with_non_joined_data, nullptr, max_threads, log->identifier()); - ret->setExtraInfo("combine non joined"); - } - } - pipeline.streams_with_non_joined_data.clear(); - return ret; -} - void executeUnion( DAGPipeline & pipeline, size_t max_streams, @@ -75,21 +49,33 @@ void executeUnion( bool ignore_block, const String & extra_info) { - if (pipeline.streams.size() == 1 && pipeline.streams_with_non_joined_data.empty()) - return; - auto non_joined_data_stream = combinedNonJoinedDataStream(pipeline, max_streams, log, ignore_block); - if (!pipeline.streams.empty()) + switch (pipeline.streams.size() + pipeline.streams_with_non_joined_data.size()) { + case 0: + break; + case 1: + { + if (pipeline.streams.size() == 1) + break; + // streams_with_non_joined_data's size is 1. + pipeline.streams.push_back(pipeline.streams_with_non_joined_data.at(0)); + pipeline.streams_with_non_joined_data.clear(); + break; + } + default: + { + BlockInputStreamPtr stream; if (ignore_block) - pipeline.firstStream() = std::make_shared(pipeline.streams, non_joined_data_stream, max_streams, log->identifier()); + stream = std::make_shared(pipeline.streams, pipeline.streams_with_non_joined_data, max_streams, log->identifier()); else - pipeline.firstStream() = std::make_shared(pipeline.streams, non_joined_data_stream, max_streams, log->identifier()); - pipeline.firstStream()->setExtraInfo(extra_info); + stream = std::make_shared(pipeline.streams, pipeline.streams_with_non_joined_data, max_streams, log->identifier()); + stream->setExtraInfo(extra_info); + pipeline.streams.resize(1); + pipeline.streams_with_non_joined_data.clear(); + pipeline.firstStream() = std::move(stream); + break; } - else if (non_joined_data_stream != nullptr) - { - pipeline.streams.push_back(non_joined_data_stream); } } diff --git a/dbms/src/Flash/tests/bench_exchange.cpp b/dbms/src/Flash/tests/bench_exchange.cpp index fbb53bfd4a4..cbbdf060580 100644 --- a/dbms/src/Flash/tests/bench_exchange.cpp +++ b/dbms/src/Flash/tests/bench_exchange.cpp @@ -215,7 +215,7 @@ std::vector ReceiverHelper::buildExchangeReceiverStream() BlockInputStreamPtr ReceiverHelper::buildUnionStream() { auto streams = buildExchangeReceiverStream(); - return std::make_shared>(streams, nullptr, concurrency, /*req_id=*/""); + return std::make_shared>(streams, BlockInputStreams{}, concurrency, /*req_id=*/""); } void ReceiverHelper::finish() @@ -290,7 +290,7 @@ BlockInputStreamPtr SenderHelper::buildUnionStream( send_streams.push_back(std::make_shared(stream, std::move(response_writer), /*req_id=*/"")); } - return std::make_shared>(send_streams, nullptr, concurrency, /*req_id=*/""); + return std::make_shared>(send_streams, BlockInputStreams{}, concurrency, /*req_id=*/""); } BlockInputStreamPtr SenderHelper::buildUnionStream(size_t total_rows, const std::vector & blocks) @@ -312,7 +312,7 @@ BlockInputStreamPtr SenderHelper::buildUnionStream(size_t total_rows, const std: send_streams.push_back(std::make_shared(stream, std::move(response_writer), /*req_id=*/"")); } - return std::make_shared>(send_streams, nullptr, concurrency, /*req_id=*/""); + return std::make_shared>(send_streams, BlockInputStreams{}, concurrency, /*req_id=*/""); } void SenderHelper::finish() diff --git a/dbms/src/Flash/tests/bench_window.cpp b/dbms/src/Flash/tests/bench_window.cpp index dfdb358c46c..356f544a836 100644 --- a/dbms/src/Flash/tests/bench_window.cpp +++ b/dbms/src/Flash/tests/bench_window.cpp @@ -71,7 +71,7 @@ class WindowFunctionBench : public ExchangeBench pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, 8192, 0, "mock_executor_id_squashing"); }); - receiver_stream = std::make_shared>(pipeline.streams, nullptr, concurrency, /*req_id=*/""); + receiver_stream = std::make_shared>(pipeline.streams, BlockInputStreams{}, concurrency, /*req_id=*/""); } tipb::Window window; diff --git a/dbms/src/Flash/tests/gtest_qb_interpreter.cpp b/dbms/src/Flash/tests/gtest_qb_interpreter.cpp index 5529fc358db..1228eaa6201 100644 --- a/dbms/src/Flash/tests/gtest_qb_interpreter.cpp +++ b/dbms/src/Flash/tests/gtest_qb_interpreter.cpp @@ -33,8 +33,8 @@ class QBInterpreterExecuteTest : public DB::tests::ExecutorTest context.addMockTable({"test_db", "r_table"}, {{"r_a", TiDB::TP::TypeLong}, {"r_b", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}); context.addMockTable({"test_db", "l_table"}, {{"l_a", TiDB::TP::TypeLong}, {"l_b", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}); context.addExchangeRelationSchema("sender_1", {{"s1", TiDB::TP::TypeString}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}); - context.addExchangeRelationSchema("sender_l", {{"l_a", TiDB::TP::TypeString}, {"l_b", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}); - context.addExchangeRelationSchema("sender_r", {{"r_a", TiDB::TP::TypeString}, {"r_b", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}); + context.addExchangeRelationSchema("sender_l", {{"l_a", TiDB::TP::TypeLong}, {"l_b", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}); + context.addExchangeRelationSchema("sender_r", {{"r_a", TiDB::TP::TypeLong}, {"r_b", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}); } }; @@ -200,47 +200,6 @@ Union: ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); } - // Join Source. - DAGRequestBuilder table1 = context.scan("test_db", "r_table"); - DAGRequestBuilder table2 = context.scan("test_db", "l_table"); - DAGRequestBuilder table3 = context.scan("test_db", "r_table"); - DAGRequestBuilder table4 = context.scan("test_db", "l_table"); - - request = table1.join( - table2.join( - table3.join(table4, - {col("join_c")}, - ASTTableJoin::Kind::Left), - {col("join_c")}, - ASTTableJoin::Kind::Left), - {col("join_c")}, - ASTTableJoin::Kind::Left) - .build(context); - { - String expected = R"( -CreatingSets - Union: - HashJoinBuildBlockInputStream x 10: , join_kind = Left - Expression: - Expression: - MockTableScan - Union x 2: - HashJoinBuildBlockInputStream x 10: , join_kind = Left - Expression: - Expression: - Expression: - HashJoinProbe: - Expression: - MockTableScan - Union: - Expression x 10: - Expression: - HashJoinProbe: - Expression: - MockTableScan)"; - ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); - } - request = context.receive("sender_1") .project({"s1", "s2", "s3"}) .project({"s1", "s2"}) @@ -280,90 +239,6 @@ Union: MockExchangeReceiver)"; ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); } - - // only join + ExchangeReceiver - DAGRequestBuilder receiver1 = context.receive("sender_l"); - DAGRequestBuilder receiver2 = context.receive("sender_r"); - DAGRequestBuilder receiver3 = context.receive("sender_l"); - DAGRequestBuilder receiver4 = context.receive("sender_r"); - - request = receiver1.join( - receiver2.join( - receiver3.join(receiver4, - {col("join_c")}, - ASTTableJoin::Kind::Left), - {col("join_c")}, - ASTTableJoin::Kind::Left), - {col("join_c")}, - ASTTableJoin::Kind::Left) - .build(context); - { - String expected = R"( -CreatingSets - Union: - HashJoinBuildBlockInputStream x 10: , join_kind = Left - Expression: - Expression: - MockExchangeReceiver - Union x 2: - HashJoinBuildBlockInputStream x 10: , join_kind = Left - Expression: - Expression: - Expression: - HashJoinProbe: - Expression: - MockExchangeReceiver - Union: - Expression x 10: - Expression: - HashJoinProbe: - Expression: - MockExchangeReceiver)"; - ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); - } - - // join + receiver + sender - // TODO: Find a way to write the request easier. - DAGRequestBuilder receiver5 = context.receive("sender_l"); - DAGRequestBuilder receiver6 = context.receive("sender_r"); - DAGRequestBuilder receiver7 = context.receive("sender_l"); - DAGRequestBuilder receiver8 = context.receive("sender_r"); - request = receiver5.join( - receiver6.join( - receiver7.join(receiver8, - {col("join_c")}, - ASTTableJoin::Kind::Left), - {col("join_c")}, - ASTTableJoin::Kind::Left), - {col("join_c")}, - ASTTableJoin::Kind::Left) - .exchangeSender(tipb::PassThrough) - .build(context); - { - String expected = R"( -CreatingSets - Union: - HashJoinBuildBlockInputStream x 10: , join_kind = Left - Expression: - Expression: - MockExchangeReceiver - Union x 2: - HashJoinBuildBlockInputStream x 10: , join_kind = Left - Expression: - Expression: - Expression: - HashJoinProbe: - Expression: - MockExchangeReceiver - Union: - MockExchangeSender x 10 - Expression: - Expression: - HashJoinProbe: - Expression: - MockExchangeReceiver)"; - ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); - } } CATCH @@ -447,5 +322,250 @@ Union: } CATCH +TEST_F(InterpreterExecuteTest, Join) +try +{ + // TODO: Find a way to write the request easier. + { + // Join Source. + DAGRequestBuilder table1 = context.scan("test_db", "r_table"); + DAGRequestBuilder table2 = context.scan("test_db", "l_table"); + DAGRequestBuilder table3 = context.scan("test_db", "r_table"); + DAGRequestBuilder table4 = context.scan("test_db", "l_table"); + + auto request = table1.join( + table2.join( + table3.join(table4, + {col("join_c")}, + ASTTableJoin::Kind::Left), + {col("join_c")}, + ASTTableJoin::Kind::Left), + {col("join_c")}, + ASTTableJoin::Kind::Left) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuildBlockInputStream x 10: , join_kind = Left + Expression: + Expression: + MockTableScan + Union x 2: + HashJoinBuildBlockInputStream x 10: , join_kind = Left + Expression: + Expression: + Expression: + HashJoinProbe: + Expression: + MockTableScan + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockTableScan)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + // only join + ExchangeReceiver + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r"); + DAGRequestBuilder receiver3 = context.receive("sender_l"); + DAGRequestBuilder receiver4 = context.receive("sender_r"); + + auto request = receiver1.join( + receiver2.join( + receiver3.join(receiver4, + {col("join_c")}, + ASTTableJoin::Kind::Left), + {col("join_c")}, + ASTTableJoin::Kind::Left), + {col("join_c")}, + ASTTableJoin::Kind::Left) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuildBlockInputStream x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union x 2: + HashJoinBuildBlockInputStream x 10: , join_kind = Left + Expression: + Expression: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + // join + receiver + sender + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r"); + DAGRequestBuilder receiver3 = context.receive("sender_l"); + DAGRequestBuilder receiver4 = context.receive("sender_r"); + + auto request = receiver1.join( + receiver2.join( + receiver3.join(receiver4, + {col("join_c")}, + ASTTableJoin::Kind::Left), + {col("join_c")}, + ASTTableJoin::Kind::Left), + {col("join_c")}, + ASTTableJoin::Kind::Left) + .exchangeSender(tipb::PassThrough) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuildBlockInputStream x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union x 2: + HashJoinBuildBlockInputStream x 10: , join_kind = Left + Expression: + Expression: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver + Union: + MockExchangeSender x 10 + Expression: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + +TEST_F(InterpreterExecuteTest, JoinThenAgg) +try +{ + { + // Left Join. + DAGRequestBuilder table1 = context.scan("test_db", "r_table"); + DAGRequestBuilder table2 = context.scan("test_db", "l_table"); + + auto request = table1.join( + table2, + {col("join_c")}, + ASTTableJoin::Kind::Left) + .aggregation({Max(col("r_a"))}, {col("join_c")}) + .build(context); + String expected = R"( +CreatingSets + Union: + HashJoinBuildBlockInputStream x 10: , join_kind = Left + Expression: + Expression: + MockTableScan + Union: + Expression x 10: + SharedQuery: + ParallelAggregating, max_threads: 10, final: true + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockTableScan)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + // Right Join + DAGRequestBuilder table1 = context.scan("test_db", "r_table"); + DAGRequestBuilder table2 = context.scan("test_db", "l_table"); + + auto request = table1.join( + table2, + {col("join_c")}, + ASTTableJoin::Kind::Right) + .aggregation({Max(col("r_a"))}, {col("join_c")}) + .build(context); + String expected = R"( +CreatingSets + Union: + HashJoinBuildBlockInputStream x 10: , join_kind = Right + Expression: + Expression: + MockTableScan + Union: + Expression x 10: + SharedQuery: + ParallelAggregating, max_threads: 10, final: true + Expression x 10: + Expression: + HashJoinProbe: + Expression: + Expression: + MockTableScan + Expression x 10: + Expression: + NonJoined: )"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + // Right join + receiver + sender + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r"); + + auto request = receiver1.join( + receiver2, + {col("join_c")}, + ASTTableJoin::Kind::Right) + .aggregation({Sum(col("r_a"))}, {col("join_c")}) + .exchangeSender(tipb::PassThrough) + .limit(10) + .build(context); + String expected = R"( +CreatingSets + Union: + HashJoinBuildBlockInputStream x 20: , join_kind = Right + Expression: + Expression: + MockExchangeReceiver + Union: + MockExchangeSender x 20 + SharedQuery: + Limit, limit = 10 + Union: + Limit x 20, limit = 10 + Expression: + Expression: + SharedQuery: + ParallelAggregating, max_threads: 20, final: true + Expression x 20: + Expression: + HashJoinProbe: + Expression: + Expression: + MockExchangeReceiver + Expression x 20: + Expression: + NonJoined: )"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 20); + } +} +CATCH + } // namespace tests } // namespace DB diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index fe8f04427a0..3514f915626 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -512,13 +512,13 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt { const auto & join = static_cast(*query.join()->table_join); if (join.kind == ASTTableJoin::Kind::Full || join.kind == ASTTableJoin::Kind::Right) - pipeline.stream_with_non_joined_data = expressions.before_join->createStreamWithNonJoinedDataIfFullOrRightJoin( + pipeline.streams_with_non_joined_data.push_back(expressions.before_join->createStreamWithNonJoinedDataIfFullOrRightJoin( pipeline.firstStream()->getHeader(), 0, 1, - settings.max_block_size); + settings.max_block_size)); - for (auto & stream : pipeline.streams) /// Applies to all sources except stream_with_non_joined_data. + for (auto & stream : pipeline.streams) /// Applies to all sources except streams_with_non_joined_data. stream = std::make_shared(stream, expressions.before_join, /*req_id=*/""); } @@ -603,7 +603,7 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt if (need_second_distinct_pass || query.limit_length || query.limit_by_expression_list - || pipeline.stream_with_non_joined_data) + || !pipeline.streams_with_non_joined_data.empty()) { need_merge_streams = true; } @@ -987,11 +987,11 @@ void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const Expre Aggregator::Params params(header, keys, aggregates, overflow_row, settings.max_rows_to_group_by, settings.group_by_overflow_mode, allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold : SettingUInt64(0), allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold_bytes : SettingUInt64(0), settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set, context.getTemporaryPath()); /// If there are several sources, then we perform parallel aggregation - if (pipeline.streams.size() > 1) + if (pipeline.streams.size() > 1 || pipeline.streams_with_non_joined_data.size() > 1) { - pipeline.firstStream() = std::make_shared( + auto stream = std::make_shared( pipeline.streams, - pipeline.stream_with_non_joined_data, + pipeline.streams_with_non_joined_data, params, file_provider, final, @@ -1001,19 +1001,21 @@ void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const Expre : static_cast(settings.max_threads), /*req_id=*/""); - pipeline.stream_with_non_joined_data = nullptr; pipeline.streams.resize(1); + pipeline.streams_with_non_joined_data.clear(); + pipeline.firstStream() = std::move(stream); } else { BlockInputStreams inputs; if (!pipeline.streams.empty()) inputs.push_back(pipeline.firstStream()); - else - pipeline.streams.resize(1); - if (pipeline.stream_with_non_joined_data) - inputs.push_back(pipeline.stream_with_non_joined_data); + if (!pipeline.streams_with_non_joined_data.empty()) + inputs.push_back(pipeline.streams_with_non_joined_data.at(0)); + + pipeline.streams.resize(1); + pipeline.streams_with_non_joined_data.clear(); pipeline.firstStream() = std::make_shared( std::make_shared(inputs, /*req_id=*/""), @@ -1021,8 +1023,6 @@ void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const Expre file_provider, final, /*req_id=*/""); - - pipeline.stream_with_non_joined_data = nullptr; } } @@ -1244,21 +1244,33 @@ void InterpreterSelectQuery::executeDistinct(Pipeline & pipeline, bool before_or void InterpreterSelectQuery::executeUnion(Pipeline & pipeline) { - /// If there are still several streams, then we combine them into one - if (pipeline.hasMoreThanOneStream()) + switch (pipeline.streams.size() + pipeline.streams_with_non_joined_data.size()) { - pipeline.firstStream() = std::make_shared>( + case 0: + break; + case 1: + { + if (pipeline.streams.size() == 1) + break; + // streams_with_non_joined_data's size is 1. + pipeline.streams.push_back(pipeline.streams_with_non_joined_data.at(0)); + pipeline.streams_with_non_joined_data.clear(); + break; + } + default: + { + BlockInputStreamPtr stream = std::make_shared>( pipeline.streams, - pipeline.stream_with_non_joined_data, + pipeline.streams_with_non_joined_data, max_streams, /*req_id=*/""); - pipeline.stream_with_non_joined_data = nullptr; + ; + pipeline.streams.resize(1); + pipeline.streams_with_non_joined_data.clear(); + pipeline.firstStream() = std::move(stream); + break; } - else if (pipeline.stream_with_non_joined_data) - { - pipeline.streams.push_back(pipeline.stream_with_non_joined_data); - pipeline.stream_with_non_joined_data = nullptr; } } diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.h b/dbms/src/Interpreters/InterpreterSelectQuery.h index 474ace7ee84..d1bcec2a3dd 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.h +++ b/dbms/src/Interpreters/InterpreterSelectQuery.h @@ -95,7 +95,7 @@ class InterpreterSelectQuery : public IInterpreter * It has a special meaning, since reading from it should be done after reading from the main streams. * It is appended to the main streams in UnionBlockInputStream or ParallelAggregatingBlockInputStream. */ - BlockInputStreamPtr stream_with_non_joined_data; + BlockInputStreams streams_with_non_joined_data; BlockInputStreamPtr & firstStream() { return streams.at(0); } @@ -105,13 +105,13 @@ class InterpreterSelectQuery : public IInterpreter for (auto & stream : streams) transform(stream); - if (stream_with_non_joined_data) - transform(stream_with_non_joined_data); + for (auto & stream : streams_with_non_joined_data) + transform(stream); } bool hasMoreThanOneStream() const { - return streams.size() + (stream_with_non_joined_data ? 1 : 0) > 1; + return streams.size() + streams_with_non_joined_data.size() > 1; } }; diff --git a/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp b/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp index 5e73b1e5f3e..076c290cc9d 100644 --- a/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp @@ -224,7 +224,7 @@ BlockIO InterpreterSelectWithUnionQuery::execute() } else { - result_stream = std::make_shared>(nested_streams, nullptr, settings.max_threads, /*req_id=*/""); + result_stream = std::make_shared>(nested_streams, BlockInputStreams{}, settings.max_threads, /*req_id=*/""); nested_streams.clear(); } diff --git a/dbms/src/TestUtils/mockExecutor.h b/dbms/src/TestUtils/mockExecutor.h index 1f78a3278af..1f7bd2322fa 100644 --- a/dbms/src/TestUtils/mockExecutor.h +++ b/dbms/src/TestUtils/mockExecutor.h @@ -83,7 +83,7 @@ class DAGRequestBuilder DAGRequestBuilder & exchangeSender(tipb::ExchangeType exchange_type); - // Currentlt only support inner join, left join and right join. + // Currently only support inner join, left join and right join. // TODO support more types of join. DAGRequestBuilder & join(const DAGRequestBuilder & right, MockAstVec exprs); DAGRequestBuilder & join(const DAGRequestBuilder & right, MockAstVec exprs, ASTTableJoin::Kind kind); @@ -175,6 +175,7 @@ MockWindowFrame buildDefaultRowsFrame(); #define Or(expr1, expr2) makeASTFunction("or", (expr1), (expr2)) #define NOT(expr) makeASTFunction("not", (expr)) #define Max(expr) makeASTFunction("max", (expr)) +#define Sum(expr) makeASTFunction("sum", (expr)) /// Window functions #define RowNumber() makeASTFunction("RowNumber") #define Rank() makeASTFunction("Rank")