Skip to content

Commit

Permalink
Optimization of SegmentReadTaskPool::getTask (#6097)
Browse files Browse the repository at this point in the history
ref #6092
  • Loading branch information
JinheLin authored Oct 14, 2022
1 parent c67db38 commit ab8ff88
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 40 deletions.
6 changes: 4 additions & 2 deletions dbms/src/Storages/DeltaMerge/DeltaMergeStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,8 @@ BlockInputStreams DeltaMergeStore::readRaw(const Context & db_context,
/* read_mode */ ReadMode::Raw,
std::move(tasks),
after_segment_read,
req_info);
req_info,
enable_read_thread);

BlockInputStreams res;
for (size_t i = 0; i < final_num_stream; ++i)
Expand Down Expand Up @@ -993,7 +994,8 @@ BlockInputStreams DeltaMergeStore::read(const Context & db_context,
/* read_mode = */ is_fast_scan ? ReadMode::Fast : ReadMode::Normal,
std::move(tasks),
after_segment_read,
log_tracing_id);
log_tracing_id,
enable_read_thread);

BlockInputStreams res;
for (size_t i = 0; i < final_num_stream; ++i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,18 @@ void SegmentReadTaskScheduler::add(const SegmentReadTaskPoolPtr & pool)
Stopwatch sw_do_add;
read_pools.add(pool);

std::unordered_set<uint64_t> seg_ids;
for (const auto & task : pool->getTasks())
const auto & tasks = pool->getTasks();
for (const auto & pa : tasks)
{
auto seg_id = task->segment->segmentId();
auto seg_id = pa.first;
merging_segments[pool->tableId()][seg_id].push_back(pool->poolId());
if (!seg_ids.insert(seg_id).second)
{
throw DB::Exception(fmt::format("Not support split segment task. segment_ids={} => segment_id={} already exist.", seg_ids, seg_id));
}
}
auto block_slots = pool->getFreeBlockSlots();
LOG_DEBUG(log, "Added, pool_id={} table_id={} block_slots={} segment_count={} segments={} pool_count={} cost={}ns do_add_cost={}ns", //
LOG_DEBUG(log, "Added, pool_id={} table_id={} block_slots={} segment_count={} pool_count={} cost={}ns do_add_cost={}ns", //
pool->poolId(),
pool->tableId(),
block_slots,
seg_ids.size(),
seg_ids,
tasks.size(),
read_pools.size(),
sw_add.elapsed(),
sw_do_add.elapsed());
Expand Down
88 changes: 76 additions & 12 deletions dbms/src/Storages/DeltaMerge/SegmentReadTaskPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,63 @@ SegmentReadTasks SegmentReadTask::trySplitReadTasks(const SegmentReadTasks & tas
return result_tasks;
}


SegmentReadTasksWrapper::SegmentReadTasksWrapper(bool enable_read_thread_, SegmentReadTasks && ordered_tasks_)
: enable_read_thread(enable_read_thread_)
{
if (enable_read_thread)
{
for (const auto & t : ordered_tasks_)
{
auto [itr, inserted] = unordered_tasks.emplace(t->segment->segmentId(), t);
if (!inserted)
{
throw DB::Exception(fmt::format("segment_id={} already exist.", t->segment->segmentId()));
}
}
}
else
{
ordered_tasks = std::move(ordered_tasks_);
}
}

SegmentReadTaskPtr SegmentReadTasksWrapper::nextTask()
{
RUNTIME_CHECK(!enable_read_thread);
if (ordered_tasks.empty())
{
return nullptr;
}
auto task = ordered_tasks.front();
ordered_tasks.pop_front();
return task;
}

SegmentReadTaskPtr SegmentReadTasksWrapper::getTask(UInt64 seg_id)
{
RUNTIME_CHECK(enable_read_thread);
auto itr = unordered_tasks.find(seg_id);
if (itr == unordered_tasks.end())
{
return nullptr;
}
auto t = itr->second;
unordered_tasks.erase(itr);
return t;
}

const std::unordered_map<UInt64, SegmentReadTaskPtr> & SegmentReadTasksWrapper::getTasks() const
{
RUNTIME_CHECK(enable_read_thread);
return unordered_tasks;
}

bool SegmentReadTasksWrapper::empty() const
{
return ordered_tasks.empty() && unordered_tasks.empty();
}

BlockInputStreamPtr SegmentReadTaskPool::buildInputStream(SegmentReadTaskPtr & t)
{
MemoryTrackerSetter setter(true, mem_tracker.get());
Expand All @@ -112,7 +169,7 @@ void SegmentReadTaskPool::finishSegment(const SegmentPtr & seg)
{
std::lock_guard lock(mutex);
active_segment_ids.erase(seg->segmentId());
pool_finished = active_segment_ids.empty() && tasks.empty();
pool_finished = active_segment_ids.empty() && tasks_wrapper.empty();
}
LOG_DEBUG(log, "finishSegment pool_id={} segment_id={} pool_finished={}", pool_id, seg->segmentId(), pool_finished);
if (pool_finished)
Expand All @@ -121,21 +178,27 @@ void SegmentReadTaskPool::finishSegment(const SegmentPtr & seg)
}
}

SegmentReadTaskPtr SegmentReadTaskPool::getTask(uint64_t seg_id)
SegmentReadTaskPtr SegmentReadTaskPool::nextTask()
{
std::lock_guard lock(mutex);
// TODO(jinhelin): use unordered_map
auto itr = std::find_if(tasks.begin(), tasks.end(), [seg_id](const SegmentReadTaskPtr & task) { return task->segment->segmentId() == seg_id; });
if (itr == tasks.end())
{
throw Exception(fmt::format("{} pool_id={} segment_id={} not found", __PRETTY_FUNCTION__, pool_id, seg_id));
}
auto t = *(itr);
tasks.erase(itr);
return tasks_wrapper.nextTask();
}

SegmentReadTaskPtr SegmentReadTaskPool::getTask(UInt64 seg_id)
{
std::lock_guard lock(mutex);
auto t = tasks_wrapper.getTask(seg_id);
RUNTIME_CHECK(t != nullptr, pool_id, seg_id);
active_segment_ids.insert(seg_id);
return t;
}

const std::unordered_map<UInt64, SegmentReadTaskPtr> & SegmentReadTaskPool::getTasks()
{
std::lock_guard lock(mutex);
return tasks_wrapper.getTasks();
}

// Choose a segment to read.
// Returns <segment_id, pool_ids>.
std::unordered_map<uint64_t, std::vector<uint64_t>>::const_iterator SegmentReadTaskPool::scheduleSegment(const std::unordered_map<uint64_t, std::vector<uint64_t>> & segments, uint64_t expected_merge_count)
Expand All @@ -148,12 +211,13 @@ std::unordered_map<uint64_t, std::vector<uint64_t>>::const_iterator SegmentReadT
}
static constexpr int max_iter_count = 32;
int iter_count = 0;
const auto & tasks = tasks_wrapper.getTasks();
for (const auto & task : tasks)
{
auto itr = segments.find(task->segment->segmentId());
auto itr = segments.find(task.first);
if (itr == segments.end())
{
throw DB::Exception(fmt::format("segment_id {} not found from merging segments", task->segment->segmentId()));
throw DB::Exception(fmt::format("segment_id {} not found from merging segments", task.first));
}
if (std::find(itr->second.begin(), itr->second.end(), poolId()) == itr->second.end())
{
Expand Down
47 changes: 31 additions & 16 deletions dbms/src/Storages/DeltaMerge/SegmentReadTaskPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,33 @@ enum class ReadMode
Raw,
};

// If `enable_read_thread_` is true, `SegmentReadTasksWrapper` use `std::unordered_map` to index `SegmentReadTask` by segment id,
// else it is the same as `SegmentReadTasks`, a `std::list` of `SegmentReadTask`.
// `SegmeneReadTasksWrapper` is not thread-safe.
class SegmentReadTasksWrapper
{
public:
SegmentReadTasksWrapper(bool enable_read_thread_, SegmentReadTasks && ordered_tasks_);

// `nextTask` pops task sequentially. This function is used when `enable_read_thread` is false.
SegmentReadTaskPtr nextTask();

// `getTask` and `getTasks` are used when `enable_read_thread` is true.
SegmentReadTaskPtr getTask(UInt64 seg_id);
const std::unordered_map<UInt64, SegmentReadTaskPtr> & getTasks() const;

bool empty() const;

private:
bool enable_read_thread;
SegmentReadTasks ordered_tasks;
std::unordered_map<UInt64, SegmentReadTaskPtr> unordered_tasks;
};

class SegmentReadTaskPool : private boost::noncopyable
{
public:
explicit SegmentReadTaskPool(
SegmentReadTaskPool(
int64_t table_id_,
const DMContextPtr & dm_context_,
const ColumnDefines & columns_to_read_,
Expand All @@ -144,7 +167,8 @@ class SegmentReadTaskPool : private boost::noncopyable
ReadMode read_mode_,
SegmentReadTasks && tasks_,
AfterSegmentRead after_segment_read_,
const String & tracing_id)
const String & tracing_id,
bool enable_read_thread_)
: pool_id(nextPoolId())
, table_id(table_id_)
, dm_context(dm_context_)
Expand All @@ -153,7 +177,7 @@ class SegmentReadTaskPool : private boost::noncopyable
, max_version(max_version_)
, expected_block_size(expected_block_size_)
, read_mode(read_mode_)
, tasks(std::move(tasks_))
, tasks_wrapper(enable_read_thread_, std::move(tasks_))
, after_segment_read(after_segment_read_)
, log(Logger::get("SegmentReadTaskPool", tracing_id))
, unordered_input_stream_ref_count(0)
Expand Down Expand Up @@ -182,22 +206,14 @@ class SegmentReadTaskPool : private boost::noncopyable
total_bytes / 1024.0 / 1024.0);
}

SegmentReadTaskPtr nextTask()
{
std::lock_guard lock(mutex);
if (tasks.empty())
return {};
auto task = tasks.front();
tasks.pop_front();
return task;
}
SegmentReadTaskPtr nextTask();
const std::unordered_map<UInt64, SegmentReadTaskPtr> & getTasks();
SegmentReadTaskPtr getTask(UInt64 seg_id);

uint64_t poolId() const { return pool_id; }

int64_t tableId() const { return table_id; }

const SegmentReadTasks & getTasks() const { return tasks; }

BlockInputStreamPtr buildInputStream(SegmentReadTaskPtr & t);

bool readOneBlock(BlockInputStreamPtr & stream, const SegmentPtr & seg);
Expand All @@ -212,7 +228,6 @@ class SegmentReadTaskPool : private boost::noncopyable
int64_t getFreeBlockSlots() const;
bool valid() const;
void setException(const DB::Exception & e);
SegmentReadTaskPtr getTask(uint64_t seg_id);

std::once_flag & addToSchedulerFlag()
{
Expand All @@ -233,7 +248,7 @@ class SegmentReadTaskPool : private boost::noncopyable
const uint64_t max_version;
const size_t expected_block_size;
const ReadMode read_mode;
SegmentReadTasks tasks;
SegmentReadTasksWrapper tasks_wrapper;
AfterSegmentRead after_segment_read;
std::mutex mutex;
std::unordered_set<uint64_t> active_segment_ids;
Expand Down
102 changes: 102 additions & 0 deletions dbms/src/Storages/DeltaMerge/tests/gtest_segment_read_task_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Storages/DeltaMerge/Segment.h>
#include <Storages/DeltaMerge/SegmentReadTaskPool.h>
#include <TestUtils/TiFlashTestBasic.h>

namespace DB::DM::tests
{

SegmentPtr createSegment(PageId seg_id)
{
return std::make_shared<Segment>("", 0, RowKeyRange{}, seg_id, seg_id + 1, nullptr, nullptr);
}

SegmentReadTaskPtr createSegmentReadTask(PageId seg_id)
{
return std::make_shared<SegmentReadTask>(createSegment(seg_id), nullptr, RowKeyRanges{});
}

SegmentReadTasks createSegmentReadTasks(const std::vector<PageId> & seg_ids)
{
SegmentReadTasks tasks;
for (PageId seg_id : seg_ids)
{
tasks.push_back(createSegmentReadTask(seg_id));
}
return tasks;
}

static const std::vector<PageId> test_seg_ids{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};

TEST(SegmentReadTasksWrapperTest, Unordered)
{
SegmentReadTasksWrapper tasks_wrapper(true, createSegmentReadTasks(test_seg_ids));

bool exception_happened = false;
try
{
tasks_wrapper.nextTask();
}
catch (const Exception & e)
{
exception_happened = true;
}
ASSERT_TRUE(exception_happened);

ASSERT_FALSE(tasks_wrapper.empty());
const auto & tasks = tasks_wrapper.getTasks();
ASSERT_EQ(tasks.size(), test_seg_ids.size());

std::random_device rd;
std::mt19937 g(rd());
std::vector<PageId> v = test_seg_ids;
std::shuffle(v.begin(), v.end(), g);
for (PageId seg_id : v)
{
auto task = tasks_wrapper.getTask(seg_id);
ASSERT_EQ(task->segment->segmentId(), seg_id);
task = tasks_wrapper.getTask(seg_id);
ASSERT_EQ(task, nullptr);
}
ASSERT_TRUE(tasks_wrapper.empty());
}

TEST(SegmentReadTasksWrapperTest, Ordered)
{
SegmentReadTasksWrapper tasks_wrapper(false, createSegmentReadTasks(test_seg_ids));

bool exception_happened = false;
try
{
tasks_wrapper.getTasks();
}
catch (const Exception & e)
{
exception_happened = true;
}
ASSERT_TRUE(exception_happened);

ASSERT_FALSE(tasks_wrapper.empty());

for (PageId seg_id : test_seg_ids)
{
auto task = tasks_wrapper.nextTask();
ASSERT_EQ(task->segment->segmentId(), seg_id);
}
ASSERT_TRUE(tasks_wrapper.empty());
ASSERT_EQ(tasks_wrapper.nextTask(), nullptr);
}

} // namespace DB::DM::tests

0 comments on commit ab8ff88

Please sign in to comment.