diff --git a/oneflow/core/embedding/persistent_table.cpp b/oneflow/core/embedding/persistent_table.cpp
index dd5a9ea39a6..091c5a7b01b 100644
--- a/oneflow/core/embedding/persistent_table.cpp
+++ b/oneflow/core/embedding/persistent_table.cpp
@@ -298,19 +298,10 @@ class AioEngine final {
 constexpr size_t kCacheLineSize = 64;
 
 template<typename Engine>
-using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>;
+using IoTask = std::function<void(Engine* engine)>;
 
 template<typename Engine>
-struct ParallelForTask {
-  ParallelForTask(size_t num_workers, size_t total, const ForRange<Engine>* for_range)
-      : counter(0), total(total), for_range(for_range), bc(num_workers) {}
-  union alignas(kCacheLineSize) {
-    std::atomic<size_t> counter;
-  };
-  size_t total;
-  const ForRange<Engine>* for_range;
-  BlockingCounter bc;
-};
+using ForRange = std::function<void(Engine* engine, size_t start, size_t end)>;
 
 template<typename Engine>
 class Worker final {
@@ -322,29 +313,21 @@ class Worker final {
     thread_.join();
   }
 
-  void Schedule(ParallelForTask<Engine>* task) { tasks_.Send(task); }
+  void Schedule(IoTask<Engine> task) { tasks_.Send(std::move(task)); }
 
   void Shutdown() { tasks_.Close(); }
 
  private:
   void PullTask() {
     while (true) {
-      ParallelForTask<Engine>* task = nullptr;
+      IoTask<Engine> task;
       const ChannelStatus status = tasks_.Receive(&task);
       if (status == ChannelStatus::kChannelStatusErrorClosed) { break; }
       CHECK_EQ(status, ChannelStatus::kChannelStatusSuccess);
-      while (true) {
-        const size_t start = task->counter.fetch_add(kParallelForStride, std::memory_order_relaxed);
-        if (start >= task->total) { break; }
-        const size_t next_start = start + kParallelForStride;
-        const size_t end = std::min(next_start, task->total);
-        (*task->for_range)(&engine_, start, end);
-      }
-      engine_.WaitUntilDone();
-      task->bc.Decrease();
+      task(&engine_);
     }
   }
-  Channel<ParallelForTask<Engine>*> tasks_;
+  Channel<IoTask<Engine>> tasks_;
   Engine engine_;
   std::thread thread_;
 };
@@ -538,45 +521,50 @@ void PersistentTableImpl<Key, Engine>::PutBlocks(uint32_t num_keys, const void*
   physical_table_size_ += num_padded_keys;
   CHECK_EQ(start_index % num_values_per_block_, 0);
   const uint64_t start_block_id = start_index / num_values_per_block_;
-  for (uint64_t i = 0; i < num_keys; ++i) {
-    row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i;
-  }
   uint64_t written_blocks = 0;
   const uint64_t block_keys_size = num_values_per_block_ * sizeof(Key);
-  while (written_blocks < num_blocks) {
-    const uint64_t batch_start_block_id = start_block_id + written_blocks;
-    const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_;
-    if (batch_chunk_id == value_files_.size()) {
-      value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644);
-    } else {
-      CHECK_LE(batch_chunk_id, value_files_.size());
-    }
-    if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) {
-      writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644);
+  BlockingCounter bc(1);
+  workers_.at(0)->Schedule([&](Engine*) {
+    while (written_blocks < num_blocks) {
+      const uint64_t batch_start_block_id = start_block_id + written_blocks;
+      const uint64_t batch_chunk_id = batch_start_block_id / num_logical_blocks_per_chunk_;
+      if (batch_chunk_id == value_files_.size()) {
+        value_files_.emplace_back(ValueFilePath(batch_chunk_id), O_CREAT | O_RDWR | O_DIRECT, 0644);
+      } else {
+        CHECK_LE(batch_chunk_id, value_files_.size());
+      }
+      if ((!writable_key_file_.IsOpen()) || writable_key_file_chunk_id_ != batch_chunk_id) {
+        writable_key_file_ = PosixFile(KeyFilePath(batch_chunk_id), O_CREAT | O_RDWR, 0644);
+      }
+      PosixFile& value_file = value_files_.at(batch_chunk_id);
+      const uint64_t block_id_in_chunk =
+          batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_;
+      const uint64_t blocks_to_write =
+          std::min(num_blocks - written_blocks,
+                   (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id);
+      const uint64_t values_bytes = blocks_to_write * logical_block_size_;
+      const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_;
+      CHECK_LE(value_file.Size(), values_offset_in_file);
+      value_file.Truncate(values_offset_in_file + values_bytes);
+      PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_),
+                    values_bytes, values_offset_in_file)
+             == values_bytes);
+      const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size;
+      writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size);
+      const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_,
+                                           blocks_to_write * num_values_per_block_)
+                                  * sizeof(Key);
+      PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size),
+                    keys_bytes, keys_offset_in_file)
+             == keys_bytes);
+      written_blocks += blocks_to_write;
     }
-    PosixFile& value_file = value_files_.at(batch_chunk_id);
-    const uint64_t block_id_in_chunk =
-        batch_start_block_id - batch_chunk_id * num_logical_blocks_per_chunk_;
-    const uint64_t blocks_to_write =
-        std::min(num_blocks - written_blocks,
-                 (batch_chunk_id + 1) * num_logical_blocks_per_chunk_ - batch_start_block_id);
-    const uint64_t values_bytes = blocks_to_write * logical_block_size_;
-    const uint64_t values_offset_in_file = block_id_in_chunk * logical_block_size_;
-    CHECK_LE(value_file.Size(), values_offset_in_file);
-    value_file.Truncate(values_offset_in_file + values_bytes);
-    PCHECK(pwrite(value_file.fd(), BytesOffset(blocks, written_blocks * logical_block_size_),
-                  values_bytes, values_offset_in_file)
-           == values_bytes);
-    const uint64_t keys_offset_in_file = block_id_in_chunk * block_keys_size;
-    writable_key_file_.Truncate(keys_offset_in_file + blocks_to_write * block_keys_size);
-    const uint64_t keys_bytes = std::min(num_keys - written_blocks * num_values_per_block_,
-                                         blocks_to_write * num_values_per_block_)
-                                * sizeof(Key);
-    PCHECK(pwrite(writable_key_file_.fd(), BytesOffset(keys, written_blocks * block_keys_size),
-                  keys_bytes, keys_offset_in_file)
-           == keys_bytes);
-    written_blocks += blocks_to_write;
+    bc.Decrease();
+  });
+  for (uint64_t i = 0; i < num_keys; ++i) {
+    row_id_mapping_[static_cast<const Key*>(keys)[i]] = start_index + i;
   }
+  bc.WaitForeverUntilCntEqualZero();
 }
 
 template<typename Key, typename Engine>
@@ -747,9 +735,22 @@ void PersistentTableImpl<Key, Engine>::SaveSnapshot(const std::string& name) {
 template<typename Key, typename Engine>
 void PersistentTableImpl<Key, Engine>::ParallelFor(size_t total,
                                                    const ForRange<Engine>& for_range) {
-  ParallelForTask<Engine> task(workers_.size(), total, &for_range);
-  for (size_t i = 0; i < workers_.size(); ++i) { workers_.at(i)->Schedule(&task); }
-  task.bc.WaitForeverUntilCntEqualZero();
+  BlockingCounter bc(workers_.size());
+  std::atomic<size_t> counter(0);
+  for (size_t i = 0; i < workers_.size(); ++i) {
+    workers_.at(i)->Schedule([&](Engine* engine) {
+      while (true) {
+        const size_t start = counter.fetch_add(kParallelForStride, std::memory_order_relaxed);
+        if (start >= total) { break; }
+        const size_t next_start = start + kParallelForStride;
+        const size_t end = std::min(next_start, total);
+        for_range(engine, start, end);
+      }
+      engine->WaitUntilDone();
+      bc.Decrease();
+    });
+  }
+  bc.WaitForeverUntilCntEqualZero();
 }
 
 template<typename Engine>