Skip to content

Commit

Permalink
Refactor: Use Executor instead of std::thread
Browse files Browse the repository at this point in the history
  • Loading branch information
mgevor committed Jul 25, 2023
1 parent ed3f845 commit c3a3693
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,8 @@ class index_gt {
template <typename progress_at = dummy_progress_t>
serialization_result_t save(char const* file_path, progress_at&& progress = {}) const noexcept {

using file_offset_t = typename node_offsets_t::offset_t;

// Make sure we have right to write to that file
serialization_result_t result;
std::FILE* file = std::fopen(file_path, "wb");
Expand Down Expand Up @@ -2443,7 +2445,7 @@ class index_gt {
return result;

// Calculate total node sizes
typename node_offsets_t::offset_t total_node_bytes = 0;
file_offset_t total_node_bytes = 0;
for (std::size_t i = 0; i != state.size; ++i) {
node_t node = node_with_id_(i);
total_node_bytes += node_head_bytes_() + node_neighbors_bytes_(node);
Expand All @@ -2452,13 +2454,12 @@ class index_gt {
// Firstly, calculate and serialize node offsets
node_offsets_buffer_t offsets_buffer{};
node_offsets_t offsets{offsets_buffer};
typename node_offsets_t::offset_t head_offset = 0;
file_offset_t head_offset = 0;
// Align vectors offset
typename node_offsets_t::offset_t vectors_base_offset =
file_offset_t vectors_base_offset =
sizeof(file_header_t) + size_ * sizeof(node_offsets_buffer_t) + total_node_bytes;
typename node_offsets_t::offset_t vectors_offset_shift =
config_.vector_alignment - vectors_base_offset % config_.vector_alignment;
typename node_offsets_t::offset_t vector_offset = total_node_bytes + vectors_offset_shift;
file_offset_t vectors_offset_shift = config_.vector_alignment - vectors_base_offset % config_.vector_alignment;
file_offset_t vector_offset = total_node_bytes + vectors_offset_shift;
for (std::size_t i = 0; i != state.size; ++i) {
offsets.head = head_offset;
offsets.vector = vector_offset;
Expand Down Expand Up @@ -2501,6 +2502,8 @@ class index_gt {
template <typename progress_at = dummy_progress_t>
serialization_result_t load(char const* file_path, progress_at&& progress = {}) noexcept {

using file_offset_t = typename node_offsets_t::offset_t;

// Remove previously stored objects
reset();

Expand Down Expand Up @@ -2579,7 +2582,7 @@ class index_gt {
}

// Then, load vectors from aligned address
typename node_offsets_t::offset_t offset = std::ftell(file);
file_offset_t offset = std::ftell(file);
std::fseek(file, config_.vector_alignment - offset % config_.vector_alignment, SEEK_CUR);
for (std::size_t i = 0; i != size_; ++i) {
read_chunk(nodes_[i].vector(), node_vector_bytes_(nodes_[i].dim()));
Expand All @@ -2597,9 +2600,21 @@ class index_gt {
* @brief Memory-maps the serialized binary index representation from disk,
* @b without copying the vectors and neighbors lists into RAM.
* Available on Linux, MacOS, but @b not on Windows.
*
* @param[in] file_path File where the index is saved.
* @param[in] executor Thread-pool to execute the job in parallel.
* @param[in] progress Callback to report the execution progress.
*/
template <typename progress_at = dummy_progress_t>
serialization_result_t view(char const* file_path, progress_at&& progress = {}) noexcept {
template < //
typename executor_at = dummy_executor_t, //
typename progress_at = dummy_progress_t //
>
serialization_result_t view( //
char const* file_path, //
executor_at&& executor = executor_at{}, //
progress_at&& progress = {}) noexcept {

using file_offset_t = typename node_offsets_t::offset_t;

// Remove previously stored objects
reset();
Expand All @@ -2614,7 +2629,7 @@ class index_gt {
if (file_handle == INVALID_HANDLE_VALUE)
return result.failed("Opening file failed!");

typename node_offsets_t::offset_t file_length = GetFileSize(file_handle, 0);
file_offset_t file_length = GetFileSize(file_handle, 0);
HANDLE mapping_handle = CreateFileMapping(file_handle, 0, PAGE_READONLY, 0, 0, 0);
if (mapping_handle == 0) {
CloseHandle(file_handle);
Expand Down Expand Up @@ -2685,36 +2700,20 @@ class index_gt {
entry_id_ = static_cast<id_t>(state.entry_idx);
}

// Read nodes and vectors
// Divide tasks between threads
std::size_t threads_count = (std::min)(limits_.threads(), size_ / 1'000); // Use optimal thread count
std::size_t thread_tasks_count = size_ / (threads_count + 1); // + main thread
std::size_t main_thread_tasks_count = thread_tasks_count + (threads_count ? size_ % (threads_count + 1) : 0);

// Task
typename node_offsets_t::offset_t base_offset = sizeof(file_header_t);
typename node_offsets_t::offset_t nodes_base_offset = base_offset + size_ * sizeof(node_offsets_buffer_t);
auto task = [&](std::size_t start, std::size_t count) {
for (std::size_t i = start; i != start + count; ++i) {
node_offsets_t offsets{file + base_offset + i * sizeof(node_offsets_buffer_t)};
byte_t* tape = file + nodes_base_offset + offsets.head;
byte_t* vector = file + nodes_base_offset + offsets.vector;
nodes_[i] = node_t{tape, (scalar_t*)vector};
}
// Concurrently locate all nodes and vectors
std::atomic<std::uint64_t> done_tasks{0};
file_offset_t base_offset = sizeof(file_header_t);
file_offset_t nodes_base_offset = base_offset + size_ * sizeof(node_offsets_buffer_t);
auto task = [&](std::size_t thread_idx, std::size_t task_idx) {
node_offsets_t offsets{file + base_offset + task_idx * sizeof(node_offsets_buffer_t)};
byte_t* tape = file + nodes_base_offset + offsets.head;
byte_t* vector = file + nodes_base_offset + offsets.vector;
nodes_[task_idx] = node_t{tape, (scalar_t*)vector};
++done_tasks;
if (thread_idx == 0)
progress(done_tasks, size_);
};

// Run threads
std::vector<std::thread> threads;
for (std::size_t i = 0; i < threads_count; ++i)
threads.push_back(std::thread(task, i * thread_tasks_count, thread_tasks_count));
task(threads_count * thread_tasks_count, main_thread_tasks_count);

// Wait to finish
progress(main_thread_tasks_count, size_);
for (std::size_t i = 0; i < threads_count; ++i) {
threads[i].join();
progress(main_thread_tasks_count + i * threads_count, size_);
}
executor.execute_bulk(size_, task);

return {};
}
Expand Down

0 comments on commit c3a3693

Please sign in to comment.