diff --git a/dbms/src/DataStreams/AggregatingBlockInputStream.cpp b/dbms/src/DataStreams/AggregatingBlockInputStream.cpp index 8ca25b24e43..8aecc68c2aa 100644 --- a/dbms/src/DataStreams/AggregatingBlockInputStream.cpp +++ b/dbms/src/DataStreams/AggregatingBlockInputStream.cpp @@ -32,7 +32,7 @@ Block AggregatingBlockInputStream::readImpl() executed = true; AggregatedDataVariantsPtr data_variants = std::make_shared(); - Aggregator::CancellationHook hook = [&]() { + CancellationHook hook = [&]() { return this->isCancelled(); }; aggregator.setCancellationHook(hook); diff --git a/dbms/src/DataStreams/HashJoinProbeExec.h b/dbms/src/DataStreams/HashJoinProbeExec.h index f92ea8702f5..02fd82c8e08 100644 --- a/dbms/src/DataStreams/HashJoinProbeExec.h +++ b/dbms/src/DataStreams/HashJoinProbeExec.h @@ -36,8 +36,6 @@ class HashJoinProbeExec : public std::enable_shared_from_this const BlockInputStreamPtr & probe_stream, size_t max_block_size); - using CancellationHook = std::function; - HashJoinProbeExec( const String & req_id, const JoinPtr & join_, diff --git a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp index 75092276fbd..9e5c4b57c5b 100644 --- a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp +++ b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp @@ -71,7 +71,7 @@ Block ParallelAggregatingBlockInputStream::readImpl() { if (!executed) { - Aggregator::CancellationHook hook = [&]() { + CancellationHook hook = [&]() { return this->isCancelled(); }; aggregator.setCancellationHook(hook); diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 75c5c221ed4..061144d0984 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -517,6 +517,9 @@ void MPPTask::runImpl() GET_METRIC(tiflash_coprocessor_request_duration_seconds, type_run_mpp_task).Observe(stopwatch.elapsedSeconds()); }); + // set cancellation hook + context->setCancellationHook([this] { return is_cancelled.load(); }); + String err_msg; try { @@ -746,7 +749,7 @@ void MPPTask::abort(const String & message, AbortType abort_type) if (previous_status == FINISHED || previous_status == CANCELLED || previous_status == FAILED) { LOG_WARNING(log, "task already in {} state", magic_enum::enum_name(previous_status)); - return; + break; } else if (previous_status == INITIALIZING && switchStatus(INITIALIZING, next_task_status)) { @@ -755,7 +758,7 @@ void MPPTask::abort(const String & message, AbortType abort_type) /// so just close all tunnels here abortTunnels("", false); LOG_WARNING(log, "Finish abort task from uninitialized"); - return; + break; } else if (previous_status == RUNNING && switchStatus(RUNNING, next_task_status)) { @@ -769,9 +772,10 @@ void MPPTask::abort(const String & message, AbortType abort_type) scheduleThisTask(ScheduleState::FAILED); /// runImpl is running, leave remaining work to runImpl LOG_WARNING(log, "Finish abort task from running"); - return; + break; } } + is_cancelled = true; } bool MPPTask::switchStatus(TaskStatus from, TaskStatus to) diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index 1c997e2bb5e..477b7c6e7be 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -177,6 +177,7 @@ class MPPTask MPPTaskManager * manager; std::atomic is_registered{false}; + std::atomic is_cancelled{false}; MPPTaskScheduleEntry schedule_entry; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp index e912ecfc8c7..0abae77c7d7 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp @@ -238,6 +238,7 @@ void PhysicalJoin::probeSideTransform(DAGPipeline & probe_pipeline, Context & co settings.max_block_size); stream->setExtraInfo(join_probe_extra_info); } + join_ptr->setCancellationHook([&] { return context.isCancelled(); }); } void PhysicalJoin::buildSideTransform(DAGPipeline & build_pipeline, Context & context, size_t max_streams) diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoinProbe.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoinProbe.cpp index a41ee82869d..4872b4f744e 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoinProbe.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoinProbe.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -56,6 +57,7 @@ void PhysicalJoinProbe::buildPipelineExecGroupImpl( max_block_size, input_header)); }); + join_ptr->setCancellationHook([&]() { return exec_context.isCancelled(); }); join_ptr.reset(); } } // namespace DB diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 33a1ea1bd4c..52b311a6ba9 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -1199,8 +1200,6 @@ class Aggregator */ Blocks convertBlockToTwoLevel(const Block & block); - using CancellationHook = std::function; - /** Set a function that checks whether the current task can be aborted. */ void setCancellationHook(CancellationHook cancellation_hook); diff --git a/dbms/src/Interpreters/CancellationHook.h b/dbms/src/Interpreters/CancellationHook.h new file mode 100644 index 00000000000..a054c256c30 --- /dev/null +++ b/dbms/src/Interpreters/CancellationHook.h @@ -0,0 +1,22 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB +{ +using CancellationHook = std::function; +} // namespace DB diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 610ef892a10..5af18eafca3 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -178,6 +179,9 @@ class Context TimezoneInfo timezone_info; DAGContext * dag_context = nullptr; + CancellationHook is_cancelled{[]() { + return false; + }}; using DatabasePtr = std::shared_ptr; using Databases = std::map>; /// Use copy constructor or createGlobal() instead @@ -235,8 +239,8 @@ class Context /// Compute and set actual user settings, client_info.current_user should be set void calculateUserSettings(); - ClientInfo & getClientInfo() { return client_info; }; - const ClientInfo & getClientInfo() const { return client_info; }; + ClientInfo & getClientInfo() { return client_info; } + const ClientInfo & getClientInfo() const { return client_info; } void setQuota( const String & name, @@ -373,6 +377,9 @@ class Context void setDAGContext(DAGContext * dag_context); DAGContext * getDAGContext() const; + bool isCancelled() const { return is_cancelled(); } + void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; } + /// List all queries. ProcessList & getProcessList(); const ProcessList & getProcessList() const; @@ -499,8 +506,8 @@ class Context SharedQueriesPtr getSharedQueries(); - const TimezoneInfo & getTimezoneInfo() const { return timezone_info; }; - TimezoneInfo & getTimezoneInfo() { return timezone_info; }; + const TimezoneInfo & getTimezoneInfo() const { return timezone_info; } + TimezoneInfo & getTimezoneInfo() { return timezone_info; } /// User name and session identifier. Named sessions are local to users. using SessionKey = std::pair; diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index e080c9f84c1..905199c5d8b 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -1385,6 +1385,9 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui Block Join::removeUselessColumn(Block & block) const { + // cancelled + if (!block) + return block; Block projected_block; for (const auto & name_and_type : output_columns_after_finalize) { @@ -1415,6 +1418,8 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const restore_config.restore_round); while (true) { + if (is_cancelled()) + return {}; auto block = doJoinBlockHash(probe_process_info, join_build_info); assert(block); block = removeUselessColumn(block); @@ -1519,6 +1524,8 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const while (true) { + if (is_cancelled()) + return {}; Block block = doJoinBlockCross(probe_process_info); assert(block); block = removeUselessColumn(block); @@ -1600,6 +1607,9 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows); + if (is_cancelled()) + return {}; + Block block{}; for (size_t i = 0; i < probe_process_info.block.columns(); ++i) { @@ -1636,13 +1646,17 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in blocks, null_rows, max_block_size, - non_equal_conditions); + non_equal_conditions, + is_cancelled); helper.joinResult(res_list); RUNTIME_CHECK_MSG(res_list.empty(), "NASemiJoinResult list must be empty after calculating join result"); } + if (is_cancelled()) + return {}; + /// Now all results are known. std::unique_ptr filter; @@ -1787,6 +1801,8 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe probe_process_info); RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows); + if (is_cancelled()) + return {}; const NameSet & probe_output_name_set = has_other_condition ? output_columns_names_set_for_other_condition_after_finalize @@ -1821,8 +1837,13 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe { if (!res_list.empty()) { - SemiJoinHelper - helper(block, left_columns, right_column_indices_to_add, max_block_size, non_equal_conditions); + SemiJoinHelper helper( + block, + left_columns, + right_column_indices_to_add, + max_block_size, + non_equal_conditions, + is_cancelled); helper.joinResult(res_list); @@ -1830,6 +1851,9 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe } } + if (is_cancelled()) + return {}; + /// Now all results are known. std::unique_ptr filter; @@ -2216,6 +2240,9 @@ Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const else block = joinBlockHash(probe_process_info); + // if cancelled, just return empty block + if (!block) + return block; /// for (cartesian)antiLeftSemi join, the meaning of "match-helper" is `non-matched` instead of `matched`. if (kind == Cross_LeftOuterAnti) { @@ -2488,6 +2515,7 @@ std::optional Join::getOneRestoreStream(size_t max_block_size_) restore_join->initBuild(build_sample_block, restore_join_build_concurrency); restore_join->setInitActiveBuildThreads(); restore_join->initProbe(probe_sample_block, restore_join_build_concurrency); + restore_join->setCancellationHook(is_cancelled); BlockInputStreams restore_scan_hash_map_streams; restore_scan_hash_map_streams.resize(restore_join_build_concurrency, nullptr); if (needScanHashMapAfterProbe(kind)) diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 4b68e448305..a60e50eb7ab 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -311,6 +312,8 @@ class Join void flushProbeSideMarkedSpillData(size_t stream_index); size_t getProbeCacheColumnThreshold() const { return probe_cache_column_threshold; } + void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; } + static const String match_helper_prefix; static const DataTypePtr match_helper_type; static const String flag_mapped_entry_helper_prefix; @@ -446,6 +449,9 @@ class Join // the index of vector is the stream_index. std::vector build_side_marked_spilled_data; std::vector probe_side_marked_spilled_data; + CancellationHook is_cancelled{[]() { + return false; + }}; private: /** Set information about structure of right hand of JOIN (joined data). diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp index 80e66b205b2..27e3e949c01 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp @@ -244,13 +244,15 @@ NASemiJoinHelper::NASemiJoinHelper( const BlocksList & right_blocks_, const std::vector & null_rows_, size_t max_block_size_, - const JoinNonEqualConditions & non_equal_conditions_) + const JoinNonEqualConditions & non_equal_conditions_, + CancellationHook is_cancelled_) : block(block_) , left_columns(left_columns_) , right_column_indices_to_add(right_column_indices_to_add_) , right_blocks(right_blocks_) , null_rows(null_rows_) , max_block_size(max_block_size_) + , is_cancelled(is_cancelled_) , non_equal_conditions(non_equal_conditions_) { static_assert(KIND == NullAware_Anti || KIND == NullAware_LeftOuterAnti || KIND == NullAware_LeftOuterSemi); @@ -280,17 +282,17 @@ void NASemiJoinHelper::joinResult(std::list(res_list, next_step_res_list); res_list.swap(next_step_res_list); - if (res_list.empty()) + if (is_cancelled() || res_list.empty()) return; runStep(res_list, next_step_res_list); res_list.swap(next_step_res_list); - if (res_list.empty()) + if (is_cancelled() || res_list.empty()) return; runStepAllBlocks(res_list); @@ -324,6 +326,8 @@ void NASemiJoinHelper::runStep( while (!res_list.empty()) { + if (is_cancelled()) + return; MutableColumns columns(block_columns); for (size_t i = 0; i < block_columns; ++i) { @@ -384,6 +388,8 @@ void NASemiJoinHelper::runStepAllBlocks(std::listgetStep() == NASemiJoinStep::DONE) break; diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h index 83ab56d332b..63d5a170513 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -167,7 +168,8 @@ class NASemiJoinHelper const BlocksList & right_blocks, const std::vector & null_rows, size_t max_block_size, - const JoinNonEqualConditions & non_equal_conditions); + const JoinNonEqualConditions & non_equal_conditions, + CancellationHook is_cancelled_); void joinResult(std::list & res_list); @@ -192,6 +194,7 @@ class NASemiJoinHelper const BlocksList & right_blocks; const std::vector & null_rows; size_t max_block_size; + CancellationHook is_cancelled; const JoinNonEqualConditions & non_equal_conditions; }; diff --git a/dbms/src/Interpreters/SemiJoinHelper.cpp b/dbms/src/Interpreters/SemiJoinHelper.cpp index 2956919f9b7..6951cdd2450 100644 --- a/dbms/src/Interpreters/SemiJoinHelper.cpp +++ b/dbms/src/Interpreters/SemiJoinHelper.cpp @@ -147,11 +147,13 @@ SemiJoinHelper::SemiJoinHelper( size_t left_columns_, const std::vector & right_column_indices_to_added_, size_t max_block_size_, - const JoinNonEqualConditions & non_equal_conditions_) + const JoinNonEqualConditions & non_equal_conditions_, + CancellationHook is_cancelled_) : block(block_) , left_columns(left_columns_) , right_column_indices_to_add(right_column_indices_to_added_) , max_block_size(max_block_size_) + , is_cancelled(is_cancelled_) , non_equal_conditions(non_equal_conditions_) { static_assert(KIND == Semi || KIND == Anti || KIND == LeftOuterAnti || KIND == LeftOuterSemi); @@ -178,6 +180,8 @@ void SemiJoinHelper::joinResult(std::list & res_list) while (!res_list.empty()) { + if (is_cancelled()) + return; MutableColumns columns(block_columns); for (size_t i = 0; i < block_columns; ++i) { diff --git a/dbms/src/Interpreters/SemiJoinHelper.h b/dbms/src/Interpreters/SemiJoinHelper.h index 3759da9c391..90dc39c81ce 100644 --- a/dbms/src/Interpreters/SemiJoinHelper.h +++ b/dbms/src/Interpreters/SemiJoinHelper.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include namespace DB @@ -137,7 +138,8 @@ class SemiJoinHelper size_t left_columns, const std::vector & right_column_indices_to_add, size_t max_block_size, - const JoinNonEqualConditions & non_equal_conditions); + const JoinNonEqualConditions & non_equal_conditions, + CancellationHook is_cancelled_); void joinResult(std::list & res_list); @@ -156,6 +158,7 @@ class SemiJoinHelper size_t right_columns; std::vector right_column_indices_to_add; size_t max_block_size; + CancellationHook is_cancelled; const JoinNonEqualConditions & non_equal_conditions; }; diff --git a/dbms/src/Operators/AggregateContext.cpp b/dbms/src/Operators/AggregateContext.cpp index 391b1acce57..c46be959937 100644 --- a/dbms/src/Operators/AggregateContext.cpp +++ b/dbms/src/Operators/AggregateContext.cpp @@ -19,7 +19,7 @@ namespace DB void AggregateContext::initBuild( const Aggregator::Params & params, size_t max_threads_, - Aggregator::CancellationHook && hook, + CancellationHook && hook, const RegisterOperatorSpillContext & register_operator_spill_context) { assert(status.load() == AggStatus::init); diff --git a/dbms/src/Operators/AggregateContext.h b/dbms/src/Operators/AggregateContext.h index 5c1351a1cd1..60f176ad299 100644 --- a/dbms/src/Operators/AggregateContext.h +++ b/dbms/src/Operators/AggregateContext.h @@ -44,7 +44,7 @@ class AggregateContext void initBuild( const Aggregator::Params & params, size_t max_threads_, - Aggregator::CancellationHook && hook, + CancellationHook && hook, const RegisterOperatorSpillContext & register_operator_spill_context); size_t getBuildConcurrency() const { return max_threads; } @@ -104,7 +104,7 @@ class AggregateContext }; std::atomic status{AggStatus::init}; - Aggregator::CancellationHook is_cancelled{[]() { + CancellationHook is_cancelled{[]() { return false; }};