Skip to content

Commit

Permalink
Support user interceptor of server (apache#2137)
Browse files Browse the repository at this point in the history
* Support user interceptor of server

* Optimize interceptor unittest

* Optimize interceptor

* interceptor return quickly
  • Loading branch information
Bright Chen authored and Yang Liming committed Jun 25, 2023
1 parent de7f3f9 commit 544818c
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/brpc/interceptor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#ifndef BRPC_INTERCEPTOR_H
#define BRPC_INTERCEPTOR_H

#include "brpc/controller.h"


namespace brpc {

class Interceptor {
public:
virtual ~Interceptor() = default;

// Returns true if accept request, reject request otherwise.
// When server rejects request, You can fill `error_code'
// and `error_txt' which will send to client.
virtual bool Accept(const brpc::Controller* controller,
int& error_code,
std::string& error_txt) const = 0;

};

} // namespace brpc


#endif //BRPC_INTERCEPTOR_H
6 changes: 6 additions & 0 deletions src/brpc/policy/baidu_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,12 @@ void ProcessRpcRequest(InputMessageBase* msg_base) {
google::protobuf::Service* svc = mp->service;
const google::protobuf::MethodDescriptor* method = mp->method;
accessor.set_method(method);


if (!server->AcceptRequest(cntl.get())) {
break;
}

if (span) {
span->ResetServerSpanName(method->full_name());
}
Expand Down
3 changes: 3 additions & 0 deletions src/brpc/policy/http_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,9 @@ void ProcessHttpRequest(InputMessageBase *msg) {
" -usercode_in_pthread is on");
return;
}
if (!server->AcceptRequest(cntl)) {
return;
}
} else if (security_mode) {
cntl->SetFailed(EPERM, "Not allowed to access builtin services, try "
"ServerOptions.internal_port=%d instead if you're in"
Expand Down
5 changes: 5 additions & 0 deletions src/brpc/policy/hulu_pbrpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ void ProcessHuluRequest(InputMessageBase* msg_base) {
google::protobuf::Service* svc = sp->service;
const google::protobuf::MethodDescriptor* method = sp->method;
accessor.set_method(method);

if (!server->AcceptRequest(cntl.get())) {
break;
}

if (span) {
span->ResetServerSpanName(method->full_name());
}
Expand Down
3 changes: 3 additions & 0 deletions src/brpc/policy/nshead_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ void ProcessNsheadRequest(InputMessageBase* msg_base) {
" -usercode_in_pthread is on");
break;
}
if (!server->AcceptRequest(cntl)) {
break;
}
} while (false);

msg.reset(); // optional, just release resource ASAP
Expand Down
5 changes: 5 additions & 0 deletions src/brpc/policy/sofa_pbrpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ void ProcessSofaRequest(InputMessageBase* msg_base) {
google::protobuf::Service* svc = sp->service;
const google::protobuf::MethodDescriptor* method = sp->method;
accessor.set_method(method);

if (!server->AcceptRequest(cntl.get())) {
break;
}

if (span) {
span->ResetServerSpanName(method->full_name());
}
Expand Down
4 changes: 4 additions & 0 deletions src/brpc/policy/thrift_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,10 @@ void ProcessThriftRequest(InputMessageBase* msg_base) {
" -usercode_in_pthread is on");
}

if (!server->AcceptRequest(cntl)) {
return;
}

msg.reset(); // optional, just release resource ASAP

if (span) {
Expand Down
25 changes: 25 additions & 0 deletions src/brpc/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ ServerOptions::ServerOptions()
, mongo_service_adaptor(NULL)
, auth(NULL)
, server_owns_auth(false)
, interceptor(NULL)
, server_owns_interceptor(false)
, num_threads(8)
, max_concurrency(0)
, session_local_data_factory(NULL)
Expand Down Expand Up @@ -450,6 +452,10 @@ Server::~Server() {
delete _options.auth;
_options.auth = NULL;
}
if (_options.server_owns_interceptor) {
delete _options.interceptor;
_options.interceptor = NULL;
}

delete _options.redis_service;
_options.redis_service = NULL;
Expand Down Expand Up @@ -2174,6 +2180,25 @@ int Server::MaxConcurrencyOf(google::protobuf::Service* service,
return MaxConcurrencyOf(service->GetDescriptor()->full_name(), method_name);
}

bool Server::AcceptRequest(Controller* cntl) const {
const Interceptor* interceptor = _options.interceptor;
if (!interceptor) {
return true;
}

int error_code = 0;
std::string error_text;
if (cntl &&
!interceptor->Accept(cntl, error_code, error_text)) {
cntl->SetFailed(error_code,
"Reject by Interceptor: %s",
error_text.c_str());
return false;
}

return true;
}

#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
int Server::SSLSwitchCTXByHostname(struct ssl_st* ssl,
int* al, Server* server) {
Expand Down
13 changes: 13 additions & 0 deletions src/brpc/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "brpc/adaptive_max_concurrency.h"
#include "brpc/http2.h"
#include "brpc/redis.h"
#include "brpc/interceptor.h"

namespace brpc {

Expand Down Expand Up @@ -91,6 +92,15 @@ struct ServerOptions {
// Default: false
bool server_owns_auth;

// Turn on request interception if `interceptor' is not NULL.
// Default: NULL
const Interceptor* interceptor;

// false: `interceptor' is not owned by server and must be valid when server is running.
// true: `interceptor' is owned by server and will be deleted when server is destructed.
// Default: false
bool server_owns_interceptor;

// Number of pthreads that server runs on. Notice that this is just a hint,
// you can't assume that the server uses exactly so many pthreads because
// pthread workers are shared by all servers and channels inside a
Expand Down Expand Up @@ -551,6 +561,9 @@ class Server {
return butil::subtle::NoBarrier_Load(&_concurrency);
};

// Returns true if accept request, reject request otherwise.
bool AcceptRequest(Controller* cntl) const;

private:
friend class StatusService;
friend class ProtobufsService;
Expand Down
201 changes: 201 additions & 0 deletions test/brpc_interceptor_unittest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <gtest/gtest.h>
#include <gflags/gflags.h>
#include "brpc/policy/sofa_pbrpc_protocol.h"
#include "brpc/channel.h"
#include "brpc/server.h"
#include "brpc/interceptor.h"
#include "brpc/nshead_service.h"
#include "echo.pb.h"

namespace brpc {
namespace policy {
DECLARE_bool(use_http_error_code);
}
}

int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
GFLAGS_NS::ParseCommandLineFlags(&argc, &argv, true);
return RUN_ALL_TESTS();
}

const int EREJECT = 4000;
int g_index = 0;
const int port = 8613;
const std::string EXP_REQUEST = "hello";
const std::string EXP_RESPONSE = "world";
const std::string NSHEAD_EXP_RESPONSE = "error";

class EchoServiceImpl : public ::test::EchoService {
public:
EchoServiceImpl() = default;
~EchoServiceImpl() override = default;
void Echo(google::protobuf::RpcController* cntl_base,
const ::test::EchoRequest* request,
::test::EchoResponse* response,
google::protobuf::Closure* done) override {
brpc::ClosureGuard done_guard(done);
ASSERT_EQ(EXP_REQUEST, request->message());
response->set_message(EXP_RESPONSE);
}
};

// Adapt your own nshead-based protocol to use brpc
class MyNsheadProtocol : public brpc::NsheadService {
public:
void ProcessNsheadRequest(const brpc::Server&,
brpc::Controller* cntl,
const brpc::NsheadMessage& request,
brpc::NsheadMessage* response,
brpc::NsheadClosure* done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);

response->head = request.head;
if (cntl->Failed()) {
ASSERT_TRUE(cntl->Failed());
ASSERT_EQ(EREJECT, cntl->ErrorCode());
response->body.append(NSHEAD_EXP_RESPONSE);
return;
}
response->body.append(EXP_RESPONSE);
}
};

class MyInterceptor : public brpc::Interceptor {
public:
MyInterceptor() = default;

~MyInterceptor() override = default;

bool Accept(const brpc::Controller* controller,
int& error_code,
std::string& error_txt) const override {
if (g_index % 2 == 0) {
error_code = EREJECT;
error_txt = "reject g_index=0";
return false;
}

return true;
}
};

class InterceptorTest : public ::testing::Test {
public:
InterceptorTest() {
EXPECT_EQ(0, _server.AddService(&_echo_svc,
brpc::SERVER_DOESNT_OWN_SERVICE));
brpc::ServerOptions options;
options.interceptor = new MyInterceptor;
options.nshead_service = new MyNsheadProtocol;
options.server_owns_interceptor = true;
EXPECT_EQ(0, _server.Start(port, &options));
}

~InterceptorTest() override = default;

static void CallMethod(test::EchoService_Stub& stub,
::test::EchoRequest& req,
::test::EchoResponse& res) {
for (g_index = 0; g_index < 1000; ++g_index) {
brpc::Controller cntl;
stub.Echo(&cntl, &req, &res, NULL);
if (g_index % 2 == 0) {
ASSERT_TRUE(cntl.Failed());
ASSERT_EQ(EREJECT, cntl.ErrorCode());
} else {
ASSERT_FALSE(cntl.Failed());
EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText();
}
}
}

private:
brpc::Server _server;
EchoServiceImpl _echo_svc;
};

TEST_F(InterceptorTest, sanity) {
::test::EchoRequest req;
::test::EchoResponse res;
req.set_message(EXP_REQUEST);

// PROTOCOL_BAIDU_STD
{
brpc::Channel channel;
brpc::ChannelOptions options;
ASSERT_EQ(0, channel.Init("localhost", port, &options));
test::EchoService_Stub stub(&channel);
CallMethod(stub, req, res);
}

// PROTOCOL_HTTP
{
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = brpc::PROTOCOL_HTTP;
ASSERT_EQ(0, channel.Init("localhost", port, &options));
test::EchoService_Stub stub(&channel);
// Set the x-bd-error-code header of http response to brpc error code.
brpc::policy::FLAGS_use_http_error_code = true;
CallMethod(stub, req, res);
}

// PROTOCOL_HULU_PBRPC
{
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = brpc::PROTOCOL_HULU_PBRPC;
ASSERT_EQ(0, channel.Init("localhost", port, &options));
test::EchoService_Stub stub(&channel);
CallMethod(stub, req, res);
}

// PROTOCOL_SOFA_PBRPC
{
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = brpc::PROTOCOL_SOFA_PBRPC;
ASSERT_EQ(0, channel.Init("localhost", port, &options));
test::EchoService_Stub stub(&channel);
CallMethod(stub, req, res);
}

// PROTOCOL_NSHEAD
{
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = brpc::PROTOCOL_NSHEAD;
ASSERT_EQ(0, channel.Init("localhost", port, &options));
brpc::NsheadMessage request;
for (g_index = 0; g_index < 1000; ++g_index) {
brpc::Controller cntl;
brpc::NsheadMessage response;
channel.CallMethod(NULL, &cntl, &request, &response, NULL);
if (g_index % 2 == 0) {
ASSERT_EQ(NSHEAD_EXP_RESPONSE, response.body.to_string());
} else {
ASSERT_EQ(EXP_RESPONSE, response.body.to_string());
}
}
}
}

0 comments on commit 544818c

Please sign in to comment.