Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PersistentTable] Async write #7946

Merged
merged 25 commits into from
Apr 11, 2022
Merged
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6863ea5
[PersistentTable] Async write
liujuncheng Apr 1, 2022
30af1bc
fix
liujuncheng Apr 1, 2022
f656195
Merge branch 'master' into dev_persistent_table_async_write
liujuncheng Apr 2, 2022
fde83fa
Merge branch 'master' into dev_persistent_table_async_write
liujuncheng Apr 6, 2022
7b862c5
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 6, 2022
998124b
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 6, 2022
6937742
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 6, 2022
047cb9c
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 7, 2022
e7b7026
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 7, 2022
77c83cc
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 7, 2022
e9126fd
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 7, 2022
4d0281c
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 7, 2022
5e978d4
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 8, 2022
6c3562e
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 8, 2022
651f142
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 8, 2022
8c21d62
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 8, 2022
e9ef436
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 9, 2022
dc24ddd
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 9, 2022
41882e7
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 9, 2022
ef854cd
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 9, 2022
e316561
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 10, 2022
e5fef57
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 10, 2022
6f762ff
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 10, 2022
c1472f1
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 10, 2022
d3ae366
Merge branch 'master' into dev_persistent_table_async_write
mergify[bot] Apr 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 62 additions & 61 deletions oneflow/core/embedding/persistent_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,19 +298,10 @@ class AioEngine final {
constexpr size_t kCacheLineSize = 64;

template<typename Engine>
using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>;
using IoTask = std::function<void(Engine* engine)>;

template<typename Engine>
struct ParallelForTask {
ParallelForTask(size_t num_workers, size_t total, const ForRange<Engine>* for_range)
: counter(0), total(total), for_range(for_range), bc(num_workers) {}
union alignas(kCacheLineSize) {
std::atomic<size_t> counter;
};
size_t total;
const ForRange<Engine>* for_range;
BlockingCounter bc;
};
using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>;

template<typename Engine>
class Worker final {
Expand All @@ -322,29 +313,21 @@ class Worker final {
thread_.join();
}

void Schedule(ParallelForTask<Engine>* task) { tasks_.Send(task); }
void Schedule(IoTask<Engine> task) { tasks_.Send(std::move(task)); }

void Shutdown() { tasks_.Close(); }

private:
void PullTask() {
while (true) {
ParallelForTask<Engine>* task = nullptr;
IoTask<Engine> task;
const ChannelStatus status = tasks_.Receive(&task);
if (status == ChannelStatus::kChannelStatusErrorClosed) { break; }
CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess);
while (true) {
const size_t start = task->counter.fetch_add(kParallelForStride, std::memory_order_relaxed);
if (start >= task->total) { break; }
const size_t next_start = start + kParallelForStride;
const size_t end = std::min(next_start, task->total);
(*task->for_range)(&engine_, start, end);
}
engine_.WaitUntilDone();
task->bc.Decrease();
task(&engine_);
}
}
Channel<ParallelForTask<Engine>*> tasks_;
Channel<IoTask<Engine>> tasks_;
Engine engine_;
std::thread thread_;
};
Expand Down Expand Up @@ -538,45 +521,50 @@ void PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void*
physical_table_size_ += num_padded_keys;
CHECK_EQ(start_index % num_values_per_block_, 0);
const uint64_t start_block_id = start_index / num_values_per_block_;
for (uint64_t i = 0; i < num_keys; ++i) {
row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i;
}
uint64_t written_blocks = 0;
const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key);
while (written_blocks < num_blocks) {
const uint64_t batch_start_block_id = start_block_id + written_blocks;
const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_;
if (batch_chunk_id == value_files_.size()) {
value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644);
} else {
CHECK_LE(batch_chunk_id, value_files_.size());
}
if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) {
writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644);
BlockingCounter bc(1);
workers_.at(0)->Schedule([&](Engine*) {
while (written_blocks < num_blocks) {
const uint64_t batch_start_block_id = start_block_id + written_blocks;
const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_;
if (batch_chunk_id == value_files_.size()) {
value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644);
} else {
CHECK_LE(batch_chunk_id, value_files_.size());
}
if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) {
writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644);
}
PosixFile& value_file = value_files_.at(batch_chunk_id);
const uint64_t block_id_in_chunk =
batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_;
const uint64_t blocks_to_write =
std::min(num_blocks - written_blocks,
(batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id);
const uint64_t values_bytes = blocks_to_write * logical_block_size_;
const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_;
CHECK_LE(value_file.Size(), values_offset_in_file);
value_file.Truncate(values_offset_in_file + values_bytes);
PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_),
values_bytes, values_offset_in_file)
== values_bytes);
const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size;
writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size);
const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_,
blocks_to_write * num_values_per_block_)
* sizeof(Key);
PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size),
keys_bytes, keys_offset_in_file)
== keys_bytes);
written_blocks += blocks_to_write;
}
PosixFile& value_file = value_files_.at(batch_chunk_id);
const uint64_t block_id_in_chunk =
batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_;
const uint64_t blocks_to_write =
std::min(num_blocks - written_blocks,
(batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id);
const uint64_t values_bytes = blocks_to_write * logical_block_size_;
const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_;
CHECK_LE(value_file.Size(), values_offset_in_file);
value_file.Truncate(values_offset_in_file + values_bytes);
PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_),
values_bytes, values_offset_in_file)
== values_bytes);
const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size;
writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size);
const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_,
blocks_to_write * num_values_per_block_)
* sizeof(Key);
PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size),
keys_bytes, keys_offset_in_file)
== keys_bytes);
written_blocks += blocks_to_write;
bc.Decrease();
});
for (uint64_t i = 0; i < num_keys; ++i) {
row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i;
}
bc.WaitForeverUntilCntEqualZero();
}

template<typename Key, typename Engine>
Expand Down Expand Up @@ -747,9 +735,22 @@ void PersistentTableImpl<Key, Engine>::SaveSnapshot(const std::string& name) {
template<typename Key, typename Engine>
void PersistentTableImpl<Key, Engine>::ParallelFor(size_t total,
const ForRange<Engine>& for_range) {
ParallelForTask<Engine> task(workers_.size(), total, &for_range);
for (size_t i = 0; i < workers_.size(); ++i) { workers_.at(i)->Schedule(&task); }
task.bc.WaitForeverUntilCntEqualZero();
BlockingCounter bc(workers_.size());
std::atomic<size_t> counter(0);
for (size_t i = 0; i < workers_.size(); ++i) {
workers_.at(i)->Schedule([&](Engine* engine) {
while (true) {
const size_t start = counter.fetch_add(kParallelForStride, std::memory_order_relaxed);
if (start >= total) { break; }
const size_t next_start = start + kParallelForStride;
const size_t end = std::min(next_start, total);
for_range(engine, start, end);
}
engine->WaitUntilDone();
bc.Decrease();
});
}
bc.WaitForeverUntilCntEqualZero();
}

template<typename Engine>
Expand Down