Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support thread local object iteration #2632

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 29 additions & 52 deletions src/butil/thread_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pthread_mutex_t g_thread_key_mutex = PTHREAD_MUTEX_INITIALIZER;
static size_t g_id = 0;
static std::deque<size_t>* g_free_ids = NULL;
static std::vector<ThreadKeyInfo>* g_thread_keys = NULL;
static __thread std::vector<ThreadKeyTLS>* g_tls_data = NULL;
static __thread std::vector<ThreadKeyTLS>* thread_key_tls_data = NULL;

ThreadKey& ThreadKey::operator=(ThreadKey&& other) noexcept {
if (this == &other) {
Expand All @@ -56,58 +56,42 @@ bool ThreadKey::Valid() const {
}

static void DestroyTlsData() {
if (!g_tls_data) {
if (!thread_key_tls_data) {
return;
}
std::vector<ThreadKeyInfo> dummy_keys;
{
BAIDU_SCOPED_LOCK(g_thread_key_mutex);
if (BAIDU_LIKELY(g_thread_keys)) {
dummy_keys.insert(dummy_keys.end(), g_thread_keys->begin(), g_thread_keys->end());
}
dummy_keys.insert(dummy_keys.end(),
g_thread_keys->begin(),
g_thread_keys->end());
}
for (size_t i = 0; i < g_tls_data->size(); ++i) {
for (size_t i = 0; i < thread_key_tls_data->size(); ++i) {
if (!KEY_UNUSED(dummy_keys[i].seq) && dummy_keys[i].dtor) {
dummy_keys[i].dtor((*g_tls_data)[i].data);
dummy_keys[i].dtor((*thread_key_tls_data)[i].data);
}
}
delete g_tls_data;
g_tls_data = NULL;
}

static std::deque<size_t>* GetGlobalFreeIds() {
if (BAIDU_UNLIKELY(!g_free_ids)) {
g_free_ids = new (std::nothrow) std::deque<size_t>();
if (BAIDU_UNLIKELY(!g_free_ids)) {
abort();
}
}

return g_free_ids;
delete thread_key_tls_data;
thread_key_tls_data = NULL;
}

int thread_key_create(ThreadKey& thread_key, DtorFunction dtor) {
BAIDU_SCOPED_LOCK(g_thread_key_mutex);
size_t id;
auto free_ids = GetGlobalFreeIds();
if (!free_ids) {
return ENOMEM;
if (BAIDU_UNLIKELY(!g_free_ids)) {
g_free_ids = new std::deque<size_t>;
}

if (!free_ids->empty()) {
id = free_ids->back();
free_ids->pop_back();
size_t id;
if (!g_free_ids->empty()) {
id = g_free_ids->back();
g_free_ids->pop_back();
} else {
if (g_id >= ThreadKey::InvalidID) {
// No more available ids.
return EAGAIN;
}
id = g_id++;
if(BAIDU_UNLIKELY(!g_thread_keys)) {
g_thread_keys = new (std::nothrow) std::vector<ThreadKeyInfo>;
if(BAIDU_UNLIKELY(!g_thread_keys)) {
return ENOMEM;
}
if (BAIDU_UNLIKELY(!g_thread_keys)) {
g_thread_keys = new std::vector<ThreadKeyInfo>;
g_thread_keys->reserve(THREAD_KEY_RESERVE);
}
g_thread_keys->resize(id + 1);
Expand Down Expand Up @@ -136,14 +120,10 @@ int thread_key_delete(ThreadKey& thread_key) {
return EINVAL;
}

if (BAIDU_UNLIKELY(!GetGlobalFreeIds())) {
return ENOMEM;
}

++((*g_thread_keys)[id].seq);
// Collect the usable key id for reuse.
if (KEY_USABLE((*g_thread_keys)[id].seq)) {
GetGlobalFreeIds()->push_back(id);
g_free_ids->push_back(id);
}
thread_key.Reset();

Expand All @@ -156,22 +136,19 @@ int thread_setspecific(ThreadKey& thread_key, void* data) {
}
size_t id = thread_key._id;
size_t seq = thread_key._seq;
if (BAIDU_UNLIKELY(!g_tls_data)) {
g_tls_data = new (std::nothrow) std::vector<ThreadKeyTLS>;
if (BAIDU_UNLIKELY(!g_tls_data)) {
return ENOMEM;
}
g_tls_data->reserve(THREAD_KEY_RESERVE);
if (BAIDU_UNLIKELY(!thread_key_tls_data)) {
thread_key_tls_data = new std::vector<ThreadKeyTLS>;
thread_key_tls_data->reserve(THREAD_KEY_RESERVE);
// Register the destructor of tls_data in this thread.
butil::thread_atexit(DestroyTlsData);
}

if (id >= g_tls_data->size()) {
g_tls_data->resize(id + 1);
if (id >= thread_key_tls_data->size()) {
thread_key_tls_data->resize(id + 1);
}

(*g_tls_data)[id].seq = seq;
(*g_tls_data)[id].data = data;
(*thread_key_tls_data)[id].seq = seq;
(*thread_key_tls_data)[id].data = data;

return 0;
}
Expand All @@ -182,13 +159,13 @@ void* thread_getspecific(ThreadKey& thread_key) {
}
size_t id = thread_key._id;
size_t seq = thread_key._seq;
if (BAIDU_UNLIKELY(!g_tls_data ||
id >= g_tls_data->size() ||
(*g_tls_data)[id].seq != seq)){
if (BAIDU_UNLIKELY(!thread_key_tls_data ||
id >= thread_key_tls_data->size() ||
(*thread_key_tls_data)[id].seq != seq)){
return NULL;
}

return (*g_tls_data)[id].data;
return (*thread_key_tls_data)[id].data;
}

} // namespace butil
28 changes: 23 additions & 5 deletions src/butil/thread_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <stdlib.h>
#include <vector>
#include "butil/scoped_lock.h"
#include "butil/type_traits.h"

namespace butil {

Expand All @@ -38,7 +39,7 @@ class ThreadKey {
static constexpr size_t InvalidID = std::numeric_limits<size_t>::max();
static constexpr size_t InitSeq = 0;

constexpr ThreadKey() :_id(InvalidID), _seq(InitSeq) {}
constexpr ThreadKey() : _id(InvalidID), _seq(InitSeq) {}

~ThreadKey() {
Reset();
Expand All @@ -62,7 +63,7 @@ class ThreadKey {
_seq = InitSeq;
}

private:
private:
size_t _id; // Key id.
// Sequence number form g_thread_keys set in thread_key_create.
size_t _seq;
Expand Down Expand Up @@ -111,6 +112,20 @@ class ThreadLocal {

T& operator*() const { return *get(); }

// Iterate through all thread local objects.
// Callback, which must accept Args params and return void,
// will be called under a thread lock.
template <typename Callback>
void for_each(Callback&& callback) {
BAIDU_CASSERT(
(is_result_void<Callback, T*>::value),
"Callback must accept Args params and return void");
BAIDU_SCOPED_LOCK(_mutex);
for (auto ptr : ptrs) {
callback(ptr);
}
}

void reset(T* ptr);

void reset() {
Expand Down Expand Up @@ -177,6 +192,9 @@ T* ThreadLocal<T>::get() {
template <typename T>
void ThreadLocal<T>::reset(T* ptr) {
T* old_ptr = get();
if (ptr == old_ptr) {
return;
}
if (thread_setspecific(_key, ptr) != 0) {
return;
}
Expand All @@ -187,9 +205,9 @@ void ThreadLocal<T>::reset(T* ptr) {
}
// Remove and delete old_ptr.
if (old_ptr) {
auto iter = std::find(ptrs.begin(), ptrs.end(), old_ptr);
if (iter!=ptrs.end()) {
ptrs.erase(iter);
auto iter = std::remove(ptrs.begin(), ptrs.end(), old_ptr);
if (iter != ptrs.end()) {
ptrs.erase(iter, ptrs.end());
chenBright marked this conversation as resolved.
Show resolved Hide resolved
}
DefaultDtor(old_ptr);
}
Expand Down
34 changes: 34 additions & 0 deletions src/butil/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,40 @@ template <typename T> struct is_enum<const T> : is_enum<T> { };
template <typename T> struct is_enum<volatile T> : is_enum<T> { };
template <typename T> struct is_enum<const volatile T> : is_enum<T> { };

// Deduces the return type of an INVOKE expression
// at compile time.
// If the callable is non-static member function,
// the first argument should be the class type.
#if (__cplusplus >= 201703L)
// std::result_of is deprecated in C++17 and removed in C++20,
// use std::invoke_result instead.
template <typename>
struct result_of;
template <typename F, typename... Args>
struct result_of<F(Args...)> : std::invoke_result<F, Args...> {};
#elif (__cplusplus >= 201103L)
template <typename F>
using result_of = std::result_of<F>;
#else
#error Only C++11 or later is supported.
#endif

template <typename F>
using result_of_t = typename result_of<F>::type;

// Whether a callable returns type which is same as ReturnType.
template<typename ReturnType, typename F, typename... Args>
struct is_result_same
: public butil::is_same<ReturnType, result_of_t<F(Args...)>> {};

// Whether a callable returns void.
template<typename F, typename... Args>
struct is_result_void : public is_result_same<void, F, Args...> {};

// Whether a callable returns int.
template<typename F, typename... Args>
struct is_result_int : public is_result_same<int, F, Args...> {};

} // namespace butil

#endif // BUTIL_TYPE_TRAITS_H
4 changes: 2 additions & 2 deletions test/endpoint_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,11 @@ TEST(EndPointTest, tcp_connect) {
ASSERT_EQ(0, butil::hostname2endpoint(g_hostname, 80, &ep));
{
butil::fd_guard sockfd(butil::tcp_connect(ep, NULL));
ASSERT_LE(0, sockfd);
ASSERT_LE(0, sockfd) << "errno=" << errno;
}
{
butil::fd_guard sockfd(butil::tcp_connect(ep, NULL, 1000));
ASSERT_LE(0, sockfd);
ASSERT_LE(0, sockfd) << "errno=" << errno;
}
{
butil::fd_guard sockfd(butil::tcp_connect(ep, NULL, 1));
Expand Down
44 changes: 42 additions & 2 deletions test/thread_key_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ TEST(ThreadLocalTest, thread_key_seq) {
}
}

void* THreadKeyCreateAndDeleteFunc(void* arg) {
void* THreadKeyCreateAndDeleteFunc(void*) {
while (!g_stopped) {
ThreadKey key;
EXPECT_EQ(0, butil::thread_key_create(key, NULL));
Expand Down Expand Up @@ -162,7 +162,7 @@ TEST(ThreadLocalTest, thread_local_multi_thread) {
ASSERT_EQ(0, pthread_create(&threads[i], NULL, ThreadLocalFunc, &args));
}

sleep(5);
sleep(2);
g_stopped = true;
for (const auto& thread : threads) {
pthread_join(thread, NULL);
Expand All @@ -172,6 +172,46 @@ TEST(ThreadLocalTest, thread_local_multi_thread) {
}
}

butil::atomic<int> g_counter(0);

void* ThreadLocalForEachFunc(void* arg) {
auto counter = static_cast<ThreadLocal<butil::atomic<int>>*>(arg);
auto local_counter = counter->get();
EXPECT_NE(nullptr, local_counter);
while (!g_stopped) {
local_counter->fetch_add(1, butil::memory_order_relaxed);
g_counter.fetch_add(1, butil::memory_order_relaxed);
if (butil::fast_rand_less_than(100) + 1 > 80) {
local_counter = new butil::atomic<int>(
local_counter->load(butil::memory_order_relaxed));
counter->reset(local_counter);
}
}
return NULL;
}

TEST(ThreadLocalTest, thread_local_for_each) {
g_stopped = false;
ThreadLocal<butil::atomic<int>> counter(false);
const int thread_num = 8;
pthread_t threads[thread_num];
for (int i = 0; i < thread_num; ++i) {
ASSERT_EQ(0, pthread_create(
&threads[i], NULL, ThreadLocalForEachFunc, &counter));
}

sleep(2);
g_stopped = true;
for (const auto& thread : threads) {
pthread_join(thread, NULL);
}
int count = 0;
counter.for_each([&count](butil::atomic<int>* c) {
count += c->load(butil::memory_order_relaxed);
});
ASSERT_EQ(count, g_counter.load(butil::memory_order_relaxed));
}

struct BAIDU_CACHELINE_ALIGNMENT ThreadKeyArg {
std::vector<ThreadKey*> thread_keys;
bool ready_delete = false;
Expand Down
Loading