Skip to content

Commit

Permalink
join be aware of cancel signal (#9450) (#9452)
Browse files Browse the repository at this point in the history
close #9430

Co-authored-by: xufei <xufei@pingcap.com>
Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent 08595f4 commit a2a4c88
Show file tree
Hide file tree
Showing 18 changed files with 110 additions and 26 deletions.
2 changes: 1 addition & 1 deletion dbms/src/DataStreams/AggregatingBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Block AggregatingBlockInputStream::readImpl()
executed = true;
AggregatedDataVariantsPtr data_variants = std::make_shared<AggregatedDataVariants>();

Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
2 changes: 0 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeExec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>
const BlockInputStreamPtr & probe_stream,
size_t max_block_size);

using CancellationHook = std::function<bool()>;

HashJoinProbeExec(
const String & req_id,
const JoinPtr & join_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Block ParallelAggregatingBlockInputStream::readImpl()
{
if (!executed)
{
Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
10 changes: 7 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ void MPPTask::runImpl()
GET_METRIC(tiflash_coprocessor_request_duration_seconds, type_run_mpp_task).Observe(stopwatch.elapsedSeconds());
});

// set cancellation hook
context->setCancellationHook([this] { return is_cancelled.load(); });

String err_msg;
try
{
Expand Down Expand Up @@ -746,7 +749,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
if (previous_status == FINISHED || previous_status == CANCELLED || previous_status == FAILED)
{
LOG_WARNING(log, "task already in {} state", magic_enum::enum_name(previous_status));
return;
break;
}
else if (previous_status == INITIALIZING && switchStatus(INITIALIZING, next_task_status))
{
Expand All @@ -755,7 +758,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
/// so just close all tunnels here
abortTunnels("", false);
LOG_WARNING(log, "Finish abort task from uninitialized");
return;
break;
}
else if (previous_status == RUNNING && switchStatus(RUNNING, next_task_status))
{
Expand All @@ -769,9 +772,10 @@ void MPPTask::abort(const String & message, AbortType abort_type)
scheduleThisTask(ScheduleState::FAILED);
/// runImpl is running, leave remaining work to runImpl
LOG_WARNING(log, "Finish abort task from running");
return;
break;
}
}
is_cancelled = true;
}

bool MPPTask::switchStatus(TaskStatus from, TaskStatus to)
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class MPPTask

MPPTaskManager * manager;
std::atomic<bool> is_registered{false};
std::atomic<bool> is_cancelled{false};

MPPTaskScheduleEntry schedule_entry;

Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ void PhysicalJoin::probeSideTransform(DAGPipeline & probe_pipeline, Context & co
settings.max_block_size);
stream->setExtraInfo(join_probe_extra_info);
}
join_ptr->setCancellationHook([&] { return context.isCancelled(); });
}

void PhysicalJoin::buildSideTransform(DAGPipeline & build_pipeline, Context & context, size_t max_streams)
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoinProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <Flash/Coprocessor/InterpreterUtils.h>
#include <Flash/Executor/PipelineExecutorContext.h>
#include <Flash/Pipeline/Exec/PipelineExecBuilder.h>
#include <Flash/Planner/Plans/PhysicalJoinProbe.h>
#include <Interpreters/Context.h>
Expand Down Expand Up @@ -56,6 +57,7 @@ void PhysicalJoinProbe::buildPipelineExecGroupImpl(
max_block_size,
input_header));
});
join_ptr->setCancellationHook([&]() { return exec_context.isCancelled(); });
join_ptr.reset();
}
} // namespace DB
3 changes: 1 addition & 2 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <Interpreters/AggSpillContext.h>
#include <Interpreters/AggregateDescription.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <TiDB/Collation/Collator.h>
#include <common/StringRef.h>
#include <common/logger_useful.h>
Expand Down Expand Up @@ -1199,8 +1200,6 @@ class Aggregator
*/
Blocks convertBlockToTwoLevel(const Block & block);

using CancellationHook = std::function<bool()>;

/** Set a function that checks whether the current task can be aborted.
*/
void setCancellationHook(CancellationHook cancellation_hook);
Expand Down
22 changes: 22 additions & 0 deletions dbms/src/Interpreters/CancellationHook.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>

namespace DB
{
using CancellationHook = std::function<bool()>;
} // namespace DB
15 changes: 11 additions & 4 deletions dbms/src/Interpreters/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Core/Types.h>
#include <Debug/MockServerInfo.h>
#include <IO/FileProvider/FileProvider_fwd.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/Settings.h>
Expand Down Expand Up @@ -178,6 +179,9 @@ class Context
TimezoneInfo timezone_info;

DAGContext * dag_context = nullptr;
CancellationHook is_cancelled{[]() {
return false;
}};
using DatabasePtr = std::shared_ptr<IDatabase>;
using Databases = std::map<String, std::shared_ptr<IDatabase>>;
/// Use copy constructor or createGlobal() instead
Expand Down Expand Up @@ -235,8 +239,8 @@ class Context
/// Compute and set actual user settings, client_info.current_user should be set
void calculateUserSettings();

ClientInfo & getClientInfo() { return client_info; };
const ClientInfo & getClientInfo() const { return client_info; };
ClientInfo & getClientInfo() { return client_info; }
const ClientInfo & getClientInfo() const { return client_info; }

void setQuota(
const String & name,
Expand Down Expand Up @@ -373,6 +377,9 @@ class Context
void setDAGContext(DAGContext * dag_context);
DAGContext * getDAGContext() const;

bool isCancelled() const { return is_cancelled(); }
void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

/// List all queries.
ProcessList & getProcessList();
const ProcessList & getProcessList() const;
Expand Down Expand Up @@ -499,8 +506,8 @@ class Context

SharedQueriesPtr getSharedQueries();

const TimezoneInfo & getTimezoneInfo() const { return timezone_info; };
TimezoneInfo & getTimezoneInfo() { return timezone_info; };
const TimezoneInfo & getTimezoneInfo() const { return timezone_info; }
TimezoneInfo & getTimezoneInfo() { return timezone_info; }

/// User name and session identifier. Named sessions are local to users.
using SessionKey = std::pair<String, String>;
Expand Down
34 changes: 31 additions & 3 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,9 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui

Block Join::removeUselessColumn(Block & block) const
{
// cancelled
if (!block)
return block;
Block projected_block;
for (const auto & name_and_type : output_columns_after_finalize)
{
Expand Down Expand Up @@ -1415,6 +1418,8 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const
restore_config.restore_round);
while (true)
{
if (is_cancelled())
return {};
auto block = doJoinBlockHash(probe_process_info, join_build_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1519,6 +1524,8 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const

while (true)
{
if (is_cancelled())
return {};
Block block = doJoinBlockCross(probe_process_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1600,6 +1607,9 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in

RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows);

if (is_cancelled())
return {};

Block block{};
for (size_t i = 0; i < probe_process_info.block.columns(); ++i)
{
Expand Down Expand Up @@ -1636,13 +1646,17 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in
blocks,
null_rows,
max_block_size,
non_equal_conditions);
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);

RUNTIME_CHECK_MSG(res_list.empty(), "NASemiJoinResult list must be empty after calculating join result");
}

if (is_cancelled())
return {};

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
Expand Down Expand Up @@ -1787,6 +1801,8 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe
probe_process_info);

RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows);
if (is_cancelled())
return {};

const NameSet & probe_output_name_set = has_other_condition
? output_columns_names_set_for_other_condition_after_finalize
Expand Down Expand Up @@ -1821,15 +1837,23 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe
{
if (!res_list.empty())
{
SemiJoinHelper<KIND, typename Maps::MappedType>
helper(block, left_columns, right_column_indices_to_add, max_block_size, non_equal_conditions);
SemiJoinHelper<KIND, typename Maps::MappedType> helper(
block,
left_columns,
right_column_indices_to_add,
max_block_size,
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);

RUNTIME_CHECK_MSG(res_list.empty(), "SemiJoinResult list must be empty after calculating join result");
}
}

if (is_cancelled())
return {};

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
Expand Down Expand Up @@ -2216,6 +2240,9 @@ Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const
else
block = joinBlockHash(probe_process_info);

// if cancelled, just return empty block
if (!block)
return block;
/// for (cartesian)antiLeftSemi join, the meaning of "match-helper" is `non-matched` instead of `matched`.
if (kind == Cross_LeftOuterAnti)
{
Expand Down Expand Up @@ -2488,6 +2515,7 @@ std::optional<RestoreInfo> Join::getOneRestoreStream(size_t max_block_size_)
restore_join->initBuild(build_sample_block, restore_join_build_concurrency);
restore_join->setInitActiveBuildThreads();
restore_join->initProbe(probe_sample_block, restore_join_build_concurrency);
restore_join->setCancellationHook(is_cancelled);
BlockInputStreams restore_scan_hash_map_streams;
restore_scan_hash_map_streams.resize(restore_join_build_concurrency, nullptr);
if (needScanHashMapAfterProbe(kind))
Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Flash/Coprocessor/RuntimeFilterMgr.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/HashJoinSpillContext.h>
#include <Interpreters/JoinHashMap.h>
Expand Down Expand Up @@ -311,6 +312,8 @@ class Join
void flushProbeSideMarkedSpillData(size_t stream_index);
size_t getProbeCacheColumnThreshold() const { return probe_cache_column_threshold; }

void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

static const String match_helper_prefix;
static const DataTypePtr match_helper_type;
static const String flag_mapped_entry_helper_prefix;
Expand Down Expand Up @@ -446,6 +449,9 @@ class Join
// the index of vector is the stream_index.
std::vector<MarkedSpillData> build_side_marked_spilled_data;
std::vector<MarkedSpillData> probe_side_marked_spilled_data;
CancellationHook is_cancelled{[]() {
return false;
}};

private:
/** Set information about structure of right hand of JOIN (joined data).
Expand Down
14 changes: 10 additions & 4 deletions dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,15 @@ NASemiJoinHelper<KIND, STRICTNESS, Mapped>::NASemiJoinHelper(
const BlocksList & right_blocks_,
const std::vector<RowsNotInsertToMap *> & null_rows_,
size_t max_block_size_,
const JoinNonEqualConditions & non_equal_conditions_)
const JoinNonEqualConditions & non_equal_conditions_,
CancellationHook is_cancelled_)
: block(block_)
, left_columns(left_columns_)
, right_column_indices_to_add(right_column_indices_to_add_)
, right_blocks(right_blocks_)
, null_rows(null_rows_)
, max_block_size(max_block_size_)
, is_cancelled(is_cancelled_)
, non_equal_conditions(non_equal_conditions_)
{
static_assert(KIND == NullAware_Anti || KIND == NullAware_LeftOuterAnti || KIND == NullAware_LeftOuterSemi);
Expand Down Expand Up @@ -280,17 +282,17 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::joinResult(std::list<NASemiJoin
res_list.swap(next_step_res_list);
}

if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NOT_NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStepAllBlocks(res_list);
Expand Down Expand Up @@ -324,6 +326,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStep(

while (!res_list.empty())
{
if (is_cancelled())
return;
MutableColumns columns(block_columns);
for (size_t i = 0; i < block_columns; ++i)
{
Expand Down Expand Up @@ -384,6 +388,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStepAllBlocks(std::list<NASe
NASemiJoinHelper::Result * res = *res_list.begin();
for (const auto & right_block : right_blocks)
{
if (is_cancelled())
return;
if (res->getStep() == NASemiJoinStep::DONE)
break;

Expand Down
Loading

0 comments on commit a2a4c88

Please sign in to comment.