Skip to content

Commit

Permalink
[PersistentTable] Async write (#7946)
Browse files Browse the repository at this point in the history
* [PersistentTable] Async write

* fix

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and xiacijie committed Apr 24, 2022
1 parent 54d7b6b commit a77403d
Showing 1 changed file with 62 additions and 61 deletions.
123 changes: 62 additions & 61 deletions oneflow/core/embedding/persistent_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,19 +298,10 @@ class AioEngine final {
constexpr size_t kCacheLineSize = 64;

template<typename Engine>
using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>;
using IoTask = std::function<void(Engine* engine)>;

template<typename Engine>
struct ParallelForTask {
ParallelForTask(size_t num_workers, size_t total, const ForRange<Engine>* for_range)
: counter(0), total(total), for_range(for_range), bc(num_workers) {}
union alignas(kCacheLineSize) {
std::atomic<size_t> counter;
};
size_t total;
const ForRange<Engine>* for_range;
BlockingCounter bc;
};
using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>;

template<typename Engine>
class Worker final {
Expand All @@ -322,29 +313,21 @@ class Worker final {
thread_.join();
}

void Schedule(ParallelForTask<Engine>* task) { tasks_.Send(task); }
void Schedule(IoTask<Engine> task) { tasks_.Send(std::move(task)); }

void Shutdown() { tasks_.Close(); }

private:
void PullTask() {
while (true) {
ParallelForTask<Engine>* task = nullptr;
IoTask<Engine> task;
const ChannelStatus status = tasks_.Receive(&task);
if (status == ChannelStatus::kChannelStatusErrorClosed) { break; }
CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess);
while (true) {
const size_t start = task->counter.fetch_add(kParallelForStride, std::memory_order_relaxed);
if (start >= task->total) { break; }
const size_t next_start = start + kParallelForStride;
const size_t end = std::min(next_start, task->total);
(*task->for_range)(&engine_, start, end);
}
engine_.WaitUntilDone();
task->bc.Decrease();
task(&engine_);
}
}
Channel<ParallelForTask<Engine>*> tasks_;
Channel<IoTask<Engine>> tasks_;
Engine engine_;
std::thread thread_;
};
Expand Down Expand Up @@ -538,45 +521,50 @@ void PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void*
physical_table_size_ += num_padded_keys;
CHECK_EQ(start_index % num_values_per_block_, 0);
const uint64_t start_block_id = start_index / num_values_per_block_;
for (uint64_t i = 0; i < num_keys; ++i) {
row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i;
}
uint64_t written_blocks = 0;
const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key);
while (written_blocks < num_blocks) {
const uint64_t batch_start_block_id = start_block_id + written_blocks;
const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_;
if (batch_chunk_id == value_files_.size()) {
value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644);
} else {
CHECK_LE(batch_chunk_id, value_files_.size());
}
if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) {
writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644);
BlockingCounter bc(1);
workers_.at(0)->Schedule([&](Engine*) {
while (written_blocks < num_blocks) {
const uint64_t batch_start_block_id = start_block_id + written_blocks;
const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_;
if (batch_chunk_id == value_files_.size()) {
value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644);
} else {
CHECK_LE(batch_chunk_id, value_files_.size());
}
if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) {
writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644);
}
PosixFile& value_file = value_files_.at(batch_chunk_id);
const uint64_t block_id_in_chunk =
batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_;
const uint64_t blocks_to_write =
std::min(num_blocks - written_blocks,
(batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id);
const uint64_t values_bytes = blocks_to_write * logical_block_size_;
const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_;
CHECK_LE(value_file.Size(), values_offset_in_file);
value_file.Truncate(values_offset_in_file + values_bytes);
PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_),
values_bytes, values_offset_in_file)
== values_bytes);
const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size;
writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size);
const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_,
blocks_to_write * num_values_per_block_)
* sizeof(Key);
PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size),
keys_bytes, keys_offset_in_file)
== keys_bytes);
written_blocks += blocks_to_write;
}
PosixFile& value_file = value_files_.at(batch_chunk_id);
const uint64_t block_id_in_chunk =
batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_;
const uint64_t blocks_to_write =
std::min(num_blocks - written_blocks,
(batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id);
const uint64_t values_bytes = blocks_to_write * logical_block_size_;
const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_;
CHECK_LE(value_file.Size(), values_offset_in_file);
value_file.Truncate(values_offset_in_file + values_bytes);
PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_),
values_bytes, values_offset_in_file)
== values_bytes);
const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size;
writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size);
const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_,
blocks_to_write * num_values_per_block_)
* sizeof(Key);
PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size),
keys_bytes, keys_offset_in_file)
== keys_bytes);
written_blocks += blocks_to_write;
bc.Decrease();
});
for (uint64_t i = 0; i < num_keys; ++i) {
row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i;
}
bc.WaitForeverUntilCntEqualZero();
}

template<typename Key, typename Engine>
Expand Down Expand Up @@ -747,9 +735,22 @@ void PersistentTableImpl<Key, Engine>::SaveSnapshot(const std::string& name) {
template<typename Key, typename Engine>
void PersistentTableImpl<Key, Engine>::ParallelFor(size_t total,
const ForRange<Engine>& for_range) {
ParallelForTask<Engine> task(workers_.size(), total, &for_range);
for (size_t i = 0; i < workers_.size(); ++i) { workers_.at(i)->Schedule(&task); }
task.bc.WaitForeverUntilCntEqualZero();
BlockingCounter bc(workers_.size());
std::atomic<size_t> counter(0);
for (size_t i = 0; i < workers_.size(); ++i) {
workers_.at(i)->Schedule([&](Engine* engine) {
while (true) {
const size_t start = counter.fetch_add(kParallelForStride, std::memory_order_relaxed);
if (start >= total) { break; }
const size_t next_start = start + kParallelForStride;
const size_t end = std::min(next_start, total);
for_range(engine, start, end);
}
engine->WaitUntilDone();
bc.Decrease();
});
}
bc.WaitForeverUntilCntEqualZero();
}

template<typename Engine>
Expand Down

0 comments on commit a77403d

Please sign in to comment.