Skip to content

Commit

Permalink
add uthread stack protect.
Browse files Browse the repository at this point in the history
  • Loading branch information
lynncui00 committed Sep 5, 2016
1 parent ca193c9 commit 8e6e694
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 92 deletions.
1 change: 1 addition & 0 deletions codegen/server_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ PackageName = $PackageName$
MaxConnections = 800000
MaxQueueLength = 20480
FastRejectThresholdMS = 20
FastRejectAdjustRate = 5
[ServerTimeout]
SocketTimeoutMS = 5000
Expand Down
2 changes: 1 addition & 1 deletion phxrpc/network/test_uthread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void execute(UThreadRuntime & runtime, size_t count) {
}

void run(size_t count) {
UThreadRuntime runtime(64 * 1024);
UThreadRuntime runtime(64 * 1024, false);

execute(runtime, count);

Expand Down
4 changes: 2 additions & 2 deletions phxrpc/network/test_uthread_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ using namespace phxrpc;
void f1(void *);
void f2(void *);

UThreadContextSystem c1(64 * 1024, &f1, nullptr, nullptr);
UThreadContextSystem c2(64 * 1024, &f2, nullptr, nullptr);
UThreadContextSystem c1(64 * 1024, &f1, nullptr, nullptr, true);
UThreadContextSystem c2(64 * 1024, &f2, nullptr, nullptr, true);

int test_count = 0;

Expand Down
5 changes: 3 additions & 2 deletions phxrpc/network/uthread_context_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ namespace phxrpc {
ContextCreateFunc_t UThreadContext::context_create_func_ = nullptr;

UThreadContext * UThreadContext :: Create(size_t stack_size,
UThreadFunc_t func, void * args, UThreadDoneCallback_t callback) {
UThreadFunc_t func, void * args,
UThreadDoneCallback_t callback, const bool need_stack_protect) {
if (context_create_func_ != nullptr) {
return context_create_func_(stack_size, func, args, callback);
return context_create_func_(stack_size, func, args, callback, need_stack_protect);
}
return nullptr;
}
Expand Down
5 changes: 3 additions & 2 deletions phxrpc/network/uthread_context_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ class UThreadContext;
typedef std::function< void(void *) > UThreadFunc_t;
typedef std::function< void() > UThreadDoneCallback_t;
typedef std::function< UThreadContext*
(size_t, UThreadFunc_t, void *, UThreadDoneCallback_t) > ContextCreateFunc_t;
(size_t, UThreadFunc_t, void *, UThreadDoneCallback_t, const bool) > ContextCreateFunc_t;

class UThreadContext {
public:
UThreadContext() { }
virtual ~UThreadContext() { }

static UThreadContext * Create(size_t stack_size,
UThreadFunc_t func, void * args, UThreadDoneCallback_t callback);
UThreadFunc_t func, void * args,
UThreadDoneCallback_t callback, const bool need_stack_protect);
static void SetContextCreateFunc(ContextCreateFunc_t context_create_func);
static ContextCreateFunc_t GetContextCreateFunc();

Expand Down
20 changes: 8 additions & 12 deletions phxrpc/network/uthread_context_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,27 @@ See the AUTHORS file for names of contributors.

namespace phxrpc {

UThreadContextSystem :: UThreadContextSystem(size_t stack_size, UThreadFunc_t func, void * args, UThreadDoneCallback_t callback)
: func_(func), args_(args), stack_(nullptr), stack_size_(stack_size),
protect_page_(0), callback_(callback) {

stack_ = (char *)calloc(1, stack_size_);
assert(stack_ != nullptr);

UThreadContextSystem :: UThreadContextSystem(size_t stack_size, UThreadFunc_t func, void * args,
UThreadDoneCallback_t callback, const bool need_stack_protect)
: func_(func), args_(args), stack_(stack_size, need_stack_protect), callback_(callback) {
Make(func, args);
}

UThreadContextSystem :: ~UThreadContextSystem() {
free(stack_);
}

UThreadContext * UThreadContextSystem :: DoCreate(size_t stack_size,
UThreadFunc_t func, void * args, UThreadDoneCallback_t callback) {
return new UThreadContextSystem(stack_size, func, args, callback);
UThreadFunc_t func, void * args, UThreadDoneCallback_t callback,
const bool need_stack_protect) {
return new UThreadContextSystem(stack_size, func, args, callback, need_stack_protect);
}

void UThreadContextSystem :: Make(UThreadFunc_t func, void * args) {
func_ = func;
args_ = args;
getcontext(&context_);
context_.uc_stack.ss_sp = stack_;
context_.uc_stack.ss_size = stack_size_;
context_.uc_stack.ss_sp = stack_.top();
context_.uc_stack.ss_size = stack_.size();
context_.uc_stack.ss_flags = 0;
context_.uc_link = GetMainContext();
uintptr_t ptr = (uintptr_t)this;
Expand Down
11 changes: 6 additions & 5 deletions phxrpc/network/uthread_context_system.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@ See the AUTHORS file for names of contributors.
#include <assert.h>

#include "uthread_context_base.h"
#include "uthread_context_util.h"

namespace phxrpc {

class UThreadContextSystem : public UThreadContext {
public:
UThreadContextSystem(size_t stack_size, UThreadFunc_t func, void * args, UThreadDoneCallback_t callback);
UThreadContextSystem(size_t stack_size, UThreadFunc_t func, void * args,
UThreadDoneCallback_t callback, const bool need_stack_protect);
~UThreadContextSystem();

static UThreadContext * DoCreate(size_t stack_size,
UThreadFunc_t func, void * args, UThreadDoneCallback_t callback);
UThreadFunc_t func, void * args, UThreadDoneCallback_t callback,
const bool need_stack_protect);

void Make(UThreadFunc_t func, void * args) override;
bool Resume() override;
Expand All @@ -50,9 +53,7 @@ class UThreadContextSystem : public UThreadContext {
ucontext_t context_;
UThreadFunc_t func_;
void * args_;
char * stack_;
size_t stack_size_;
int protect_page_;
UThreadStackMemory stack_;
UThreadDoneCallback_t callback_;
};

Expand Down
44 changes: 32 additions & 12 deletions phxrpc/network/uthread_context_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,47 @@ See the AUTHORS file for names of contributors.
#include "uthread_context_util.h"
#include <assert.h>
#include <unistd.h>
#include <sys/mman.h>

namespace phxrpc {

int UThreadProtectStack(void * stack_top, size_t stack_size) {
int page = STACK_PROTECT_PAGE;
UThreadStackMemory :: UThreadStackMemory(const size_t stack_size, const bool need_protect) :
raw_stack_(nullptr), stack_(nullptr), need_protect_(need_protect) {
int page_size = getpagesize();
assert(stack_size >= (size_t)page_size * (page + 1));
void * protect_addr = stack_top;
if ((size_t)protect_addr & (page_size - 1)) {
protect_addr = (void *)(((size_t)stack_top & (~(page_size - 1))) + page_size);
if ((stack_size % page_size) != 0) {
stack_size_ = (stack_size / page_size + 1) * page_size;
} else {
stack_size_ = stack_size;
}

if (need_protect) {
raw_stack_ = mmap(NULL, stack_size_ + page_size * 2,
PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
assert(raw_stack_ != nullptr);
assert(mprotect(raw_stack_, page_size, PROT_NONE) == 0);
assert(mprotect((void *)((char *)raw_stack_ + stack_size_ + page_size), page_size, PROT_NONE) == 0);
stack_ = (void *)((char *)raw_stack_ + page_size);
} else {
raw_stack_ = mmap(NULL, stack_size_, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
assert(raw_stack_ != nullptr);
stack_ = raw_stack_;

This comment has been minimized.

Copy link
@berockguo

berockguo Sep 8, 2016

十分漂亮的实现,用mprotect保护协程上下的两块内存,保证栈内容不会被修改。

}
return mprotect(protect_addr, page_size * page, PROT_NONE);
}

int UThreadUnProtectStack(void * stack_top, int page) {
void * protect_addr = stack_top;
UThreadStackMemory :: ~UThreadStackMemory() {
int page_size = getpagesize();
if ((size_t)protect_addr & (page_size - 1)) {
protect_addr = (void *)(((size_t)stack_top & (~(page_size - 1))) + page_size);
if (need_protect_) {
assert(mprotect(raw_stack_, page_size, PROT_READ | PROT_WRITE) == 0);
assert(mprotect((void *)((char *)raw_stack_ + stack_size_ + page_size), page_size, PROT_READ | PROT_WRITE) == 0);
}
return mprotect(protect_addr, page_size * page, PROT_READ | PROT_WRITE);
}

void * UThreadStackMemory :: top() {
return stack_;
}

size_t UThreadStackMemory :: size() {
return stack_size_;
}

} //namespace phxrpc
18 changes: 14 additions & 4 deletions phxrpc/network/uthread_context_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,19 @@ See the AUTHORS file for names of contributors.

namespace phxrpc {

#define STACK_PROTECT_PAGE 1

int UThreadProtectStack(void * stack_top, size_t stack_size);
int UThreadUnProtectStack(void * stack_top, int page);
class UThreadStackMemory {
public:
UThreadStackMemory(const size_t stack_size, const bool need_protect = true);
~UThreadStackMemory();

void * top();
size_t size();

private:
void * raw_stack_;
void * stack_;
size_t stack_size_;
int need_protect_;
};

} //namespace phxrpc
33 changes: 15 additions & 18 deletions phxrpc/network/uthread_epoll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ enum UThreadEpollREventStatus {
UThreadEpollREvent_Close = -2,
};

UThreadEpollScheduler::UThreadEpollScheduler(size_t stack_size, int max_task) :
epoll_wake_up_(this) {
runtime_ = new UThreadRuntime(stack_size);
UThreadEpollScheduler::UThreadEpollScheduler(size_t stack_size, int max_task, const bool need_stack_protect) :
runtime_(stack_size, need_stack_protect), epoll_wake_up_(this) {
max_task_ = max_task;

epoll_fd_ = epoll_create(max_task_);
Expand All @@ -125,8 +124,6 @@ UThreadEpollScheduler::UThreadEpollScheduler(size_t stack_size, int max_task) :
}

UThreadEpollScheduler::~UThreadEpollScheduler() {
delete runtime_;

close(epoll_fd_);
}

Expand All @@ -148,11 +145,11 @@ void UThreadEpollScheduler :: SetHandlerAcceptedFdFunc(UThreadHanderAcceptedFdFu
}

bool UThreadEpollScheduler::YieldTask() {
return runtime_->Yield();
return runtime_.Yield();
}

int UThreadEpollScheduler::GetCurrUThread() {
return runtime_->GetCurrUThread();
return runtime_.GetCurrUThread();
}

UThreadSocket_t * UThreadEpollScheduler::CreateSocket(int fd, int socket_timeout_ms,
Expand Down Expand Up @@ -181,8 +178,8 @@ UThreadSocket_t * UThreadEpollScheduler::CreateSocket(int fd, int socket_timeout
void UThreadEpollScheduler::ConsumeTodoList() {
while (!todo_list_.empty()) {
auto & it = todo_list_.front();
int id = runtime_->Create(it.first, it.second);
runtime_->Resume(id);
int id = runtime_.Create(it.first, it.second);
runtime_.Resume(id);

todo_list_.pop();
}
Expand All @@ -203,7 +200,7 @@ void UThreadEpollScheduler :: ResumeAll(int flag) {
std::vector<UThreadSocket_t*> exist_socket_list = timer_.GetSocketList();
for (auto & socket : exist_socket_list) {
socket->waited_events = flag;
runtime_->Resume(socket->uthread_id);
runtime_.Resume(socket->uthread_id);
}
}

Expand Down Expand Up @@ -231,21 +228,21 @@ bool UThreadEpollScheduler::Run() {

int next_timeout = timer_.GetNextTimeout();

for (; (run_forever_) || (!runtime_->IsAllDone());) {
for (; (run_forever_) || (!runtime_.IsAllDone());) {
int nfds = epoll_wait(epoll_fd_, events, max_task_, 4);
if (nfds != -1) {
for (int i = 0; i < nfds; i++) {
UThreadSocket_t * socket = (UThreadSocket_t*) events[i].data.ptr;
socket->waited_events = events[i].events;

runtime_->Resume(socket->uthread_id);
runtime_.Resume(socket->uthread_id);
}

//for server mode
if (active_socket_func_ != nullptr) {
UThreadSocket_t * socket = nullptr;
while ((socket = active_socket_func_()) != nullptr) {
runtime_->Resume(socket->uthread_id);
runtime_.Resume(socket->uthread_id);
}
}

Expand Down Expand Up @@ -298,7 +295,7 @@ void UThreadEpollScheduler::DealwithTimeout(int & next_timeout) {

UThreadSocket_t * socket = timer_.PopTimeout();
socket->waited_events = UThreadEpollREvent_Timeout;
runtime_->Resume(socket->uthread_id);
runtime_.Resume(socket->uthread_id);
}
}

Expand Down Expand Up @@ -503,22 +500,22 @@ void UThreadSetArgs(UThreadSocket_t & socket, void * args) {
socket.args = args;
}

void * UthreadGetArgs(UThreadSocket_t & socket) {
void * UThreadGetArgs(UThreadSocket_t & socket) {
return socket.args;
}

void UthreadWait(UThreadSocket_t & socket, int timeout_ms) {
void UThreadWait(UThreadSocket_t & socket, int timeout_ms) {
socket.uthread_id = socket.scheduler->GetCurrUThread();
socket.scheduler->AddTimer(&socket, timeout_ms);
socket.scheduler->YieldTask();
socket.scheduler->RemoveTimer(socket.timer_id);
}

void UthreadLazyDestory(UThreadSocket_t & socket) {
void UThreadLazyDestory(UThreadSocket_t & socket) {
socket.uthread_id = -1;
}

bool IsUthreadDestory(UThreadSocket_t & socket) {
bool IsUThreadDestory(UThreadSocket_t & socket) {
return socket.uthread_id == -1;
}

Expand Down
13 changes: 7 additions & 6 deletions phxrpc/network/uthread_epoll.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class EpollNotifier {

class UThreadEpollScheduler {
public:
UThreadEpollScheduler(size_t stack_size, int max_task);
UThreadEpollScheduler(size_t stack_size, int max_task, const bool need_stack_protect = true);
~UThreadEpollScheduler();

static UThreadEpollScheduler * Instance();
Expand Down Expand Up @@ -95,7 +95,7 @@ class UThreadEpollScheduler {
void StatEpollwaitEvents(const int event_count);

private:
UThreadRuntime * runtime_;
UThreadRuntime runtime_;
int max_task_;
TaskQueue todo_list_;
int epoll_fd_;
Expand Down Expand Up @@ -126,6 +126,7 @@ class __uthread {
};

#define uthread_begin phxrpc::UThreadEpollScheduler _uthread_scheduler(64 * 1024, 300);
#define uthread_begin_withargs(stack_size, max_task) phxrpc::UThreadEpollScheduler _uthread_scheduler(stack_size, max_task);
#define uthread_s _uthread_scheduler
#define uthread_t phxrpc::__uthread(_uthread_scheduler)-
#define uthread_end _uthread_scheduler.Run();
Expand Down Expand Up @@ -162,13 +163,13 @@ UThreadSocket_t * NewUThreadSocket();

void UThreadSetArgs(UThreadSocket_t & socket, void * args);

void * UthreadGetArgs(UThreadSocket_t & socket);
void * UThreadGetArgs(UThreadSocket_t & socket);

void UthreadWait(UThreadSocket_t & socket, int timeout_ms);
void UThreadWait(UThreadSocket_t & socket, int timeout_ms);

void UthreadLazyDestory(UThreadSocket_t & socket);
void UThreadLazyDestory(UThreadSocket_t & socket);

bool IsUthreadDestory(UThreadSocket_t & socket);
bool IsUThreadDestory(UThreadSocket_t & socket);

};

8 changes: 5 additions & 3 deletions phxrpc/network/uthread_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ enum {

namespace phxrpc {

UThreadRuntime :: UThreadRuntime(size_t stack_size)
UThreadRuntime :: UThreadRuntime(size_t stack_size, const bool need_stack_protect)
:stack_size_(stack_size), first_done_item_(-1),
current_uthread_(-1), unfinished_item_count_(0) {
current_uthread_(-1), unfinished_item_count_(0),
need_stack_protect_(need_stack_protect) {
if (UThreadContext::GetContextCreateFunc() == nullptr) {
UThreadContext::SetContextCreateFunc(UThreadContextSystem::DoCreate);
}
Expand All @@ -59,7 +60,8 @@ int UThreadRuntime :: Create(UThreadFunc_t func, void * args) {
} else {
index = context_list_.size();
auto new_context = UThreadContext::Create(stack_size_, func, args,
std::bind(&UThreadRuntime::UThreadDoneCallback, this));
std::bind(&UThreadRuntime::UThreadDoneCallback, this),
need_stack_protect_);
assert(new_context != nullptr);
ContextSlot context_slot;
context_slot.context = new_context;
Expand Down
Loading

0 comments on commit 8e6e694

Please sign in to comment.