Skip to content

Commit

Permalink
Support A Multiple Producer, Single Consumer Wait-Free Queue
Browse files Browse the repository at this point in the history
  • Loading branch information
chenBright committed Jan 3, 2024
1 parent 023fa14 commit dd1d2cb
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 3 deletions.
189 changes: 189 additions & 0 deletions src/butil/containers/mpsc_queue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// A Multiple Producer, Single Consumer Wait-Free Queue.
// It allows multiple threads to enqueue, and allows one thread
// (and only one thread) to dequeue.

#ifndef BUTIL_MPSC_QUEUE_H
#define BUTIL_MPSC_QUEUE_H

#include "butil/object_pool.h"
#include "butil/type_traits.h"

namespace butil {

template <typename T>
struct BAIDU_CACHELINE_ALIGNMENT MPSCQueueNode {
static MPSCQueueNode* const UNCONNECTED;

MPSCQueueNode* next{NULL};
char data_mem[sizeof(T)]{};

};

template <typename T>
MPSCQueueNode<T>* const MPSCQueueNode<T>::UNCONNECTED = (MPSCQueueNode<T>*)(intptr_t)-1;

// Default allocator for MPSCQueueNode.
template <typename T>
class DefaultAllocator {
public:
void* Alloc() { return malloc(sizeof(MPSCQueueNode<T>)); }
void Free(void* p) { free(p); }
};

// Allocator using ObjectPool for MPSCQueueNode.
template <typename T>
class ObjectPoolAllocator {
public:
void* Alloc() { return get_object<MPSCQueueNode<T>>(); }
void Free(void* p) { return_object(p); }
};


template <typename T, typename Alloc = DefaultAllocator<T>>
class MPSCQueue {
public:
MPSCQueue()
: _head(NULL)
, _cur_enqueue_node(NULL)
, _cur_dequeue_node(NULL) {}

~MPSCQueue();

// Enqueue data to the queue.
void Enqueue(typename add_const_reference<T>::type data);
void Enqueue(T&& data);

// Dequeue data from the queue.
bool Dequeue(T& data);

private:
// Reverse the list until old_head.
void ReverseList(MPSCQueueNode<T>* old_head);

void EnqueueImpl(MPSCQueueNode<T>* node);
bool DequeueImpl(T* data);

Alloc _alloc;
atomic<MPSCQueueNode<T>*> _head;
atomic<MPSCQueueNode<T>*> _cur_enqueue_node;
MPSCQueueNode<T>* _cur_dequeue_node;
};

template <typename T, typename Alloc>
MPSCQueue<T, Alloc>::~MPSCQueue() {
while (DequeueImpl(NULL));
}

template <typename T, typename Alloc>
void MPSCQueue<T, Alloc>::Enqueue(typename add_const_reference<T>::type data) {
auto node = (MPSCQueueNode<T>*)_alloc.Alloc();
node->next = MPSCQueueNode<T>::UNCONNECTED;
new ((void*)&node->data_mem) T(data);
EnqueueImpl(node);
}

template <typename T, typename Alloc>
void MPSCQueue<T, Alloc>::Enqueue(T&& data) {
auto node = (MPSCQueueNode<T>*)_alloc.Alloc();
node->next = MPSCQueueNode<T>::UNCONNECTED;
new ((void*)&node->data_mem) T(std::forward<T>(data));
EnqueueImpl(node);
}

template <typename T, typename Alloc>
void MPSCQueue<T, Alloc>::EnqueueImpl(MPSCQueueNode<T>* node) {
MPSCQueueNode<T>* prev = _head.exchange(node, memory_order_release);
if (prev) {
node->next = prev;
return;
}
node->next = NULL;
_cur_enqueue_node.store(node, memory_order_relaxed);
}

template <typename T, typename Alloc>
bool MPSCQueue<T, Alloc>::Dequeue(T& data) {
return DequeueImpl(&data);
}

template <typename T, typename Alloc>
bool MPSCQueue<T, Alloc>::DequeueImpl(T* data) {
MPSCQueueNode<T>* node;
if (_cur_dequeue_node) {
node = _cur_dequeue_node;
} else {
node = _cur_enqueue_node.load(memory_order_relaxed);
}
if (!node) {
return false;
}

_cur_enqueue_node.store(NULL, memory_order_relaxed);
if (data) {
auto mem = (T* const)node->data_mem;
*data = std::move(*mem);
}
MPSCQueueNode<T>* old_node = node;
if (!node->next) {
ReverseList(node);
}
_cur_dequeue_node = node->next;
return_object(old_node);

return true;
}

template <typename T, typename Alloc>
void MPSCQueue<T, Alloc>::ReverseList(MPSCQueueNode<T>* old_head) {
// Try to set _write_head to NULL to mark that it is done.
MPSCQueueNode<T>* new_head = old_head;
MPSCQueueNode<T>* desired = NULL;
if (_head.compare_exchange_strong(
new_head, desired, memory_order_acquire)) {
// No one added new requests.
return;
}
CHECK_NE(new_head, old_head);
// Above acquire fence pairs release fence of exchange in Enqueue() to make
// sure that we see all fields of requests set.

// Someone added new requests.
// Reverse the list until old_head.
MPSCQueueNode<T>* tail = NULL;
MPSCQueueNode<T>* p = new_head;
do {
while (p->next == MPSCQueueNode<T>::UNCONNECTED) {
// TODO(gejun): elaborate this
sched_yield();
}
MPSCQueueNode<T>* const saved_next = p->next;
p->next = tail;
tail = p;
p = saved_next;
CHECK(p);
} while (p != old_head);

// Link old list with new list.
old_head->next = tail;
}

}

#endif // BUTIL_MPSC_QUEUE_H
6 changes: 3 additions & 3 deletions src/butil/thread_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.

#ifndef BRPC_THREAD_KEY_H
#define BRPC_THREAD_KEY_H
#ifndef BUTIL_THREAD_KEY_H
#define BUTIL_THREAD_KEY_H

#include <limits>
#include <pthread.h>
Expand Down Expand Up @@ -199,4 +199,4 @@ void ThreadLocal<T>::reset(T* ptr) {
}


#endif //BRPC_THREAD_KEY_H
#endif // BUTIL_THREAD_KEY_H
1 change: 1 addition & 0 deletions test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ TEST_BUTIL_SOURCES = [
"mru_cache_unittest.cc",
"small_map_unittest.cc",
"stack_container_unittest.cc",
"mpsc_queue_unittest.cc",
"cpu_unittest.cc",
"crash_logging_unittest.cc",
"leak_tracker_unittest.cc",
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ SET(TEST_BUTIL_SOURCES
${PROJECT_SOURCE_DIR}/test/mru_cache_unittest.cc
${PROJECT_SOURCE_DIR}/test/small_map_unittest.cc
${PROJECT_SOURCE_DIR}/test/stack_container_unittest.cc
${PROJECT_SOURCE_DIR}/test/mpsc_queue_unittest.cc
${PROJECT_SOURCE_DIR}/test/cpu_unittest.cc
${PROJECT_SOURCE_DIR}/test/crash_logging_unittest.cc
${PROJECT_SOURCE_DIR}/test/leak_tracker_unittest.cc
Expand Down
1 change: 1 addition & 0 deletions test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ TEST_BUTIL_SOURCES = \
mru_cache_unittest.cc \
small_map_unittest.cc \
stack_container_unittest.cc \
mpsc_queue_unittest.cc \
cpu_unittest.cc \
crash_logging_unittest.cc \
leak_tracker_unittest.cc \
Expand Down
124 changes: 124 additions & 0 deletions test/mpsc_queue_unittest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include <gtest/gtest.h>
#include <pthread.h>
#include "butil/containers/mpsc_queue.h"

namespace {

const uint MAX_COUNT = 100000000;

void Consume(butil::MPSCQueue<uint>& q, bool allow_empty) {
uint i = 0;
uint empty_count = 0;
while (true) {
uint d;
if (!q.Dequeue(d)) {
ASSERT_TRUE(allow_empty);
ASSERT_LT(empty_count++, (const uint)10000);
::usleep(10 * 1000);
continue;
}
ASSERT_EQ(i++, d);
if (i == MAX_COUNT) {
break;
}
}
}

void* ProduceThread(void* arg) {
auto q = (butil::MPSCQueue<uint>*)arg;
for (uint i = 0; i < MAX_COUNT; ++i) {
q->Enqueue(i);
}
return NULL;
}

void* ConsumeThread1(void* arg) {
auto q = (butil::MPSCQueue<uint>*)arg;
Consume(*q, true);
return NULL;
}

TEST(MPSCQueueTest, spsc_single_thread) {
butil::MPSCQueue<uint> q;
for (uint i = 0; i < MAX_COUNT; ++i) {
q.Enqueue(i);
}
Consume(q, false);
}

TEST(MPSCQueueTest, spsc_multi_thread) {
butil::MPSCQueue<uint> q;
pthread_t produce_tid;
ASSERT_EQ(0, pthread_create(&produce_tid, NULL, ProduceThread, &q));
pthread_t consume_tid;
ASSERT_EQ(0, pthread_create(&consume_tid, NULL, ConsumeThread1, &q));

pthread_join(produce_tid, NULL);
pthread_join(consume_tid, NULL);

}

butil::atomic<uint> g_index(0);
void* MultiProduceThread(void* arg) {
auto q = (butil::MPSCQueue<uint>*)arg;
while (true) {
uint i = g_index.fetch_add(1, butil::memory_order_relaxed);
if (i >= MAX_COUNT) {
break;
}
q->Enqueue(i);
}
return NULL;
}

butil::Mutex g_mutex;
bool g_counts[MAX_COUNT];
void Consume2(butil::MPSCQueue<uint>& q) {
uint empty_count = 0;
uint count = 0;
while (true) {
uint d;
if (!q.Dequeue(d)) {
ASSERT_LT(empty_count++, (const uint)10000);
::usleep(1 * 1000);
continue;
}
ASSERT_LT(d, MAX_COUNT);
{
BAIDU_SCOPED_LOCK(g_mutex);
ASSERT_FALSE(g_counts[d]);
g_counts[d] = true;
}
if (++count >= MAX_COUNT) {
break;
}
}
}

void* ConsumeThread2(void* arg) {
auto q = (butil::MPSCQueue<uint>*)arg;
Consume2(*q);
return NULL;
}

TEST(MPSCQueueTest, mpsc_multi_thread) {
butil::MPSCQueue<uint> q;

int thread_num = 8;
pthread_t threads[thread_num];
for (int i = 0; i < thread_num; ++i) {
ASSERT_EQ(0, pthread_create(&threads[i], NULL, MultiProduceThread, &q));
}

pthread_t consume_tid;
ASSERT_EQ(0, pthread_create(&consume_tid, NULL, ConsumeThread2, &q));

for (int i = 0; i < thread_num; ++i) {
pthread_join(threads[i], NULL);
}
pthread_join(consume_tid, NULL);

}


}

0 comments on commit dd1d2cb

Please sign in to comment.