diff --git a/oneflow/core/embedding/persistent_table.cpp b/oneflow/core/embedding/persistent_table.cpp index dd5a9ea39a6..091c5a7b01b 100644 --- a/oneflow/core/embedding/persistent_table.cpp +++ b/oneflow/core/embedding/persistent_table.cpp @@ -298,19 +298,10 @@ class AioEngine final { constexpr size_t kCacheLineSize = 64; template<typename Engine> -using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>; +using IoTask = std::function<void(Engine* engine)>; template<typename Engine> -struct ParallelForTask { - ParallelForTask(size_t num_workers, size_t total, const ForRange<Engine>* for_range) - : counter(0), total(total), for_range(for_range), bc(num_workers) {} - union alignas(kCacheLineSize) { - std::atomic<size_t> counter; - }; - size_t total; - const ForRange<Engine>* for_range; - BlockingCounter bc; -}; +using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>; template<typename Engine> class Worker final { @@ -322,29 +313,21 @@ class Worker final { thread_.join(); } - void Schedule(ParallelForTask<Engine>* task) { tasks_.Send(task); } + void Schedule(IoTask<Engine> task) { tasks_.Send(std::move(task)); } void Shutdown() { tasks_.Close(); } private: void PullTask() { while (true) { - ParallelForTask<Engine>* task = nullptr; + IoTask<Engine> task; const ChannelStatus status = tasks_.Receive(&task); if (status == ChannelStatus::kChannelStatusErrorClosed) { break; } CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess); - while (true) { - const size_t start = task->counter.fetch_add(kParallelForStride, std::memory_order_relaxed); - if (start >= task->total) { break; } - const size_t next_start = start + kParallelForStride; - const size_t end = std::min(next_start, task->total); - (*task->for_range)(&engine_, start, end); - } - engine_.WaitUntilDone(); - task->bc.Decrease(); + task(&engine_); } } - Channel<ParallelForTask<Engine>*> tasks_; + Channel<IoTask<Engine>> tasks_; Engine engine_; std::thread thread_; }; @@ -538,45 +521,50 @@ void PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void* physical_table_size_ += num_padded_keys; CHECK_EQ(start_index % num_values_per_block_, 0); const uint64_t start_block_id = start_index / num_values_per_block_; - for (uint64_t i = 0; i < num_keys; ++i) { - row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i; - } uint64_t written_blocks = 0; const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key); - while (written_blocks < num_blocks) { - const uint64_t batch_start_block_id = start_block_id + written_blocks; - const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_; - if (batch_chunk_id == value_files_.size()) { - value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644); - } else { - CHECK_LE(batch_chunk_id, value_files_.size()); - } - if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) { - writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644); + BlockingCounter bc(1); + workers_.at(0)->Schedule([&](Engine*) { + while (written_blocks < num_blocks) { + const uint64_t batch_start_block_id = start_block_id + written_blocks; + const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_; + if (batch_chunk_id == value_files_.size()) { + value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644); + } else { + CHECK_LE(batch_chunk_id, value_files_.size()); + } + if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) { + writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644); + } + PosixFile& value_file = value_files_.at(batch_chunk_id); + const uint64_t block_id_in_chunk = + batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_; + const uint64_t blocks_to_write = + std::min(num_blocks - written_blocks, + (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id); + const uint64_t values_bytes = blocks_to_write * logical_block_size_; + const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_; + CHECK_LE(value_file.Size(), values_offset_in_file); + value_file.Truncate(values_offset_in_file + values_bytes); + PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_), + values_bytes, values_offset_in_file) + == values_bytes); + const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size; + writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size); + const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_, + blocks_to_write * num_values_per_block_) + * sizeof(Key); + PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size), + keys_bytes, keys_offset_in_file) + == keys_bytes); + written_blocks += blocks_to_write; } - PosixFile& value_file = value_files_.at(batch_chunk_id); - const uint64_t block_id_in_chunk = - batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_; - const uint64_t blocks_to_write = - std::min(num_blocks - written_blocks, - (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id); - const uint64_t values_bytes = blocks_to_write * logical_block_size_; - const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_; - CHECK_LE(value_file.Size(), values_offset_in_file); - value_file.Truncate(values_offset_in_file + values_bytes); - PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_), - values_bytes, values_offset_in_file) - == values_bytes); - const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size; - writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size); - const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_, - blocks_to_write * num_values_per_block_) - * sizeof(Key); - PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size), - keys_bytes, keys_offset_in_file) - == keys_bytes); - written_blocks += blocks_to_write; + bc.Decrease(); + }); + for (uint64_t i = 0; i < num_keys; ++i) { + row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i; } + bc.WaitForeverUntilCntEqualZero(); } template<typename Key, typename Engine> @@ -747,9 +735,22 @@ void PersistentTableImpl<Key, Engine>::SaveSnapshot(const std::string& name) { template<typename Key, typename Engine> void PersistentTableImpl<Key, Engine>::ParallelFor(size_t total, const ForRange<Engine>& for_range) { - ParallelForTask<Engine> task(workers_.size(), total, &for_range); - for (size_t i = 0; i < workers_.size(); ++i) { workers_.at(i)->Schedule(&task); } - task.bc.WaitForeverUntilCntEqualZero(); + BlockingCounter bc(workers_.size()); + std::atomic<size_t> counter(0); + for (size_t i = 0; i < workers_.size(); ++i) { + workers_.at(i)->Schedule([&](Engine* engine) { + while (true) { + const size_t start = counter.fetch_add(kParallelForStride, std::memory_order_relaxed); + if (start >= total) { break; } + const size_t next_start = start + kParallelForStride; + const size_t end = std::min(next_start, total); + for_range(engine, start, end); + } + engine->WaitUntilDone(); + bc.Decrease(); + }); + } + bc.WaitForeverUntilCntEqualZero(); } template<typename Engine>