Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support user interceptor of server #2137

Merged
merged 4 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions src/brpc/policy/baidu_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,22 +309,6 @@ void EndRunningCallMethodInPool(
return EndRunningUserCodeInPool(CallMethodInBackupThread, args);
};

// Returns true if accept request, reject request otherwise.
bool AcceptRequest(const Server* server, Controller* cntl) {
const Interceptor* interceptor = server->options().interceptor;
int error_code = 0;
std::string error_text;
if (interceptor &&
!interceptor->Accept(cntl, error_code, error_text)) {
cntl->SetFailed(error_code,
"Reject by Interceptor: %s",
error_text.c_str());
return false;
}

return true;
}

void ProcessRpcRequest(InputMessageBase* msg_base) {
const int64_t start_parse_us = butil::cpuwide_time_us();
DestroyingPtr<MostCommonMessage> msg(static_cast<MostCommonMessage*>(msg_base));
Expand Down Expand Up @@ -479,7 +463,8 @@ void ProcessRpcRequest(InputMessageBase* msg_base) {
const google::protobuf::MethodDescriptor* method = mp->method;
accessor.set_method(method);

if (!AcceptRequest(server, cntl.get())) {

if (!server->AcceptRequest(cntl.get())) {
break;
}

Expand Down
5 changes: 1 addition & 4 deletions src/brpc/policy/http_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1251,9 +1251,6 @@ void EndRunningCallMethodInPool(
::google::protobuf::Message* response,
::google::protobuf::Closure* done);

// Defined in baidu_rpc_protocol.cpp
bool AcceptRequest(const Server* server, Controller* cntl);

void ProcessHttpRequest(InputMessageBase *msg) {
const int64_t start_parse_us = butil::cpuwide_time_us();
DestroyingPtr<HttpContext> imsg_guard(static_cast<HttpContext*>(msg));
Expand Down Expand Up @@ -1433,7 +1430,7 @@ void ProcessHttpRequest(InputMessageBase *msg) {
" -usercode_in_pthread is on");
return;
}
if (!AcceptRequest(server, cntl)) {
if (!server->AcceptRequest(cntl)) {
return;
}
} else if (security_mode) {
Expand Down
5 changes: 1 addition & 4 deletions src/brpc/policy/hulu_pbrpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,6 @@ void EndRunningCallMethodInPool(
::google::protobuf::Message* response,
::google::protobuf::Closure* done);

// Defined in baidu_rpc_protocol.cpp
bool AcceptRequest(const Server* server, Controller* cntl);

void ProcessHuluRequest(InputMessageBase* msg_base) {
const int64_t start_parse_us = butil::cpuwide_time_us();
DestroyingPtr<MostCommonMessage> msg(static_cast<MostCommonMessage*>(msg_base));
Expand Down Expand Up @@ -473,7 +470,7 @@ void ProcessHuluRequest(InputMessageBase* msg_base) {
const google::protobuf::MethodDescriptor* method = sp->method;
accessor.set_method(method);

if (!AcceptRequest(server, cntl.get())) {
if (!server->AcceptRequest(cntl.get())) {
break;
}

Expand Down
5 changes: 1 addition & 4 deletions src/brpc/policy/nshead_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ static void EndRunningCallMethodInPool(NsheadService* service,
return EndRunningUserCodeInPool(CallMethodInBackupThread, args);
};

// Defined in baidu_rpc_protocol.cpp
bool AcceptRequest(const Server* server, Controller* cntl);

void ProcessNsheadRequest(InputMessageBase* msg_base) {
const int64_t start_parse_us = butil::cpuwide_time_us();

Expand Down Expand Up @@ -318,7 +315,7 @@ void ProcessNsheadRequest(InputMessageBase* msg_base) {
" -usercode_in_pthread is on");
break;
}
if (!AcceptRequest(server, cntl)) {
if (!server->AcceptRequest(cntl)) {
break;
}
} while (false);
Expand Down
5 changes: 1 addition & 4 deletions src/brpc/policy/sofa_pbrpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,6 @@ void EndRunningCallMethodInPool(
::google::protobuf::Message* response,
::google::protobuf::Closure* done);

// Defined in baidu_rpc_protocol.cpp
bool AcceptRequest(const Server* server, Controller* cntl);

void ProcessSofaRequest(InputMessageBase* msg_base) {
const int64_t start_parse_us = butil::cpuwide_time_us();
DestroyingPtr<MostCommonMessage> msg(static_cast<MostCommonMessage*>(msg_base));
Expand Down Expand Up @@ -424,7 +421,7 @@ void ProcessSofaRequest(InputMessageBase* msg_base) {
const google::protobuf::MethodDescriptor* method = sp->method;
accessor.set_method(method);

if (!AcceptRequest(server, cntl.get())) {
if (!server->AcceptRequest(cntl.get())) {
break;
}

Expand Down
5 changes: 1 addition & 4 deletions src/brpc/policy/thrift_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,6 @@ static void EndRunningCallMethodInPool(ThriftService* service,
return EndRunningUserCodeInPool(CallMethodInBackupThread, args);
};

// Defined in baidu_rpc_protocol.cpp
bool AcceptRequest(const Server* server, Controller* cntl);

void ProcessThriftRequest(InputMessageBase* msg_base) {
const int64_t start_parse_us = butil::cpuwide_time_us();

Expand Down Expand Up @@ -541,7 +538,7 @@ void ProcessThriftRequest(InputMessageBase* msg_base) {
" -usercode_in_pthread is on");
}

if (!AcceptRequest(server, cntl)) {
if (!server->AcceptRequest(cntl)) {
return;
}

Expand Down
17 changes: 16 additions & 1 deletion src/brpc/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ Server::~Server() {
delete _options.auth;
_options.auth = NULL;
}
if (_options.interceptor) {
if (_options.server_owns_interceptor) {
delete _options.interceptor;
_options.interceptor = NULL;
}
Expand Down Expand Up @@ -2197,6 +2197,21 @@ int Server::MaxConcurrencyOf(google::protobuf::Service* service,
return MaxConcurrencyOf(service->GetDescriptor()->full_name(), method_name);
}

bool Server::AcceptRequest(Controller* cntl) const {
const Interceptor* interceptor = _options.interceptor;
chenBright marked this conversation as resolved.
Show resolved Hide resolved
int error_code = 0;
std::string error_text;
if (cntl && interceptor &&
!interceptor->Accept(cntl, error_code, error_text)) {
cntl->SetFailed(error_code,
chenBright marked this conversation as resolved.
Show resolved Hide resolved
"Reject by Interceptor: %s",
error_text.c_str());
return false;
}

return true;
}

#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl,
int* al, Server* server) {
Expand Down
3 changes: 3 additions & 0 deletions src/brpc/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ class Server {
return butil::subtle::NoBarrier_Load(&_concurrency);
};

// Returns true if accept request, reject request otherwise.
bool AcceptRequest(Controller* cntl) const;

private:
friend class StatusService;
friend class ProtobufsService;
Expand Down
54 changes: 49 additions & 5 deletions test/brpc_interceptor_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "brpc/channel.h"
#include "brpc/server.h"
#include "brpc/interceptor.h"
#include "brpc/nshead_service.h"
#include "echo.pb.h"

namespace brpc {
Expand All @@ -40,6 +41,7 @@ int g_index = 0;
const int port = 8613;
const std::string EXP_REQUEST = "hello";
const std::string EXP_RESPONSE = "world";
const std::string NSHEAD_EXP_RESPONSE = "error";

class EchoServiceImpl : public ::test::EchoService {
public:
Expand All @@ -50,11 +52,34 @@ class EchoServiceImpl : public ::test::EchoService {
::test::EchoResponse* response,
google::protobuf::Closure* done) override {
brpc::ClosureGuard done_guard(done);
EXPECT_EQ(EXP_REQUEST, request->message());
ASSERT_EQ(EXP_REQUEST, request->message());
response->set_message(EXP_RESPONSE);
}
};

// Adapt your own nshead-based protocol to use brpc
class MyNsheadProtocol : public brpc::NsheadService {
public:
void ProcessNsheadRequest(const brpc::Server&,
brpc::Controller* cntl,
const brpc::NsheadMessage& request,
brpc::NsheadMessage* response,
brpc::NsheadClosure* done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);

response->head = request.head;
if (cntl->Failed()) {
ASSERT_TRUE(cntl->Failed());
ASSERT_EQ(EREJECT, cntl->ErrorCode());
response->body.append(NSHEAD_EXP_RESPONSE);
return;
}
response->body.append(EXP_RESPONSE);
}
};

class MyInterceptor : public brpc::Interceptor {
public:
MyInterceptor() = default;
Expand All @@ -81,15 +106,16 @@ class InterceptorTest : public ::testing::Test {
brpc::SERVER_DOESNT_OWN_SERVICE));
brpc::ServerOptions options;
options.interceptor = new MyInterceptor;
options.nshead_service = new MyNsheadProtocol;
options.server_owns_interceptor = true;
EXPECT_EQ(0, _server.Start(port, &options));
}

~InterceptorTest() override = default;

void CallMethod(test::EchoService_Stub& stub,
::test::EchoRequest& req,
::test::EchoResponse& res) {
static void CallMethod(test::EchoService_Stub& stub,
::test::EchoRequest& req,
::test::EchoResponse& res) {
for (g_index = 0; g_index < 1000; ++g_index) {
brpc::Controller cntl;
stub.Echo(&cntl, &req, &res, NULL);
Expand Down Expand Up @@ -144,7 +170,6 @@ TEST_F(InterceptorTest, sanity) {
CallMethod(stub, req, res);
}

g_index = 0;
// PROTOCOL_SOFA_PBRPC
{
brpc::Channel channel;
Expand All @@ -154,4 +179,23 @@ TEST_F(InterceptorTest, sanity) {
test::EchoService_Stub stub(&channel);
CallMethod(stub, req, res);
}

// PROTOCOL_NSHEAD
{
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = brpc::PROTOCOL_NSHEAD;
ASSERT_EQ(0, channel.Init("localhost", port, &options));
brpc::NsheadMessage request;
for (g_index = 0; g_index < 1000; ++g_index) {
brpc::Controller cntl;
brpc::NsheadMessage response;
channel.CallMethod(NULL, &cntl, &request, &response, NULL);
if (g_index % 2 == 0) {
ASSERT_EQ(NSHEAD_EXP_RESPONSE, response.body.to_string());
} else {
ASSERT_EQ(EXP_RESPONSE, response.body.to_string());
}
}
}
}